Colin Cross | 5498f85 | 2018-01-03 23:39:54 -0800 | [diff] [blame^] | 1 | // Copyright 2018 Google Inc. All rights reserved. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | package main |
| 16 | |
| 17 | import ( |
| 18 | "bytes" |
| 19 | "flag" |
| 20 | "fmt" |
| 21 | "io" |
| 22 | "math" |
| 23 | "os" |
| 24 | ) |
| 25 | |
| 26 | var ( |
| 27 | input = flag.String("i", "", "input file") |
| 28 | output = flag.String("o", "", "output file") |
| 29 | symbol = flag.String("s", "", "symbol to inject into") |
| 30 | from = flag.String("from", "", "optional existing value of the symbol for verification") |
| 31 | value = flag.String("v", "", "value to inject into symbol") |
| 32 | ) |
| 33 | |
| 34 | var maxUint64 uint64 = math.MaxUint64 |
| 35 | |
| 36 | type cantParseError struct { |
| 37 | error |
| 38 | } |
| 39 | |
| 40 | func main() { |
| 41 | flag.Parse() |
| 42 | |
| 43 | usageError := func(s string) { |
| 44 | fmt.Fprintln(os.Stderr, s) |
| 45 | flag.Usage() |
| 46 | os.Exit(1) |
| 47 | } |
| 48 | |
| 49 | if *input == "" { |
| 50 | usageError("-i is required") |
| 51 | } |
| 52 | |
| 53 | if *output == "" { |
| 54 | usageError("-o is required") |
| 55 | } |
| 56 | |
| 57 | if *symbol == "" { |
| 58 | usageError("-s is required") |
| 59 | } |
| 60 | |
| 61 | if *value == "" { |
| 62 | usageError("-v is required") |
| 63 | } |
| 64 | |
| 65 | r, err := os.Open(*input) |
| 66 | if err != nil { |
| 67 | fmt.Fprintln(os.Stderr, err.Error()) |
| 68 | os.Exit(2) |
| 69 | } |
| 70 | defer r.Close() |
| 71 | |
| 72 | w, err := os.OpenFile(*output, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777) |
| 73 | if err != nil { |
| 74 | fmt.Fprintln(os.Stderr, err.Error()) |
| 75 | os.Exit(3) |
| 76 | } |
| 77 | defer w.Close() |
| 78 | |
| 79 | err = injectSymbol(r, w, *symbol, *value, *from) |
| 80 | if err != nil { |
| 81 | fmt.Fprintln(os.Stderr, err.Error()) |
| 82 | os.Remove(*output) |
| 83 | os.Exit(2) |
| 84 | } |
| 85 | } |
| 86 | |
| 87 | type ReadSeekerAt interface { |
| 88 | io.ReaderAt |
| 89 | io.ReadSeeker |
| 90 | } |
| 91 | |
| 92 | func injectSymbol(r ReadSeekerAt, w io.Writer, symbol, value, from string) error { |
| 93 | var offset, size uint64 |
| 94 | var err error |
| 95 | |
| 96 | offset, size, err = findElfSymbol(r, symbol) |
| 97 | if elfError, ok := err.(cantParseError); ok { |
| 98 | // Try as a mach-o file |
| 99 | offset, size, err = findMachoSymbol(r, symbol) |
| 100 | if _, ok := err.(cantParseError); ok { |
| 101 | // Try as a windows PE file |
| 102 | offset, size, err = findPESymbol(r, symbol) |
| 103 | if _, ok := err.(cantParseError); ok { |
| 104 | // Can't parse as elf, macho, or PE, return the elf error |
| 105 | return elfError |
| 106 | } |
| 107 | } |
| 108 | } |
| 109 | if err != nil { |
| 110 | return err |
| 111 | } |
| 112 | |
| 113 | if uint64(len(value))+1 > size { |
| 114 | return fmt.Errorf("value length %d overflows symbol size %d", len(value), size) |
| 115 | } |
| 116 | |
| 117 | if from != "" { |
| 118 | // Read the exsting symbol contents and verify they match the expected value |
| 119 | expected := make([]byte, size) |
| 120 | existing := make([]byte, size) |
| 121 | copy(expected, from) |
| 122 | _, err := r.ReadAt(existing, int64(offset)) |
| 123 | if err != nil { |
| 124 | return err |
| 125 | } |
| 126 | if bytes.Compare(existing, expected) != 0 { |
| 127 | return fmt.Errorf("existing symbol contents %q did not match expected value %q", |
| 128 | string(existing), string(expected)) |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | return copyAndInject(r, w, offset, size, value) |
| 133 | } |
| 134 | |
| 135 | func copyAndInject(r io.ReadSeeker, w io.Writer, offset, size uint64, value string) (err error) { |
| 136 | // helper that asserts a two-value function returning an int64 and an error has err != nil |
| 137 | must := func(n int64, err error) { |
| 138 | if err != nil { |
| 139 | panic(err) |
| 140 | } |
| 141 | } |
| 142 | |
| 143 | // helper that asserts a two-value function returning an int and an error has err != nil |
| 144 | must2 := func(n int, err error) { |
| 145 | must(int64(n), err) |
| 146 | } |
| 147 | |
| 148 | // convert a panic into returning an error |
| 149 | defer func() { |
| 150 | if r := recover(); r != nil { |
| 151 | err, _ = r.(error) |
| 152 | if err == io.EOF { |
| 153 | err = io.ErrUnexpectedEOF |
| 154 | } |
| 155 | if err == nil { |
| 156 | panic(r) |
| 157 | } |
| 158 | } |
| 159 | }() |
| 160 | |
| 161 | buf := make([]byte, size) |
| 162 | copy(buf, value) |
| 163 | |
| 164 | // Reset the input file |
| 165 | must(r.Seek(0, io.SeekStart)) |
| 166 | // Copy the first bytes up to the symbol offset |
| 167 | must(io.CopyN(w, r, int64(offset))) |
| 168 | // Skip the symbol contents in the input file |
| 169 | must(r.Seek(int64(size), io.SeekCurrent)) |
| 170 | // Write the injected value in the output file |
| 171 | must2(w.Write(buf)) |
| 172 | // Write the remainder of the file |
| 173 | must(io.Copy(w, r)) |
| 174 | |
| 175 | return nil |
| 176 | } |