Add a tool to inject data into an elf, macho, or PE symbol

Test: symbol_inject -i a.out -o a.out2 -s symbol -v value
Change-Id: I16cd8facbae754f679bef07ab0ba23638286e1d7
diff --git a/cmd/symbol_inject/Android.bp b/cmd/symbol_inject/Android.bp
new file mode 100644
index 0000000..5f9c4a4
--- /dev/null
+++ b/cmd/symbol_inject/Android.bp
@@ -0,0 +1,23 @@
+// Copyright 2018 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+blueprint_go_binary {
+    name: "symbol_inject",
+    srcs: [
+        "symbol_inject.go",
+        "elf.go",
+        "macho.go",
+        "pe.go",
+    ],
+}
diff --git a/cmd/symbol_inject/elf.go b/cmd/symbol_inject/elf.go
new file mode 100644
index 0000000..1741a5b
--- /dev/null
+++ b/cmd/symbol_inject/elf.go
@@ -0,0 +1,76 @@
+// Copyright 2018 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"debug/elf"
+	"fmt"
+	"io"
+)
+
+func findElfSymbol(r io.ReaderAt, symbol string) (uint64, uint64, error) {
+	elfFile, err := elf.NewFile(r)
+	if err != nil {
+		return maxUint64, maxUint64, cantParseError{err}
+	}
+
+	symbols, err := elfFile.Symbols()
+	if err != nil {
+		return maxUint64, maxUint64, err
+	}
+
+	for _, s := range symbols {
+		if elf.ST_TYPE(s.Info) != elf.STT_OBJECT {
+			continue
+		}
+		if s.Name == symbol {
+			offset, err := calculateElfSymbolOffset(elfFile, s)
+			if err != nil {
+				return maxUint64, maxUint64, err
+			}
+			return offset, s.Size, nil
+		}
+	}
+
+	return maxUint64, maxUint64, fmt.Errorf("symbol not found")
+}
+
+func calculateElfSymbolOffset(file *elf.File, symbol elf.Symbol) (uint64, error) {
+	if symbol.Section == elf.SHN_UNDEF || int(symbol.Section) >= len(file.Sections) {
+		return maxUint64, fmt.Errorf("invalid section index %d", symbol.Section)
+	}
+	section := file.Sections[symbol.Section]
+	switch file.Type {
+	case elf.ET_REL:
+		// "In relocatable files, st_value holds a section offset for a defined symbol.
+		// That is, st_value is an offset from the beginning of the section that st_shndx identifies."
+		return file.Sections[symbol.Section].Addr + symbol.Value, nil
+	case elf.ET_EXEC, elf.ET_DYN:
+		// "In executable and shared object files, st_value holds a virtual address. To make these
+		// files’ symbols more useful for the dynamic linker, the section offset (file interpretation)
+		// gives way to a virtual address (memory interpretation) for which the section number is
+		// irrelevant."
+		if symbol.Value < section.Addr {
+			return maxUint64, fmt.Errorf("symbol starts before the start of its section")
+		}
+		section_offset := symbol.Value - section.Addr
+		if section_offset+symbol.Size > section.Size {
+			return maxUint64, fmt.Errorf("symbol extends past the end of its section")
+		}
+		return section.Offset + section_offset, nil
+	default:
+		return maxUint64, fmt.Errorf("unsupported elf file type %d", file.Type)
+	}
+}
diff --git a/cmd/symbol_inject/macho.go b/cmd/symbol_inject/macho.go
new file mode 100644
index 0000000..0945293
--- /dev/null
+++ b/cmd/symbol_inject/macho.go
@@ -0,0 +1,65 @@
+// Copyright 2018 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"debug/macho"
+	"fmt"
+	"io"
+)
+
+func findMachoSymbol(r io.ReaderAt, symbolName string) (uint64, uint64, error) {
+	machoFile, err := macho.NewFile(r)
+	if err != nil {
+		return maxUint64, maxUint64, cantParseError{err}
+	}
+
+	// TODO(ccross): why?
+	symbolName = "_" + symbolName
+
+	for i, symbol := range machoFile.Symtab.Syms {
+		if symbol.Sect == 0 {
+			continue
+		}
+		if symbol.Name == symbolName {
+			var nextSymbol *macho.Symbol
+			if i+1 < len(machoFile.Symtab.Syms) {
+				nextSymbol = &machoFile.Symtab.Syms[i+1]
+			}
+			return calculateMachoSymbolOffset(machoFile, symbol, nextSymbol)
+		}
+	}
+
+	return maxUint64, maxUint64, fmt.Errorf("symbol not found")
+}
+
+func calculateMachoSymbolOffset(file *macho.File, symbol macho.Symbol, nextSymbol *macho.Symbol) (uint64, uint64, error) {
+	section := file.Sections[symbol.Sect-1]
+
+	var end uint64
+	if nextSymbol != nil && nextSymbol.Sect != symbol.Sect {
+		nextSymbol = nil
+	}
+	if nextSymbol != nil {
+		end = nextSymbol.Value
+	} else {
+		end = section.Addr + section.Size
+	}
+
+	size := end - symbol.Value - 1
+	offset := uint64(section.Offset) + (symbol.Value - section.Addr)
+
+	return offset, size, nil
+}
diff --git a/cmd/symbol_inject/pe.go b/cmd/symbol_inject/pe.go
new file mode 100644
index 0000000..86f6162
--- /dev/null
+++ b/cmd/symbol_inject/pe.go
@@ -0,0 +1,67 @@
+// Copyright 2018 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"debug/pe"
+	"fmt"
+	"io"
+	"sort"
+)
+
+func findPESymbol(r io.ReaderAt, symbolName string) (uint64, uint64, error) {
+	peFile, err := pe.NewFile(r)
+	if err != nil {
+		return maxUint64, maxUint64, cantParseError{err}
+	}
+
+	sort.Slice(peFile.Symbols, func(i, j int) bool {
+		if peFile.Symbols[i].SectionNumber != peFile.Symbols[j].SectionNumber {
+			return peFile.Symbols[i].SectionNumber < peFile.Symbols[j].SectionNumber
+		}
+		return peFile.Symbols[i].Value < peFile.Symbols[j].Value
+	})
+
+	for i, symbol := range peFile.Symbols {
+		if symbol.Name == symbolName {
+			var nextSymbol *pe.Symbol
+			if i+1 < len(peFile.Symbols) {
+				nextSymbol = peFile.Symbols[i+1]
+			}
+			return calculatePESymbolOffset(peFile, symbol, nextSymbol)
+		}
+	}
+
+	return maxUint64, maxUint64, fmt.Errorf("symbol not found")
+}
+
+func calculatePESymbolOffset(file *pe.File, symbol *pe.Symbol, nextSymbol *pe.Symbol) (uint64, uint64, error) {
+	section := file.Sections[symbol.SectionNumber-1]
+
+	var end uint32
+	if nextSymbol != nil && nextSymbol.SectionNumber != symbol.SectionNumber {
+		nextSymbol = nil
+	}
+	if nextSymbol != nil {
+		end = nextSymbol.Value
+	} else {
+		end = section.Size
+	}
+
+	size := end - symbol.Value - 1
+	offset := section.Offset + symbol.Value
+
+	return uint64(offset), uint64(size), nil
+}
diff --git a/cmd/symbol_inject/symbol_inject.go b/cmd/symbol_inject/symbol_inject.go
new file mode 100644
index 0000000..75f8a1a
--- /dev/null
+++ b/cmd/symbol_inject/symbol_inject.go
@@ -0,0 +1,176 @@
+// Copyright 2018 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"bytes"
+	"flag"
+	"fmt"
+	"io"
+	"math"
+	"os"
+)
+
+var (
+	input  = flag.String("i", "", "input file")
+	output = flag.String("o", "", "output file")
+	symbol = flag.String("s", "", "symbol to inject into")
+	from   = flag.String("from", "", "optional existing value of the symbol for verification")
+	value  = flag.String("v", "", "value to inject into symbol")
+)
+
+var maxUint64 uint64 = math.MaxUint64
+
+type cantParseError struct {
+	error
+}
+
+func main() {
+	flag.Parse()
+
+	usageError := func(s string) {
+		fmt.Fprintln(os.Stderr, s)
+		flag.Usage()
+		os.Exit(1)
+	}
+
+	if *input == "" {
+		usageError("-i is required")
+	}
+
+	if *output == "" {
+		usageError("-o is required")
+	}
+
+	if *symbol == "" {
+		usageError("-s is required")
+	}
+
+	if *value == "" {
+		usageError("-v is required")
+	}
+
+	r, err := os.Open(*input)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err.Error())
+		os.Exit(2)
+	}
+	defer r.Close()
+
+	w, err := os.OpenFile(*output, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err.Error())
+		os.Exit(3)
+	}
+	defer w.Close()
+
+	err = injectSymbol(r, w, *symbol, *value, *from)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err.Error())
+		os.Remove(*output)
+		os.Exit(2)
+	}
+}
+
+type ReadSeekerAt interface {
+	io.ReaderAt
+	io.ReadSeeker
+}
+
+func injectSymbol(r ReadSeekerAt, w io.Writer, symbol, value, from string) error {
+	var offset, size uint64
+	var err error
+
+	offset, size, err = findElfSymbol(r, symbol)
+	if elfError, ok := err.(cantParseError); ok {
+		// Try as a mach-o file
+		offset, size, err = findMachoSymbol(r, symbol)
+		if _, ok := err.(cantParseError); ok {
+			// Try as a windows PE file
+			offset, size, err = findPESymbol(r, symbol)
+			if _, ok := err.(cantParseError); ok {
+				// Can't parse as elf, macho, or PE, return the elf error
+				return elfError
+			}
+		}
+	}
+	if err != nil {
+		return err
+	}
+
+	if uint64(len(value))+1 > size {
+		return fmt.Errorf("value length %d overflows symbol size %d", len(value), size)
+	}
+
+	if from != "" {
+		// Read the exsting symbol contents and verify they match the expected value
+		expected := make([]byte, size)
+		existing := make([]byte, size)
+		copy(expected, from)
+		_, err := r.ReadAt(existing, int64(offset))
+		if err != nil {
+			return err
+		}
+		if bytes.Compare(existing, expected) != 0 {
+			return fmt.Errorf("existing symbol contents %q did not match expected value %q",
+				string(existing), string(expected))
+		}
+	}
+
+	return copyAndInject(r, w, offset, size, value)
+}
+
+func copyAndInject(r io.ReadSeeker, w io.Writer, offset, size uint64, value string) (err error) {
+	// helper that asserts a two-value function returning an int64 and an error has err != nil
+	must := func(n int64, err error) {
+		if err != nil {
+			panic(err)
+		}
+	}
+
+	// helper that asserts a two-value function returning an int and an error has err != nil
+	must2 := func(n int, err error) {
+		must(int64(n), err)
+	}
+
+	// convert a panic into returning an error
+	defer func() {
+		if r := recover(); r != nil {
+			err, _ = r.(error)
+			if err == io.EOF {
+				err = io.ErrUnexpectedEOF
+			}
+			if err == nil {
+				panic(r)
+			}
+		}
+	}()
+
+	buf := make([]byte, size)
+	copy(buf, value)
+
+	// Reset the input file
+	must(r.Seek(0, io.SeekStart))
+	// Copy the first bytes up to the symbol offset
+	must(io.CopyN(w, r, int64(offset)))
+	// Skip the symbol contents in the input file
+	must(r.Seek(int64(size), io.SeekCurrent))
+	// Write the injected value in the output file
+	must2(w.Write(buf))
+	// Write the remainder of the file
+	must(io.Copy(w, r))
+
+	return nil
+}