Refactor code for partitions c srcs

To support protos (and other srcs that generate sources), we need to
partition further. Separate out into a separate common function.

Bug: 200601772
Test: build/bazel/ci/bp2build.sh
Change-Id: I7bf4cd96fd9a9fca4ccb3c96f21a04303201f891
diff --git a/bazel/Android.bp b/bazel/Android.bp
index b68d65b..80af2bd 100644
--- a/bazel/Android.bp
+++ b/bazel/Android.bp
@@ -14,6 +14,7 @@
     testSrcs: [
         "aquery_test.go",
         "properties_test.go",
+        "testing.go",
     ],
     pluginFor: [
         "soong_build",
diff --git a/bazel/properties.go b/bazel/properties.go
index d3b40a2..bd8ef0d 100644
--- a/bazel/properties.go
+++ b/bazel/properties.go
@@ -19,6 +19,9 @@
 	"path/filepath"
 	"regexp"
 	"sort"
+	"strings"
+
+	"github.com/google/blueprint"
 )
 
 // BazelTargetModuleProperties contain properties and metadata used for
@@ -182,76 +185,6 @@
 	return strings
 }
 
-// Map a function over all labels in a LabelList.
-func MapLabelList(mapOver LabelList, mapFn func(string) string) LabelList {
-	var includes []Label
-	for _, inc := range mapOver.Includes {
-		mappedLabel := Label{Label: mapFn(inc.Label), OriginalModuleName: inc.OriginalModuleName}
-		includes = append(includes, mappedLabel)
-	}
-	// mapFn is not applied over excludes, but they are propagated as-is.
-	return LabelList{Includes: includes, Excludes: mapOver.Excludes}
-}
-
-// Map a function over all Labels in a LabelListAttribute
-func MapLabelListAttribute(mapOver LabelListAttribute, mapFn func(string) string) LabelListAttribute {
-	var result LabelListAttribute
-
-	result.Value = MapLabelList(mapOver.Value, mapFn)
-
-	for axis, configToLabels := range mapOver.ConfigurableValues {
-		for config, value := range configToLabels {
-			result.SetSelectValue(axis, config, MapLabelList(value, mapFn))
-		}
-	}
-
-	return result
-}
-
-// Return all needles in a given haystack, where needleFn is true for needles.
-func FilterLabelList(haystack LabelList, needleFn func(string) bool) LabelList {
-	var includes []Label
-	for _, inc := range haystack.Includes {
-		if needleFn(inc.Label) {
-			includes = append(includes, inc)
-		}
-	}
-	// needleFn is not applied over excludes, but they are propagated as-is.
-	return LabelList{Includes: includes, Excludes: haystack.Excludes}
-}
-
-// Return all needles in a given haystack, where needleFn is true for needles.
-func FilterLabelListAttribute(haystack LabelListAttribute, needleFn func(string) bool) LabelListAttribute {
-	result := MakeLabelListAttribute(FilterLabelList(haystack.Value, needleFn))
-
-	for config, selects := range haystack.ConfigurableValues {
-		newSelects := make(labelListSelectValues, len(selects))
-		for k, v := range selects {
-			newSelects[k] = FilterLabelList(v, needleFn)
-		}
-		result.ConfigurableValues[config] = newSelects
-	}
-
-	return result
-}
-
-// Subtract needle from haystack
-func SubtractBazelLabelListAttribute(haystack LabelListAttribute, needle LabelListAttribute) LabelListAttribute {
-	result := MakeLabelListAttribute(SubtractBazelLabelList(haystack.Value, needle.Value))
-
-	for config, selects := range haystack.ConfigurableValues {
-		newSelects := make(labelListSelectValues, len(selects))
-		needleSelects := needle.ConfigurableValues[config]
-
-		for k, v := range selects {
-			newSelects[k] = SubtractBazelLabelList(v, needleSelects[k])
-		}
-		result.ConfigurableValues[config] = newSelects
-	}
-
-	return result
-}
-
 // Subtract needle from haystack
 func SubtractBazelLabels(haystack []Label, needle []Label) []Label {
 	// This is really a set
@@ -624,6 +557,144 @@
 	}
 }
 
