Extract primary apk from apk set zip

Extract and install the primary apk normally, and then unzip the rest
of them as a post install command.

Bug: 204136549
Test: app_set_test.go
Change-Id: I17437ff27f49df6bc91bdbbea6173b46c7d3ec4e
diff --git a/cmd/extract_apks/main.go b/cmd/extract_apks/main.go
index 6e51a28..1cf64de 100644
--- a/cmd/extract_apks/main.go
+++ b/cmd/extract_apks/main.go
@@ -356,7 +356,7 @@
 
 // Writes out selected entries, renaming them as needed
 func (apkSet *ApkSet) writeApks(selected SelectionResult, config TargetConfig,
-	writer Zip2ZipWriter, partition string) ([]string, error) {
+	outFile io.Writer, zipWriter Zip2ZipWriter, partition string) ([]string, error) {
 	// Renaming rules:
 	//  splits/MODULE-master.apk to STEM.apk
 	// else
@@ -406,8 +406,14 @@
 				origin, inName, outName)
 		}
 		entryOrigin[outName] = inName
-		if err := writer.CopyFrom(apkFile, outName); err != nil {
-			return nil, err
+		if outName == config.stem+".apk" {
+			if err := writeZipEntryToFile(outFile, apkFile); err != nil {
+				return nil, err
+			}
+		} else {
+			if err := zipWriter.CopyFrom(apkFile, outName); err != nil {
+				return nil, err
+			}
 		}
 		if partition != "" {
 			apkcerts = append(apkcerts, fmt.Sprintf(
@@ -426,14 +432,13 @@
 	if !ok {
 		return fmt.Errorf("Couldn't find apk path %s", selected.entries[0])
 	}
-	inputReader, _ := apk.Open()
-	_, err := io.Copy(outFile, inputReader)
-	return err
+	return writeZipEntryToFile(outFile, apk)
 }
 
 // Arguments parsing
 var (
-	outputFile   = flag.String("o", "", "output file containing extracted entries")
+	outputFile   = flag.String("o", "", "output file for primary entry")
+	zipFile      = flag.String("zip", "", "output file containing additional extracted entries")
 	targetConfig = TargetConfig{
 		screenDpi: map[android_bundle_proto.ScreenDensity_DensityAlias]bool{},
 		abis:      map[android_bundle_proto.Abi_AbiAlias]int{},
@@ -494,7 +499,8 @@
 
 func processArgs() {
 	flag.Usage = func() {
-		fmt.Fprintln(os.Stderr, `usage: extract_apks -o <output-file> -sdk-version value -abis value `+
+		fmt.Fprintln(os.Stderr, `usage: extract_apks -o <output-file> [-zip <output-zip-file>] `+
+			`-sdk-version value -abis value `+
 			`-screen-densities value {-stem value | -extract-single} [-allow-prereleased] `+
 			`[-apkcerts <apkcerts output file> -partition <partition>] <APK set>`)
 		flag.PrintDefaults()
@@ -510,7 +516,8 @@
 	flag.StringVar(&targetConfig.stem, "stem", "", "output entries base name in the output zip file")
 	flag.Parse()
 	if (*outputFile == "") || len(flag.Args()) != 1 || *version == 0 ||
-		(targetConfig.stem == "" && !*extractSingle) || (*apkcertsOutput != "" && *partition == "") {
+		((targetConfig.stem == "" || *zipFile == "") && !*extractSingle) ||
+		(*apkcertsOutput != "" && *partition == "") {
 		flag.Usage()
 	}
 	targetConfig.sdkVersion = int32(*version)
@@ -542,13 +549,20 @@
 	if *extractSingle {
 		err = apkSet.extractAndCopySingle(sel, outFile)
 	} else {
-		writer := zip.NewWriter(outFile)
+		zipOutputFile, err := os.Create(*zipFile)
+		if err != nil {
+			log.Fatal(err)
+		}
+		defer zipOutputFile.Close()
+
+		zipWriter := zip.NewWriter(zipOutputFile)
 		defer func() {
-			if err := writer.Close(); err != nil {
+			if err := zipWriter.Close(); err != nil {
 				log.Fatal(err)
 			}
 		}()
-		apkcerts, err := apkSet.writeApks(sel, targetConfig, writer, *partition)
+
+		apkcerts, err := apkSet.writeApks(sel, targetConfig, outFile, zipWriter, *partition)
 		if err == nil && *apkcertsOutput != "" {
 			apkcertsFile, err := os.Create(*apkcertsOutput)
 			if err != nil {
@@ -567,3 +581,13 @@
 		log.Fatal(err)
 	}
 }
+
+func writeZipEntryToFile(outFile io.Writer, zipEntry *zip.File) error {
+	reader, err := zipEntry.Open()
+	if err != nil {
+		return err
+	}
+	defer reader.Close()
+	_, err = io.Copy(outFile, reader)
+	return err
+}
diff --git a/cmd/extract_apks/main_test.go b/cmd/extract_apks/main_test.go
index 9fcf324..f5e4046 100644
--- a/cmd/extract_apks/main_test.go
+++ b/cmd/extract_apks/main_test.go
@@ -15,6 +15,7 @@
 package main
 
 import (
+	"bytes"
 	"fmt"
 	"reflect"
 	"testing"
@@ -437,8 +438,8 @@
 	stem       string
 	partition  string
 	// what we write from what
-	expectedZipEntries map[string]string
-	expectedApkcerts   []string
+	zipEntries       map[string]string
+	expectedApkcerts []string
 }
 
 func TestWriteApks(t *testing.T) {
@@ -448,7 +449,7 @@
 			moduleName: "mybase",
 			stem:       "Foo",
 			partition:  "system",
-			expectedZipEntries: map[string]string{
+			zipEntries: map[string]string{
 				"Foo.apk":       "splits/mybase-master.apk",
 				"Foo-xhdpi.apk": "splits/mybase-xhdpi.apk",
 			},
@@ -462,7 +463,7 @@
 			moduleName: "base",
 			stem:       "Bar",
 			partition:  "product",
-			expectedZipEntries: map[string]string{
+			zipEntries: map[string]string{
 				"Bar.apk": "universal.apk",
 			},
 			expectedApkcerts: []string{
@@ -471,23 +472,46 @@
 		},
 	}
 	for _, testCase := range testCases {
-		apkSet := ApkSet{entries: make(map[string]*zip.File)}
-		sel := SelectionResult{moduleName: testCase.moduleName}
-		for _, in := range testCase.expectedZipEntries {
-			apkSet.entries[in] = &zip.File{FileHeader: zip.FileHeader{Name: in}}
-			sel.entries = append(sel.entries, in)
-		}
-		writer := testZip2ZipWriter{make(map[string]string)}
-		config := TargetConfig{stem: testCase.stem}
-		apkcerts, err := apkSet.writeApks(sel, config, writer, testCase.partition)
-		if err != nil {
-			t.Error(err)
-		}
-		if !reflect.DeepEqual(testCase.expectedZipEntries, writer.entries) {
-			t.Errorf("expected zip entries %v, got %v", testCase.expectedZipEntries, writer.entries)
-		}
-		if !reflect.DeepEqual(testCase.expectedApkcerts, apkcerts) {
-			t.Errorf("expected apkcerts %v, got %v", testCase.expectedApkcerts, apkcerts)
-		}
+		t.Run(testCase.name, func(t *testing.T) {
+			testZipBuf := &bytes.Buffer{}
+			testZip := zip.NewWriter(testZipBuf)
+			for _, in := range testCase.zipEntries {
+				f, _ := testZip.Create(in)
+				f.Write([]byte(in))
+			}
+			testZip.Close()
+
+			zipReader, _ := zip.NewReader(bytes.NewReader(testZipBuf.Bytes()), int64(testZipBuf.Len()))
+
+			apkSet := ApkSet{entries: make(map[string]*zip.File)}
+			sel := SelectionResult{moduleName: testCase.moduleName}
+			for _, f := range zipReader.File {
+				apkSet.entries[f.Name] = f
+				sel.entries = append(sel.entries, f.Name)
+			}
+
+			zipWriter := testZip2ZipWriter{make(map[string]string)}
+			outWriter := &bytes.Buffer{}
+			config := TargetConfig{stem: testCase.stem}
+			apkcerts, err := apkSet.writeApks(sel, config, outWriter, zipWriter, testCase.partition)
+			if err != nil {
+				t.Error(err)
+			}
+			expectedZipEntries := make(map[string]string)
+			for k, v := range testCase.zipEntries {
+				if k != testCase.stem+".apk" {
+					expectedZipEntries[k] = v
+				}
+			}
+			if !reflect.DeepEqual(expectedZipEntries, zipWriter.entries) {
+				t.Errorf("expected zip entries %v, got %v", testCase.zipEntries, zipWriter.entries)
+			}
+			if !reflect.DeepEqual(testCase.expectedApkcerts, apkcerts) {
+				t.Errorf("expected apkcerts %v, got %v", testCase.expectedApkcerts, apkcerts)
+			}
+			if g, w := outWriter.String(), testCase.zipEntries[testCase.stem+".apk"]; !reflect.DeepEqual(g, w) {
+				t.Errorf("expected output file contents %q, got %q", testCase.stem+".apk", outWriter.String())
+			}
+		})
 	}
 }