blob: 75f8a1ac08d70d5c1431dc9f9066d444ca61fc85 [file] [log] [blame]
Colin Cross5498f852018-01-03 23:39:54 -08001// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18 "bytes"
19 "flag"
20 "fmt"
21 "io"
22 "math"
23 "os"
24)
25
26var (
27 input = flag.String("i", "", "input file")
28 output = flag.String("o", "", "output file")
29 symbol = flag.String("s", "", "symbol to inject into")
30 from = flag.String("from", "", "optional existing value of the symbol for verification")
31 value = flag.String("v", "", "value to inject into symbol")
32)
33
34var maxUint64 uint64 = math.MaxUint64
35
36type cantParseError struct {
37 error
38}
39
40func main() {
41 flag.Parse()
42
43 usageError := func(s string) {
44 fmt.Fprintln(os.Stderr, s)
45 flag.Usage()
46 os.Exit(1)
47 }
48
49 if *input == "" {
50 usageError("-i is required")
51 }
52
53 if *output == "" {
54 usageError("-o is required")
55 }
56
57 if *symbol == "" {
58 usageError("-s is required")
59 }
60
61 if *value == "" {
62 usageError("-v is required")
63 }
64
65 r, err := os.Open(*input)
66 if err != nil {
67 fmt.Fprintln(os.Stderr, err.Error())
68 os.Exit(2)
69 }
70 defer r.Close()
71
72 w, err := os.OpenFile(*output, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
73 if err != nil {
74 fmt.Fprintln(os.Stderr, err.Error())
75 os.Exit(3)
76 }
77 defer w.Close()
78
79 err = injectSymbol(r, w, *symbol, *value, *from)
80 if err != nil {
81 fmt.Fprintln(os.Stderr, err.Error())
82 os.Remove(*output)
83 os.Exit(2)
84 }
85}
86
87type ReadSeekerAt interface {
88 io.ReaderAt
89 io.ReadSeeker
90}
91
92func injectSymbol(r ReadSeekerAt, w io.Writer, symbol, value, from string) error {
93 var offset, size uint64
94 var err error
95
96 offset, size, err = findElfSymbol(r, symbol)
97 if elfError, ok := err.(cantParseError); ok {
98 // Try as a mach-o file
99 offset, size, err = findMachoSymbol(r, symbol)
100 if _, ok := err.(cantParseError); ok {
101 // Try as a windows PE file
102 offset, size, err = findPESymbol(r, symbol)
103 if _, ok := err.(cantParseError); ok {
104 // Can't parse as elf, macho, or PE, return the elf error
105 return elfError
106 }
107 }
108 }
109 if err != nil {
110 return err
111 }
112
113 if uint64(len(value))+1 > size {
114 return fmt.Errorf("value length %d overflows symbol size %d", len(value), size)
115 }
116
117 if from != "" {
118 // Read the exsting symbol contents and verify they match the expected value
119 expected := make([]byte, size)
120 existing := make([]byte, size)
121 copy(expected, from)
122 _, err := r.ReadAt(existing, int64(offset))
123 if err != nil {
124 return err
125 }
126 if bytes.Compare(existing, expected) != 0 {
127 return fmt.Errorf("existing symbol contents %q did not match expected value %q",
128 string(existing), string(expected))
129 }
130 }
131
132 return copyAndInject(r, w, offset, size, value)
133}
134
135func copyAndInject(r io.ReadSeeker, w io.Writer, offset, size uint64, value string) (err error) {
136 // helper that asserts a two-value function returning an int64 and an error has err != nil
137 must := func(n int64, err error) {
138 if err != nil {
139 panic(err)
140 }
141 }
142
143 // helper that asserts a two-value function returning an int and an error has err != nil
144 must2 := func(n int, err error) {
145 must(int64(n), err)
146 }
147
148 // convert a panic into returning an error
149 defer func() {
150 if r := recover(); r != nil {
151 err, _ = r.(error)
152 if err == io.EOF {
153 err = io.ErrUnexpectedEOF
154 }
155 if err == nil {
156 panic(r)
157 }
158 }
159 }()
160
161 buf := make([]byte, size)
162 copy(buf, value)
163
164 // Reset the input file
165 must(r.Seek(0, io.SeekStart))
166 // Copy the first bytes up to the symbol offset
167 must(io.CopyN(w, r, int64(offset)))
168 // Skip the symbol contents in the input file
169 must(r.Seek(int64(size), io.SeekCurrent))
170 // Write the injected value in the output file
171 must2(w.Write(buf))
172 // Write the remainder of the file
173 must(io.Copy(w, r))
174
175 return nil
176}