+// OtherModuleContext is a limited context that has methods with information about other modules.
+type OtherModuleContext interface {
+	ModuleFromName(name string) (blueprint.Module, bool)
+	OtherModuleType(m blueprint.Module) string
+	OtherModuleName(m blueprint.Module) string
+	OtherModuleDir(m blueprint.Module) string
+	ModuleErrorf(fmt string, args ...interface{})
+}
+
+// LabelMapper is a function that takes a OtherModuleContext and returns a (potentially changed)
+// label and whether it was changed.
+type LabelMapper func(OtherModuleContext, string) (string, bool)
+
+// LabelPartition contains descriptions of a partition for labels
+type LabelPartition struct {
+	// Extensions to include in this partition
+	Extensions []string
+	// LabelMapper is a function that can map a label to a new label, and indicate whether to include
+	// the mapped label in the partition
+	LabelMapper LabelMapper
+	// Whether to store files not included in any other partition in a group of LabelPartitions
+	// Only one partition in a group of LabelPartitions can enabled Keep_remainder
+	Keep_remainder bool
+}
+
+// LabelPartitions is a map of partition name to a LabelPartition describing the elements of the
+// partition
+type LabelPartitions map[string]LabelPartition
+
+// filter returns a pointer to a label if the label should be included in the partition or nil if
+// not.
+func (lf LabelPartition) filter(ctx OtherModuleContext, label Label) *Label {
+	if lf.LabelMapper != nil {
+		if newLabel, changed := lf.LabelMapper(ctx, label.Label); changed {
+			return &Label{newLabel, label.OriginalModuleName}
+		}
+	}
+	for _, ext := range lf.Extensions {
+		if strings.HasSuffix(label.Label, ext) {
+			return &label
+		}
+	}
+
+	return nil
+}
+
+// PartitionToLabelListAttribute is map of partition name to a LabelListAttribute
+type PartitionToLabelListAttribute map[string]LabelListAttribute
+
+type partitionToLabelList map[string]*LabelList
+
+func (p partitionToLabelList) appendIncludes(partition string, label Label) {
+	if _, ok := p[partition]; !ok {
+		p[partition] = &LabelList{}
+	}
+	p[partition].Includes = append(p[partition].Includes, label)
+}
+
+func (p partitionToLabelList) excludes(partition string, excludes []Label) {
+	if _, ok := p[partition]; !ok {
+		p[partition] = &LabelList{}
+	}
+	p[partition].Excludes = excludes
+}
+
+// PartitionLabelListAttribute partitions a LabelListAttribute into the requested partitions
+func PartitionLabelListAttribute(ctx OtherModuleContext, lla *LabelListAttribute, partitions LabelPartitions) PartitionToLabelListAttribute {
+	ret := PartitionToLabelListAttribute{}
+	var partitionNames []string
+	// Stored as a pointer to distinguish nil (no remainder partition) from empty string partition
+	var remainderPartition *string
+	for p, f := range partitions {
+		partitionNames = append(partitionNames, p)
+		if f.Keep_remainder {
+			if remainderPartition != nil {
+				panic("only one partition can store the remainder")
+			}
+			// If we take the address of p in a loop, we'll end up with the last value of p in
+			// remainderPartition, we want the requested partition
+			capturePartition := p
+			remainderPartition = &capturePartition
+		}
+	}
+
+	partitionLabelList := func(axis ConfigurationAxis, config string) {
+		value := lla.SelectValue(axis, config)
+		partitionToLabels := partitionToLabelList{}
+		for _, item := range value.Includes {
+			wasFiltered := false
+			var inPartition *string
+			for partition, f := range partitions {
+				filtered := f.filter(ctx, item)
+				if filtered == nil {
+					// did not match this filter, keep looking
+					continue
+				}
+				wasFiltered = true
+				partitionToLabels.appendIncludes(partition, *filtered)
+				// don't need to check other partitions if this filter used the item,
+				// continue checking if mapped to another name
+				if *filtered == item {
+					if inPartition != nil {
+						ctx.ModuleErrorf("%q was found in multiple partitions: %q, %q", item.Label, *inPartition, partition)
+					}
+					capturePartition := partition
+					inPartition = &capturePartition
+				}
+			}
+
+			// if not specified in a partition, add to remainder partition if one exists
+			if !wasFiltered && remainderPartition != nil {
+				partitionToLabels.appendIncludes(*remainderPartition, item)
+			}
+		}
+
+		// ensure empty lists are maintained
+		if value.Excludes != nil {
+			for _, partition := range partitionNames {
+				partitionToLabels.excludes(partition, value.Excludes)
+			}
+		}
+
+		for partition, list := range partitionToLabels {
+			val := ret[partition]
+			(&val).SetSelectValue(axis, config, *list)
+			ret[partition] = val
+		}
+	}
+
+	partitionLabelList(NoConfigAxis, "")
+	for axis, configToList := range lla.ConfigurableValues {
+		for config, _ := range configToList {
+			partitionLabelList(axis, config)
+		}
+	}
+	return ret
+}
+
 // StringListAttribute corresponds to the string_list Bazel attribute type with
 // support for additional metadata, like configurations.
 type StringListAttribute struct {
diff --git a/bazel/properties_test.go b/bazel/properties_test.go
index 85596e2..f53fdc1 100644
--- a/bazel/properties_test.go
+++ b/bazel/properties_test.go
@@ -16,7 +16,10 @@
 
 import (
 	"reflect"
+	"strings"
 	"testing"
+
+	"github.com/google/blueprint/proptools"
 )
 
 func TestUniqueBazelLabels(t *testing.T) {
@@ -294,6 +297,222 @@
 	}
 }
 
+// labelAddSuffixForTypeMapper returns a LabelMapper that adds suffix to label name for modules of
+// typ
+func labelAddSuffixForTypeMapper(suffix, typ string) LabelMapper {
+	return func(omc OtherModuleContext, label string) (string, bool) {
+		m, ok := omc.ModuleFromName(label)
+		if !ok {
+			return label, false
+		}
+		mTyp := omc.OtherModuleType(m)
+		if typ == mTyp {
+			return label + suffix, true
+		}
+		return label, false
+	}
+}
+
+func TestPartitionLabelListAttribute(t *testing.T) {
+	testCases := []struct {
+		name           string
+		ctx            *otherModuleTestContext
+		labelList      LabelListAttribute
+		filters        LabelPartitions
+		expected       PartitionToLabelListAttribute
+		expectedErrMsg *string
+	}{
+		{
+			name: "no configurable values",
+			ctx:  &otherModuleTestContext{},
+			labelList: LabelListAttribute{
+				Value: makeLabelList([]string{"a.a", "b.b", "c.c", "d.d", "e.e"}, []string{}),
+			},
+			filters: LabelPartitions{
+				"A": LabelPartition{Extensions: []string{".a"}},
+				"B": LabelPartition{Extensions: []string{".b"}},
+				"C": LabelPartition{Extensions: []string{".c"}},
+			},
+			expected: PartitionToLabelListAttribute{
+				"A": LabelListAttribute{Value: makeLabelList([]string{"a.a"}, []string{})},
+				"B": LabelListAttribute{Value: makeLabelList([]string{"b.b"}, []string{})},
+				"C": LabelListAttribute{Value: makeLabelList([]string{"c.c"}, []string{})},
+			},
+		},
+		{
+			name: "no configurable values, remainder partition",
+			ctx:  &otherModuleTestContext{},
+			labelList: LabelListAttribute{
+				Value: makeLabelList([]string{"a.a", "b.b", "c.c", "d.d", "e.e"}, []string{}),
+			},
+			filters: LabelPartitions{
+				"A": LabelPartition{Extensions: []string{".a"}, Keep_remainder: true},
+				"B": LabelPartition{Extensions: []string{".b"}},
+				"C": LabelPartition{Extensions: []string{".c"}},
+			},
+			expected: PartitionToLabelListAttribute{
+				"A": LabelListAttribute{Value: makeLabelList([]string{"a.a", "d.d", "e.e"}, []string{})},
+				"B": LabelListAttribute{Value: makeLabelList([]string{"b.b"}, []string{})},
+				"C": LabelListAttribute{Value: makeLabelList([]string{"c.c"}, []string{})},
+			},
+		},
+		{
+			name: "no configurable values, empty partition",
+			ctx:  &otherModuleTestContext{},
+			labelList: LabelListAttribute{
+				Value: makeLabelList([]string{"a.a", "c.c"}, []string{}),
+			},
+			filters: LabelPartitions{
+				"A": LabelPartition{Extensions: []string{".a"}},
+				"B": LabelPartition{Extensions: []string{".b"}},
+				"C": LabelPartition{Extensions: []string{".c"}},
+			},
+			expected: PartitionToLabelListAttribute{
+				"A": LabelListAttribute{Value: makeLabelList([]string{"a.a"}, []string{})},
+				"C": LabelListAttribute{Value: makeLabelList([]string{"c.c"}, []string{})},
+			},
+		},
+		{
+			name: "no configurable values, has map",
+			ctx: &otherModuleTestContext{
+				modules: []testModuleInfo{testModuleInfo{name: "srcs", typ: "fg", dir: "dir"}},
+			},
+			labelList: LabelListAttribute{
+				Value: makeLabelList([]string{"a.a", "srcs", "b.b", "c.c"}, []string{}),
+			},
+			filters: LabelPartitions{
+				"A": LabelPartition{Extensions: []string{".a"}, LabelMapper: labelAddSuffixForTypeMapper("_a", "fg")},
+				"B": LabelPartition{Extensions: []string{".b"}},
+				"C": LabelPartition{Extensions: []string{".c"}},
+			},
+			expected: PartitionToLabelListAttribute{
+				"A": LabelListAttribute{Value: makeLabelList([]string{"a.a", "srcs_a"}, []string{})},
+				"B": LabelListAttribute{Value: makeLabelList([]string{"b.b"}, []string{})},
+				"C": LabelListAttribute{Value: makeLabelList([]string{"c.c"}, []string{})},
+			},
+		},
+		{
+			name: "configurable values, keeps empty if excludes",
+			ctx:  &otherModuleTestContext{},
+			labelList: LabelListAttribute{
+				ConfigurableValues: configurableLabelLists{
+					ArchConfigurationAxis: labelListSelectValues{
+						"x86":    makeLabelList([]string{"a.a", "c.c"}, []string{}),
+						"arm":    makeLabelList([]string{"b.b"}, []string{}),
+						"x86_64": makeLabelList([]string{"b.b"}, []string{"d.d"}),
+					},
+				},
+			},
+			filters: LabelPartitions{
+				"A": LabelPartition{Extensions: []string{".a"}},
+				"B": LabelPartition{Extensions: []string{".b"}},
+				"C": LabelPartition{Extensions: []string{".c"}},
+			},
+			expected: PartitionToLabelListAttribute{
+				"A": LabelListAttribute{
+					ConfigurableValues: configurableLabelLists{
+						ArchConfigurationAxis: labelListSelectValues{
+							"x86":    makeLabelList([]string{"a.a"}, []string{}),
+							"x86_64": makeLabelList([]string{}, []string{"c.c"}),
+						},
+					},
+				},
+				"B": LabelListAttribute{
+					ConfigurableValues: configurableLabelLists{
+						ArchConfigurationAxis: labelListSelectValues{
+							"arm":    makeLabelList([]string{"b.b"}, []string{}),
+							"x86_64": makeLabelList([]string{"b.b"}, []string{"c.c"}),
+						},
+					},
+				},
+				"C": LabelListAttribute{
+					ConfigurableValues: configurableLabelLists{
+						ArchConfigurationAxis: labelListSelectValues{
+							"x86":    makeLabelList([]string{"c.c"}, []string{}),
+							"x86_64": makeLabelList([]string{}, []string{"c.c"}),
+						},
+					},
+				},
+			},
+		},
+		{
+			name: "error for multiple partitions same value",
+			ctx:  &otherModuleTestContext{},
+			labelList: LabelListAttribute{
+				Value: makeLabelList([]string{"a.a", "b.b", "c.c", "d.d", "e.e"}, []string{}),
+			},
+			filters: LabelPartitions{
+				"A":       LabelPartition{Extensions: []string{".a"}},
+				"other A": LabelPartition{Extensions: []string{".a"}},
+			},
+			expected:       PartitionToLabelListAttribute{},
+			expectedErrMsg: proptools.StringPtr(`"a.a" was found in multiple partitions:`),
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			got := PartitionLabelListAttribute(tc.ctx, &tc.labelList, tc.filters)
+
+			if hasErrors, expectsErr := len(tc.ctx.errors) > 0, tc.expectedErrMsg != nil; hasErrors != expectsErr {
+				t.Errorf("Unexpected error(s): %q, expected: %q", tc.ctx.errors, *tc.expectedErrMsg)
+			} else if tc.expectedErrMsg != nil {
+				found := false
+				for _, err := range tc.ctx.errors {
+					if strings.Contains(err, *tc.expectedErrMsg) {
+						found = true
+						break
+					}
+				}
+
+				if !found {
+					t.Errorf("Expected error message: %q, got %q", *tc.expectedErrMsg, tc.ctx.errors)
+				}
+				return
+			}
+
+			if len(tc.expected) != len(got) {
+				t.Errorf("Expected %d partitions, got %d partitions", len(tc.expected), len(got))
+			}
+			for partition, expectedLla := range tc.expected {
+				gotLla, ok := got[partition]
+				if !ok {
+					t.Errorf("Expected partition %q, but it was not found %v", partition, got)
+					continue
+				}
+				expectedLabelList := expectedLla.Value
+				gotLabelList := gotLla.Value
+				if !reflect.DeepEqual(expectedLabelList.Includes, gotLabelList.Includes) {
+					t.Errorf("Expected no config includes %v, got %v", expectedLabelList.Includes, gotLabelList.Includes)
+				}
+				expectedAxes := expectedLla.SortedConfigurationAxes()
+				gotAxes := gotLla.SortedConfigurationAxes()
+				if !reflect.DeepEqual(expectedAxes, gotAxes) {
+					t.Errorf("Expected axes %v, got %v (%#v)", expectedAxes, gotAxes, gotLla)
+				}
+				for _, axis := range expectedLla.SortedConfigurationAxes() {
+					if _, exists := gotLla.ConfigurableValues[axis]; !exists {
+						t.Errorf("Expected %s to be a supported axis, but it was not found", axis)
+					}
+					if expected, got := expectedLla.ConfigurableValues[axis], gotLla.ConfigurableValues[axis]; len(expected) != len(got) {
+						t.Errorf("For axis %q: expected configs %v, got %v", axis, expected, got)
+					}
+					for config, expectedLabelList := range expectedLla.ConfigurableValues[axis] {
+						gotLabelList, exists := gotLla.ConfigurableValues[axis][config]
+						if !exists {
+							t.Errorf("Expected %s to be a supported config, but config was not found", config)
+							continue
+						}
+						if !reflect.DeepEqual(expectedLabelList.Includes, gotLabelList.Includes) {
+							t.Errorf("Expected %s %s includes %v, got %v", axis, config, expectedLabelList.Includes, gotLabelList.Includes)
+						}
+					}
+				}
+			}
+		})
+	}
+}
+
 func TestDeduplicateAxesFromBase(t *testing.T) {
 	attr := StringListAttribute{
 		Value: []string{
diff --git a/bazel/testing.go b/bazel/testing.go
new file mode 100644
index 0000000..23c8350
--- /dev/null
+++ b/bazel/testing.go
@@ -0,0 +1,105 @@
+// Copyright 2021 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bazel
+
+import (
+	"fmt"
+
+	"github.com/google/blueprint"
+)
+
+// testModuleInfo implements blueprint.Module interface with sufficient information to mock a subset of
+// a blueprint ModuleContext
+type testModuleInfo struct {
+	name string
+	typ  string
+	dir  string
+}
+
+// Name returns name for testModuleInfo -- required to implement blueprint.Module
+func (mi testModuleInfo) Name() string {
+	return mi.name
+}
+
+// GenerateBuildActions unused, but required to implmeent blueprint.Module
+func (mi testModuleInfo) GenerateBuildActions(blueprint.ModuleContext) {}
+
+func (mi testModuleInfo) equals(other testModuleInfo) bool {
+	return mi.name == other.name && mi.typ == other.typ && mi.dir == other.dir
+}
+
+// ensure testModuleInfo implements blueprint.Module
+var _ blueprint.Module = testModuleInfo{}
+
+// otherModuleTestContext is a mock context that implements OtherModuleContext
+type otherModuleTestContext struct {
+	modules []testModuleInfo
+	errors  []string
+}
+
+// ModuleFromName retrieves the testModuleInfo corresponding to name, if it exists
+func (omc *otherModuleTestContext) ModuleFromName(name string) (blueprint.Module, bool) {
+	for _, m := range omc.modules {
+		if m.name == name {
+			return m, true
+		}
+	}
+	return testModuleInfo{}, false
+}
+
+// testModuleInfo returns the testModuleInfo corresponding to a blueprint.Module if it exists in omc
+func (omc *otherModuleTestContext) testModuleInfo(m blueprint.Module) (testModuleInfo, bool) {
+	mi, ok := m.(testModuleInfo)
+	if !ok {
+		return testModuleInfo{}, false
+	}
+	for _, other := range omc.modules {
+		if other.equals(mi) {
+			return mi, true
+		}
+	}
+	return testModuleInfo{}, false
+}
+
+// OtherModuleType returns type of m if it exists in omc
+func (omc *otherModuleTestContext) OtherModuleType(m blueprint.Module) string {
+	if mi, ok := omc.testModuleInfo(m); ok {
+		return mi.typ
+	}
+	return ""
+}
+
+// OtherModuleName returns name of m if it exists in omc
+func (omc *otherModuleTestContext) OtherModuleName(m blueprint.Module) string {
+	if mi, ok := omc.testModuleInfo(m); ok {
+		return mi.name
+	}
+	return ""
+}
+
+// OtherModuleDir returns dir of m if it exists in omc
+func (omc *otherModuleTestContext) OtherModuleDir(m blueprint.Module) string {
+	if mi, ok := omc.testModuleInfo(m); ok {
+		return mi.dir
+	}
+	return ""
+}
+
+func (omc *otherModuleTestContext) ModuleErrorf(format string, args ...interface{}) {
+	omc.errors = append(omc.errors, fmt.Sprintf(format, args...))
+}
+
+// Ensure otherModuleTestContext implements OtherModuleContext
+var _ OtherModuleContext = &otherModuleTestContext{}
diff --git a/bp2build/cc_library_conversion_test.go b/bp2build/cc_library_conversion_test.go
index b3a1053..ca44b98 100644
--- a/bp2build/cc_library_conversion_test.go
+++ b/bp2build/cc_library_conversion_test.go
@@ -580,7 +580,7 @@
     "both_source.c",
     "both_source.s",
     "both_source.S",
-        ":both_filegroup",
+    ":both_filegroup",
   ],
     static: {
         srcs: [
@@ -633,9 +633,9 @@
     local_includes = ["."],
     shared = {
         "srcs": [
-            ":shared_filegroup_cpp_srcs",
-            "shared_source.cc",
             "shared_source.cpp",
+            "shared_source.cc",
+            ":shared_filegroup_cpp_srcs",
         ],
         "srcs_as": [
             "shared_source.s",
@@ -648,9 +648,9 @@
         ],
     },
     srcs = [
-        ":both_filegroup_cpp_srcs",
-        "both_source.cc",
         "both_source.cpp",
+        "both_source.cc",
+        ":both_filegroup_cpp_srcs",
     ],
     srcs_as = [
         "both_source.s",
@@ -663,9 +663,9 @@
     ],
     static = {
         "srcs": [
-            ":static_filegroup_cpp_srcs",
-            "static_source.cc",
             "static_source.cpp",
+            "static_source.cc",
+            ":static_filegroup_cpp_srcs",
         ],
         "srcs_as": [
             "static_source.s",
diff --git a/cc/bp2build.go b/cc/bp2build.go
index e48f757..c281c0e 100644
--- a/cc/bp2build.go
+++ b/cc/bp2build.go
@@ -14,12 +14,12 @@
 package cc
 
 import (
-	"fmt"
 	"path/filepath"
 	"strings"
 
 	"android/soong/android"
 	"android/soong/bazel"
+
 	"github.com/google/blueprint"
 
 	"github.com/google/blueprint/proptools"
@@ -41,18 +41,6 @@
 }
 
 func groupSrcsByExtension(ctx android.TopDownMutatorContext, srcs bazel.LabelListAttribute) (cppSrcs, cSrcs, asSrcs bazel.LabelListAttribute) {
-	// Branch srcs into three language-specific groups.
-	// C++ is the "catch-all" group, and comprises generated sources because we don't
-	// know the language of these sources until the genrule is executed.
-	// TODO(b/190006308): Handle language detection of sources in a Bazel rule.
-	isCSrcOrFilegroup := func(s string) bool {
-		return strings.HasSuffix(s, ".c") || strings.HasSuffix(s, "_c_srcs")
-	}
-
-	isAsmSrcOrFilegroup := func(s string) bool {
-		return strings.HasSuffix(s, ".S") || strings.HasSuffix(s, ".s") || strings.HasSuffix(s, "_as_srcs")
-	}
-
 	// Check that a module is a filegroup type named <label>.
 	isFilegroupNamed := func(m android.Module, fullLabel string) bool {
 		if ctx.OtherModuleType(m) != "filegroup" {
@@ -61,54 +49,39 @@
 		labelParts := strings.Split(fullLabel, ":")
 		if len(labelParts) > 2 {
 			// There should not be more than one colon in a label.
-			panic(fmt.Errorf("%s is not a valid Bazel label for a filegroup", fullLabel))
-		} else {
-			return m.Name() == labelParts[len(labelParts)-1]
+			ctx.ModuleErrorf("%s is not a valid Bazel label for a filegroup", fullLabel)
 		}
+		return m.Name() == labelParts[len(labelParts)-1]
 	}
 
-	// Convert the filegroup dependencies into the extension-specific filegroups
-	// filtered in the filegroup.bzl macro.
-	cppFilegroup := func(label string) string {
-		m, exists := ctx.ModuleFromName(label)
-		if exists {
-			aModule, _ := m.(android.Module)
-			if isFilegroupNamed(aModule, label) {
-				label = label + "_cpp_srcs"
+	// Convert filegroup dependencies into extension-specific filegroups filtered in the filegroup.bzl
+	// macro.
+	addSuffixForFilegroup := func(suffix string) bazel.LabelMapper {
+		return func(ctx bazel.OtherModuleContext, label string) (string, bool) {
+			m, exists := ctx.ModuleFromName(label)
+			if !exists {
+				return label, false
 			}
-		}
-		return label
-	}
-	cFilegroup := func(label string) string {
-		m, exists := ctx.ModuleFromName(label)
-		if exists {
 			aModule, _ := m.(android.Module)
-			if isFilegroupNamed(aModule, label) {
-				label = label + "_c_srcs"
+			if !isFilegroupNamed(aModule, label) {
+				return label, false
 			}
+			return label + suffix, true
 		}
-		return label
-	}
-	asFilegroup := func(label string) string {
-		m, exists := ctx.ModuleFromName(label)
-		if exists {
-			aModule, _ := m.(android.Module)
-			if isFilegroupNamed(aModule, label) {
-				label = label + "_as_srcs"
-			}
-		}
-		return label
 	}
 
-	cSrcs = bazel.MapLabelListAttribute(srcs, cFilegroup)
-	cSrcs = bazel.FilterLabelListAttribute(cSrcs, isCSrcOrFilegroup)
+	// TODO(b/190006308): Handle language detection of sources in a Bazel rule.
+	partitioned := bazel.PartitionLabelListAttribute(ctx, &srcs, bazel.LabelPartitions{
+		"c":  bazel.LabelPartition{Extensions: []string{".c"}, LabelMapper: addSuffixForFilegroup("_c_srcs")},
+		"as": bazel.LabelPartition{Extensions: []string{".s", ".S"}, LabelMapper: addSuffixForFilegroup("_as_srcs")},
+		// C++ is the "catch-all" group, and comprises generated sources because we don't
+		// know the language of these sources until the genrule is executed.
+		"cpp": bazel.LabelPartition{Extensions: []string{".cpp", ".cc", ".cxx", ".mm"}, LabelMapper: addSuffixForFilegroup("_cpp_srcs"), Keep_remainder: true},
+	})
 
-	asSrcs = bazel.MapLabelListAttribute(srcs, asFilegroup)
-	asSrcs = bazel.FilterLabelListAttribute(asSrcs, isAsmSrcOrFilegroup)
-
-	cppSrcs = bazel.MapLabelListAttribute(srcs, cppFilegroup)
-	cppSrcs = bazel.SubtractBazelLabelListAttribute(cppSrcs, cSrcs)
-	cppSrcs = bazel.SubtractBazelLabelListAttribute(cppSrcs, asSrcs)
+	cSrcs = partitioned["c"]
+	asSrcs = partitioned["as"]
+	cppSrcs = partitioned["cpp"]
 	return
 }