Merge "Implement topology normalization" into main
diff --git a/services/core/java/com/android/server/display/DisplayTopology.java b/services/core/java/com/android/server/display/DisplayTopology.java
index b01d617..fdadafe 100644
--- a/services/core/java/com/android/server/display/DisplayTopology.java
+++ b/services/core/java/com/android/server/display/DisplayTopology.java
@@ -16,7 +16,13 @@
 
 package com.android.server.display;
 
+import static com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_BOTTOM;
+import static com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_LEFT;
+import static com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_TOP;
+import static com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_RIGHT;
+
 import android.annotation.Nullable;
+import android.graphics.RectF;
 import android.util.IndentingPrintWriter;
 import android.util.Pair;
 import android.util.Slog;
@@ -25,16 +31,21 @@
 import com.android.internal.annotations.VisibleForTesting;
 
 import java.io.PrintWriter;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
-import java.util.LinkedList;
+import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Queue;
 
 /**
  * Represents the relative placement of extended displays.
+ * Does not support concurrent calls, so a lock should be held when calling into this class.
  */
 class DisplayTopology {
     private static final String TAG = "DisplayTopology";
+    private static final float EPSILON = 0.0001f;
 
     /**
      * The topology tree
@@ -58,7 +69,7 @@
      * @param width The width of the display
      * @param height The height of the display
      */
-    void addDisplay(int displayId, double width, double height) {
+    void addDisplay(int displayId, float width, float height) {
         addDisplay(displayId, width, height, /* shouldLog= */ true);
     }
 
@@ -69,10 +80,10 @@
      * @param displayId The logical display ID
      */
     void removeDisplay(int displayId) {
-        if (!isDisplayPresent(displayId, mRoot)) {
+        if (findDisplay(displayId, mRoot) == null) {
             return;
         }
-        Queue<TreeNode> queue = new LinkedList<>();
+        Queue<TreeNode> queue = new ArrayDeque<>();
         queue.add(mRoot);
         mRoot = null;
         while (!queue.isEmpty()) {
@@ -115,7 +126,11 @@
         }
     }
 
-    private void addDisplay(int displayId, double width, double height, boolean shouldLog) {
+    private void addDisplay(int displayId, float width, float height, boolean shouldLog) {
+        if (findDisplay(displayId, mRoot) != null) {
+            throw new IllegalArgumentException(
+                    "DisplayTopology: attempting to add a display that already exists");
+        }
         if (mRoot == null) {
             mRoot = new TreeNode(displayId, width, height, /* position= */ null, /* offset= */ 0);
             mPrimaryDisplayId = displayId;
@@ -124,9 +139,8 @@
             }
         } else if (mRoot.mChildren.isEmpty()) {
             // This is the 2nd display. Align the middles of the top and bottom edges.
-            double offset = mRoot.mWidth / 2 - width / 2;
-            TreeNode display = new TreeNode(displayId, width, height,
-                    TreeNode.Position.POSITION_TOP, offset);
+            float offset = mRoot.mWidth / 2 - width / 2;
+            TreeNode display = new TreeNode(displayId, width, height, POSITION_TOP, offset);
             mRoot.mChildren.add(display);
             if (shouldLog) {
                 Slog.i(TAG, "Second display added: " + display + ", parent ID: "
@@ -134,8 +148,8 @@
             }
         } else {
             TreeNode rightMostDisplay = findRightMostDisplay(mRoot, mRoot.mWidth).first;
-            TreeNode newDisplay = new TreeNode(displayId, width, height,
-                    TreeNode.Position.POSITION_RIGHT, /* offset= */ 0);
+            TreeNode newDisplay = new TreeNode(displayId, width, height, POSITION_RIGHT,
+                    /* offset= */ 0);
             rightMostDisplay.mChildren.add(newDisplay);
             if (shouldLog) {
                 Slog.i(TAG, "Display added: " + newDisplay + ", parent ID: "
@@ -150,11 +164,11 @@
      * @return The display that is the furthest to the right and the x position of the right edge
      * of that display
      */
-    private Pair<TreeNode, Double> findRightMostDisplay(TreeNode display, double xPos) {
-        Pair<TreeNode, Double> result = new Pair<>(display, xPos);
+    private static Pair<TreeNode, Float> findRightMostDisplay(TreeNode display, float xPos) {
+        Pair<TreeNode, Float> result = new Pair<>(display, xPos);
         for (TreeNode child : display.mChildren) {
             // The x position of the right edge of the child
-            double childXPos;
+            float childXPos;
             switch (child.mPosition) {
                 case POSITION_LEFT -> childXPos = xPos - display.mWidth;
                 case POSITION_TOP, POSITION_BOTTOM ->
@@ -164,7 +178,7 @@
             }
 
             // Recursive call - find the rightmost display starting from the child
-            Pair<TreeNode, Double> childResult = findRightMostDisplay(child, childXPos);
+            Pair<TreeNode, Float> childResult = findRightMostDisplay(child, childXPos);
             // Check if the one found is further right
             if (childResult.second > result.second) {
                 result = new Pair<>(childResult.first, childResult.second);
@@ -173,19 +187,200 @@
         return result;
     }
 
-    private boolean isDisplayPresent(int displayId, TreeNode node) {
-        if (node == null) {
-            return false;
+    @Nullable
+    private static TreeNode findDisplay(int displayId, TreeNode startingNode) {
+        if (startingNode == null) {
+            return null;
         }
-        if (node.mDisplayId == displayId) {
-            return true;
+        if (startingNode.mDisplayId == displayId) {
+            return startingNode;
         }
-        for (TreeNode child : node.mChildren) {
-            if (isDisplayPresent(displayId, child)) {
-                return true;
+        for (TreeNode child : startingNode.mChildren) {
+            TreeNode display = findDisplay(displayId, child);
+            if (display != null) {
+                return display;
             }
         }
-        return false;
+        return null;
+    }
+
+    /**
+     * Get information about the topology that will be used for the normalization algorithm.
+     * Assigns origins to each display to compute the bounds.
+     * @param bounds The map where the bounds of each display will be put
+     * @param depths The map where the depths of each display in the tree will be put
+     * @param parents The map where the parent of each display will be put
+     * @param display The starting node
+     * @param x The starting x position
+     * @param y The starting y position
+     * @param depth The starting depth
+     */
+    private static void getInfo(Map<TreeNode, RectF> bounds, Map<TreeNode, Integer> depths,
+            Map<TreeNode, TreeNode> parents, TreeNode display, float x, float y, int depth) {
+        bounds.put(display, new RectF(x, y, x + display.mWidth, y + display.mHeight));
+        depths.put(display, depth);
+        for (TreeNode child : display.mChildren) {
+            parents.put(child, display);
+            if (child.mPosition == POSITION_LEFT) {
+                getInfo(bounds, depths, parents, child, x - child.mWidth, y + child.mOffset,
+                        depth + 1);
+            } else if (child.mPosition == POSITION_RIGHT) {
+                getInfo(bounds, depths, parents, child, x + display.mWidth, y + child.mOffset,
+                        depth + 1);
+            } else if (child.mPosition == POSITION_TOP) {
+                getInfo(bounds, depths, parents, child, x + child.mOffset, y - child.mHeight,
+                        depth + 1);
+            } else if (child.mPosition == POSITION_BOTTOM) {
+                getInfo(bounds, depths, parents, child, x + child.mOffset, y + display.mHeight,
+                        depth + 1);
+            }
+        }
+    }
+
+    /**
+     * Update the topology to remove any overlaps between displays.
+     */
+    @VisibleForTesting
+    void normalize() {
+        if (mRoot == null) {
+            return;
+        }
+        Map<TreeNode, RectF> bounds = new HashMap<>();
+        Map<TreeNode, Integer> depths = new HashMap<>();
+        Map<TreeNode, TreeNode> parents = new HashMap<>();
+        getInfo(bounds, depths, parents, mRoot, /* x= */ 0, /* y= */ 0, /* depth= */ 0);
+
+        // Sort the displays first by their depth in the tree, then by the distance of their top
+        // left point from the root display's origin (0, 0). This way we process the displays
+        // starting at the root and we push out a display if necessary.
+        Comparator<TreeNode> comparator = (d1, d2) -> {
+            if (d1 == d2) {
+                return 0;
+            }
+
+            int compareDepths = Integer.compare(depths.get(d1), depths.get(d2));
+            if (compareDepths != 0) {
+                return compareDepths;
+            }
+
+            RectF bounds1 = bounds.get(d1);
+            RectF bounds2 = bounds.get(d2);
+            return Double.compare(Math.hypot(bounds1.left, bounds1.top),
+                    Math.hypot(bounds2.left, bounds2.top));
+        };
+        List<TreeNode> displays = new ArrayList<>(bounds.keySet());
+        displays.sort(comparator);
+
+        for (int i = 1; i < displays.size(); i++) {
+            TreeNode targetDisplay = displays.get(i);
+            TreeNode lastIntersectingSourceDisplay = null;
+            float lastOffsetX = 0;
+            float lastOffsetY = 0;
+
+            for (int j = 0; j < i; j++) {
+                TreeNode sourceDisplay = displays.get(j);
+                RectF sourceBounds = bounds.get(sourceDisplay);
+                RectF targetBounds = bounds.get(targetDisplay);
+
+                if (!RectF.intersects(sourceBounds, targetBounds)) {
+                    continue;
+                }
+
+                // Find the offset by which to move the display. Pick the smaller one among the x
+                // and y axes.
+                float offsetX = targetBounds.left >= 0
+                        ? sourceBounds.right - targetBounds.left
+                        : sourceBounds.left - targetBounds.right;
+                float offsetY = targetBounds.top >= 0
+                        ? sourceBounds.bottom - targetBounds.top
+                        : sourceBounds.top - targetBounds.bottom;
+                if (Math.abs(offsetX) <= Math.abs(offsetY)) {
+                    targetBounds.left += offsetX;
+                    targetBounds.right += offsetX;
+                    // We need to also update the offset in the tree
+                    if (targetDisplay.mPosition == POSITION_TOP
+                            || targetDisplay.mPosition == POSITION_BOTTOM) {
+                        targetDisplay.mOffset += offsetX;
+                    }
+                    offsetY = 0;
+                } else {
+                    targetBounds.top += offsetY;
+                    targetBounds.bottom += offsetY;
+                    // We need to also update the offset in the tree
+                    if (targetDisplay.mPosition == POSITION_LEFT
+                            || targetDisplay.mPosition == POSITION_RIGHT) {
+                        targetDisplay.mOffset += offsetY;
+                    }
+                    offsetX = 0;
+                }
+
+                lastIntersectingSourceDisplay = sourceDisplay;
+                lastOffsetX = offsetX;
+                lastOffsetY = offsetY;
+            }
+
+            // Now re-parent the target display to the last intersecting source display if it no
+            // longer touches its parent.
+            if (lastIntersectingSourceDisplay == null) {
+                // There was no overlap.
+                continue;
+            }
+            TreeNode parent = parents.get(targetDisplay);
+            if (parent == lastIntersectingSourceDisplay) {
+                // The displays are moved in such a way that they're adjacent to the intersecting
+                // display. If the last intersecting display happens to be the parent then we
+                // already know that the display is adjacent to its parent.
+                continue;
+            }
+
+            RectF childBounds = bounds.get(targetDisplay);
+            RectF parentBounds = bounds.get(parent);
+            // Check that the edges are on the same line
+            boolean areTouching = switch (targetDisplay.mPosition) {
+                case POSITION_LEFT -> floatEquals(parentBounds.left, childBounds.right);
+                case POSITION_RIGHT -> floatEquals(parentBounds.right, childBounds.left);
+                case POSITION_TOP -> floatEquals(parentBounds.top, childBounds.bottom);
+                case POSITION_BOTTOM -> floatEquals(parentBounds.bottom, childBounds.top);
+            };
+            // Check that the offset is within bounds
+            areTouching &= switch (targetDisplay.mPosition) {
+                case POSITION_LEFT, POSITION_RIGHT ->
+                        childBounds.bottom + EPSILON >= parentBounds.top
+                                && childBounds.top <= parentBounds.bottom + EPSILON;
+                case POSITION_TOP, POSITION_BOTTOM ->
+                        childBounds.right + EPSILON >= parentBounds.left
+                                && childBounds.left <= parentBounds.right + EPSILON;
+            };
+
+            if (!areTouching) {
+                // Re-parent the display.
+                parent.mChildren.remove(targetDisplay);
+                RectF lastIntersectingSourceDisplayBounds =
+                        bounds.get(lastIntersectingSourceDisplay);
+                lastIntersectingSourceDisplay.mChildren.add(targetDisplay);
+
+                if (lastOffsetX != 0) {
+                    targetDisplay.mPosition = lastOffsetX > 0 ? POSITION_RIGHT : POSITION_LEFT;
+                    targetDisplay.mOffset =
+                            childBounds.top - lastIntersectingSourceDisplayBounds.top;
+                } else if (lastOffsetY != 0) {
+                    targetDisplay.mPosition = lastOffsetY > 0 ? POSITION_BOTTOM : POSITION_TOP;
+                    targetDisplay.mOffset =
+                            childBounds.left - lastIntersectingSourceDisplayBounds.left;
+                }
+            }
+        }
+    }
+
+    /**
+     * Tests whether two brightness float values are within a small enough tolerance
+     * of each other.
+     * @param a first float to compare
+     * @param b second float to compare
+     * @return whether the two values are within a small enough tolerance value
+     */
+    public static boolean floatEquals(float a, float b) {
+        return a == b || Float.isNaN(a) && Float.isNaN(b) || Math.abs(a - b) < EPSILON;
     }
 
     @VisibleForTesting
@@ -201,13 +396,13 @@
          * The width of the display in density-independent pixels (dp).
          */
         @VisibleForTesting
-        double mWidth;
+        float mWidth;
 
         /**
          * The height of the display in density-independent pixels (dp).
          */
         @VisibleForTesting
-        double mHeight;
+        float mHeight;
 
         /**
          * The position of this display relative to its parent.
@@ -222,13 +417,12 @@
          * used is density-independent pixels (dp).
          */
         @VisibleForTesting
-        double mOffset;
+        float mOffset;
 
         @VisibleForTesting
         final List<TreeNode> mChildren = new ArrayList<>();
 
-        TreeNode(int displayId, double width, double height, Position position,
-                double offset) {
+        TreeNode(int displayId, float width, float height, Position position, float offset) {
             mDisplayId = displayId;
             mWidth = width;
             mHeight = height;
diff --git a/services/core/java/com/android/server/display/DisplayTopologyCoordinator.java b/services/core/java/com/android/server/display/DisplayTopologyCoordinator.java
index 46358dfd..b101e58 100644
--- a/services/core/java/com/android/server/display/DisplayTopologyCoordinator.java
+++ b/services/core/java/com/android/server/display/DisplayTopologyCoordinator.java
@@ -89,8 +89,8 @@
      * @param info The display info
      * @return The width of the display in dp
      */
-    private double getWidth(DisplayInfo info) {
-        return info.logicalWidth * (double) DisplayMetrics.DENSITY_DEFAULT
+    private float getWidth(DisplayInfo info) {
+        return info.logicalWidth * (float) DisplayMetrics.DENSITY_DEFAULT
                 / info.logicalDensityDpi;
     }
 
@@ -98,8 +98,8 @@
      * @param info The display info
      * @return The height of the display in dp
      */
-    private double getHeight(DisplayInfo info) {
-        return info.logicalHeight * (double) DisplayMetrics.DENSITY_DEFAULT
+    private float getHeight(DisplayInfo info) {
+        return info.logicalHeight * (float) DisplayMetrics.DENSITY_DEFAULT
                 / info.logicalDensityDpi;
     }
 
diff --git a/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyCoordinatorTest.kt b/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyCoordinatorTest.kt
index 17af633..85e7356 100644
--- a/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyCoordinatorTest.kt
+++ b/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyCoordinatorTest.kt
@@ -21,7 +21,7 @@
 import android.view.DisplayInfo
 import org.junit.Before
 import org.junit.Test
-import org.mockito.ArgumentMatchers.anyDouble
+import org.mockito.ArgumentMatchers.anyFloat
 import org.mockito.ArgumentMatchers.anyInt
 import org.mockito.kotlin.mock
 import org.mockito.kotlin.never
@@ -55,9 +55,9 @@
 
         coordinator.onDisplayAdded(displayInfo)
 
-        val widthDp = displayInfo.logicalWidth * (DisplayMetrics.DENSITY_DEFAULT.toDouble()
+        val widthDp = displayInfo.logicalWidth * (DisplayMetrics.DENSITY_DEFAULT.toFloat()
                 / displayInfo.logicalDensityDpi)
-        val heightDp = displayInfo.logicalHeight * (DisplayMetrics.DENSITY_DEFAULT.toDouble()
+        val heightDp = displayInfo.logicalHeight * (DisplayMetrics.DENSITY_DEFAULT.toFloat()
                 / displayInfo.logicalDensityDpi)
         verify(mockTopology).addDisplay(displayInfo.displayId, widthDp, heightDp)
     }
@@ -68,7 +68,7 @@
 
         coordinator.onDisplayAdded(displayInfo)
 
-        verify(mockTopology, never()).addDisplay(anyInt(), anyDouble(), anyDouble())
+        verify(mockTopology, never()).addDisplay(anyInt(), anyFloat(), anyFloat())
     }
 
     @Test
@@ -78,6 +78,6 @@
 
         coordinator.onDisplayAdded(displayInfo)
 
-        verify(mockTopology, never()).addDisplay(anyInt(), anyDouble(), anyDouble())
+        verify(mockTopology, never()).addDisplay(anyInt(), anyFloat(), anyFloat())
     }
 }
\ No newline at end of file
diff --git a/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyTest.kt b/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyTest.kt
index f3a8d841..cd8c26d 100644
--- a/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyTest.kt
+++ b/services/tests/displayservicetests/src/com/android/server/display/DisplayTopologyTest.kt
@@ -17,6 +17,9 @@
 package com.android.server.display
 
 import android.view.Display
+import com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_BOTTOM
+import com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_TOP
+import com.android.server.display.DisplayTopology.TreeNode.Position.POSITION_RIGHT
 import com.google.common.truth.Truth.assertThat
 import org.junit.Test
 
@@ -26,8 +29,8 @@
     @Test
     fun addOneDisplay() {
         val displayId = 1
-        val width = 800.0
-        val height = 600.0
+        val width = 800f
+        val height = 600f
 
         topology.addDisplay(displayId, width, height)
 
@@ -43,12 +46,12 @@
     @Test
     fun addTwoDisplays() {
         val displayId1 = 1
-        val width1 = 800.0
-        val height1 = 600.0
+        val width1 = 800f
+        val height1 = 600f
 
         val displayId2 = 2
-        val width2 = 1000.0
-        val height2 = 1500.0
+        val width2 = 1000f
+        val height2 = 1500f
 
         topology.addDisplay(displayId1, width1, height1)
         topology.addDisplay(displayId2, width2, height2)
@@ -66,20 +69,19 @@
         assertThat(display2.mWidth).isEqualTo(width2)
         assertThat(display2.mHeight).isEqualTo(height2)
         assertThat(display2.mChildren).isEmpty()
-        assertThat(display2.mPosition).isEqualTo(
-            DisplayTopology.TreeNode.Position.POSITION_TOP)
+        assertThat(display2.mPosition).isEqualTo(POSITION_TOP)
         assertThat(display2.mOffset).isEqualTo(width1 / 2 - width2 / 2)
     }
 
     @Test
     fun addManyDisplays() {
         val displayId1 = 1
-        val width1 = 800.0
-        val height1 = 600.0
+        val width1 = 800f
+        val height1 = 600f
 
         val displayId2 = 2
-        val width2 = 1000.0
-        val height2 = 1500.0
+        val width2 = 1000f
+        val height2 = 1500f
 
         topology.addDisplay(displayId1, width1, height1)
         topology.addDisplay(displayId2, width2, height2)
@@ -102,8 +104,7 @@
         assertThat(display2.mWidth).isEqualTo(width2)
         assertThat(display2.mHeight).isEqualTo(height2)
         assertThat(display2.mChildren).hasSize(1)
-        assertThat(display2.mPosition).isEqualTo(
-            DisplayTopology.TreeNode.Position.POSITION_TOP)
+        assertThat(display2.mPosition).isEqualTo(POSITION_TOP)
         assertThat(display2.mOffset).isEqualTo(width1 / 2 - width2 / 2)
 
         var display = display2
@@ -114,8 +115,7 @@
             assertThat(display.mHeight).isEqualTo(height1)
             // The last display should have no children
             assertThat(display.mChildren).hasSize(if (i < noOfDisplays) 1 else 0)
-            assertThat(display.mPosition).isEqualTo(
-                DisplayTopology.TreeNode.Position.POSITION_RIGHT)
+            assertThat(display.mPosition).isEqualTo(POSITION_RIGHT)
             assertThat(display.mOffset).isEqualTo(0)
         }
     }
@@ -123,12 +123,12 @@
     @Test
     fun removeDisplays() {
         val displayId1 = 1
-        val width1 = 800.0
-        val height1 = 600.0
+        val width1 = 800f
+        val height1 = 600f
 
         val displayId2 = 2
-        val width2 = 1000.0
-        val height2 = 1500.0
+        val width2 = 1000f
+        val height2 = 1500f
 
         topology.addDisplay(displayId1, width1, height1)
         topology.addDisplay(displayId2, width2, height2)
@@ -154,8 +154,7 @@
         assertThat(display2.mWidth).isEqualTo(width2)
         assertThat(display2.mHeight).isEqualTo(height2)
         assertThat(display2.mChildren).hasSize(1)
-        assertThat(display2.mPosition).isEqualTo(
-            DisplayTopology.TreeNode.Position.POSITION_TOP)
+        assertThat(display2.mPosition).isEqualTo(POSITION_TOP)
         assertThat(display2.mOffset).isEqualTo(width1 / 2 - width2 / 2)
 
         var display = display2
@@ -169,8 +168,7 @@
             assertThat(display.mHeight).isEqualTo(height1)
             // The last display should have no children
             assertThat(display.mChildren).hasSize(if (i < noOfDisplays) 1 else 0)
-            assertThat(display.mPosition).isEqualTo(
-                DisplayTopology.TreeNode.Position.POSITION_RIGHT)
+            assertThat(display.mPosition).isEqualTo(POSITION_RIGHT)
             assertThat(display.mOffset).isEqualTo(0)
         }
 
@@ -194,8 +192,7 @@
         assertThat(display2.mWidth).isEqualTo(width2)
         assertThat(display2.mHeight).isEqualTo(height2)
         assertThat(display2.mChildren).hasSize(1)
-        assertThat(display2.mPosition).isEqualTo(
-            DisplayTopology.TreeNode.Position.POSITION_TOP)
+        assertThat(display2.mPosition).isEqualTo(POSITION_TOP)
         assertThat(display2.mOffset).isEqualTo(width1 / 2 - width2 / 2)
 
         display = display2
@@ -209,8 +206,7 @@
             assertThat(display.mHeight).isEqualTo(height1)
             // The last display should have no children
             assertThat(display.mChildren).hasSize(if (i < noOfDisplays) 1 else 0)
-            assertThat(display.mPosition).isEqualTo(
-                DisplayTopology.TreeNode.Position.POSITION_RIGHT)
+            assertThat(display.mPosition).isEqualTo(POSITION_RIGHT)
             assertThat(display.mOffset).isEqualTo(0)
         }
     }
@@ -218,8 +214,8 @@
     @Test
     fun removeAllDisplays() {
         val displayId = 1
-        val width = 800.0
-        val height = 600.0
+        val width = 800f
+        val height = 600f
 
         topology.addDisplay(displayId, width, height)
         topology.removeDisplay(displayId)
@@ -231,8 +227,8 @@
     @Test
     fun removeDisplayThatDoesNotExist() {
         val displayId = 1
-        val width = 800.0
-        val height = 600.0
+        val width = 800f
+        val height = 600f
 
         topology.addDisplay(displayId, width, height)
         topology.removeDisplay(3)
@@ -245,4 +241,236 @@
         assertThat(display.mHeight).isEqualTo(height)
         assertThat(display.mChildren).isEmpty()
     }
+
+    @Test
+    fun removePrimaryDisplay() {
+        val displayId1 = 1
+        val displayId2 = 2
+        val width = 800f
+        val height = 600f
+
+        topology.addDisplay(displayId1, width, height)
+        topology.addDisplay(displayId2, width, height)
+        topology.mPrimaryDisplayId = displayId2
+        topology.removeDisplay(displayId2)
+
+        assertThat(topology.mPrimaryDisplayId).isEqualTo(displayId1)
+        val display = topology.mRoot!!
+        assertThat(display.mDisplayId).isEqualTo(displayId1)
+        assertThat(display.mWidth).isEqualTo(width)
+        assertThat(display.mHeight).isEqualTo(height)
+        assertThat(display.mChildren).isEmpty()
+    }
+
+    @Test
+    fun normalization_noOverlaps_leavesTopologyUnchanged() {
+        val display1 = DisplayTopology.TreeNode(/* displayId= */ 1, /* width= */ 200f,
+            /* height= */ 600f, /* position= */ null, /* offset= */ 0f)
+        topology.mRoot = display1
+
+        val display2 = DisplayTopology.TreeNode(/* displayId= */ 2, /* width= */ 600f,
+            /* height= */ 200f, POSITION_RIGHT, /* offset= */ 0f)
+        display1.mChildren.add(display2)
+
+        val primaryDisplayId = 3
+        val display3 = DisplayTopology.TreeNode(primaryDisplayId, /* width= */ 600f,
+            /* height= */ 200f, POSITION_RIGHT, /* offset= */ 400f)
+        display1.mChildren.add(display3)
+        topology.mPrimaryDisplayId = primaryDisplayId
+
+        val display4 = DisplayTopology.TreeNode(/* displayId= */ 4, /* width= */ 200f,
+            /* height= */ 600f, POSITION_RIGHT, /* offset= */ 0f)
+        display2.mChildren.add(display4)
+
+        topology.normalize()
+
+        assertThat(topology.mPrimaryDisplayId).isEqualTo(primaryDisplayId)
+
+        val actualDisplay1 = topology.mRoot!!
+        assertThat(actualDisplay1.mDisplayId).isEqualTo(1)
+        assertThat(actualDisplay1.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay1.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay1.mChildren).hasSize(2)
+
+        val actualDisplay2 = actualDisplay1.mChildren[0]
+        assertThat(actualDisplay2.mDisplayId).isEqualTo(2)
+        assertThat(actualDisplay2.mWidth).isEqualTo(600f)
+        assertThat(actualDisplay2.mHeight).isEqualTo(200f)
+        assertThat(actualDisplay2.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay2.mOffset).isEqualTo(0f)
+        assertThat(actualDisplay2.mChildren).hasSize(1)
+
+        val actualDisplay3 = actualDisplay1.mChildren[1]
+        assertThat(actualDisplay3.mDisplayId).isEqualTo(3)
+        assertThat(actualDisplay3.mWidth).isEqualTo(600f)
+        assertThat(actualDisplay3.mHeight).isEqualTo(200f)
+        assertThat(actualDisplay3.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay3.mOffset).isEqualTo(400f)
+        assertThat(actualDisplay3.mChildren).isEmpty()
+
+        val actualDisplay4 = actualDisplay2.mChildren[0]
+        assertThat(actualDisplay4.mDisplayId).isEqualTo(4)
+        assertThat(actualDisplay4.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay4.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay4.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay4.mOffset).isEqualTo(0f)
+        assertThat(actualDisplay4.mChildren).isEmpty()
+    }
+
+    @Test
+    fun normalization_moveDisplayWithoutReparenting() {
+        val display1 = DisplayTopology.TreeNode(/* displayId= */ 1, /* width= */ 200f,
+            /* height= */ 600f, /* position= */ null, /* offset= */ 0f)
+        topology.mRoot = display1
+
+        val display2 = DisplayTopology.TreeNode(/* displayId= */ 2, /* width= */ 200f,
+            /* height= */ 600f, POSITION_RIGHT, /* offset= */ 0f)
+        display1.mChildren.add(display2)
+
+        val primaryDisplayId = 3
+        val display3 = DisplayTopology.TreeNode(primaryDisplayId, /* width= */ 600f,
+            /* height= */ 200f, POSITION_RIGHT, /* offset= */ 10f)
+        display1.mChildren.add(display3)
+        topology.mPrimaryDisplayId = primaryDisplayId
+
+        val display4 = DisplayTopology.TreeNode(/* displayId= */ 4, /* width= */ 200f,
+            /* height= */ 600f, POSITION_RIGHT, /* offset= */ 0f)
+        display2.mChildren.add(display4)
+
+        // Display 3 becomes a child of display 2. Display 4 gets moved without changing its parent.
+        topology.normalize()
+
+        assertThat(topology.mPrimaryDisplayId).isEqualTo(primaryDisplayId)
+
+        val actualDisplay1 = topology.mRoot!!
+        assertThat(actualDisplay1.mDisplayId).isEqualTo(1)
+        assertThat(actualDisplay1.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay1.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay1.mChildren).hasSize(1)
+
+        val actualDisplay2 = actualDisplay1.mChildren[0]
+        assertThat(actualDisplay2.mDisplayId).isEqualTo(2)
+        assertThat(actualDisplay2.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay2.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay2.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay2.mOffset).isEqualTo(0f)
+        assertThat(actualDisplay2.mChildren).hasSize(2)
+
+        val actualDisplay3 = actualDisplay2.mChildren[1]
+        assertThat(actualDisplay3.mDisplayId).isEqualTo(3)
+        assertThat(actualDisplay3.mWidth).isEqualTo(600f)
+        assertThat(actualDisplay3.mHeight).isEqualTo(200f)
+        assertThat(actualDisplay3.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay3.mOffset).isEqualTo(10f)
+        assertThat(actualDisplay3.mChildren).isEmpty()
+
+        val actualDisplay4 = actualDisplay2.mChildren[0]
+        assertThat(actualDisplay4.mDisplayId).isEqualTo(4)
+        assertThat(actualDisplay4.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay4.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay4.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay4.mOffset).isEqualTo(210f)
+        assertThat(actualDisplay4.mChildren).isEmpty()
+    }
+
+    @Test
+    fun normalization_moveDisplayWithoutReparenting_offsetOutOfBounds() {
+        val display1 = DisplayTopology.TreeNode(/* displayId= */ 1, /* width= */ 200f,
+            /* height= */ 50f, /* position= */ null, /* offset= */ 0f)
+        topology.mRoot = display1
+
+        val display2 = DisplayTopology.TreeNode(/* displayId= */ 2, /* width= */ 600f,
+            /* height= */ 200f, POSITION_RIGHT, /* offset= */ 0f)
+        display1.mChildren.add(display2)
+
+        val primaryDisplayId = 3
+        val display3 = DisplayTopology.TreeNode(primaryDisplayId, /* width= */ 600f,
+            /* height= */ 200f, POSITION_RIGHT, /* offset= */ 10f)
+        display1.mChildren.add(display3)
+        topology.mPrimaryDisplayId = primaryDisplayId
+
+        // Display 3 gets moved and its left side is still on the same line as the right side
+        // of Display 1, but it no longer touches it (the offset is out of bounds), so Display 2
+        // becomes its new parent.
+        topology.normalize()
+
+        assertThat(topology.mPrimaryDisplayId).isEqualTo(primaryDisplayId)
+
+        val actualDisplay1 = topology.mRoot!!
+        assertThat(actualDisplay1.mDisplayId).isEqualTo(1)
+        assertThat(actualDisplay1.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay1.mHeight).isEqualTo(50f)
+        assertThat(actualDisplay1.mChildren).hasSize(1)
+
+        val actualDisplay2 = actualDisplay1.mChildren[0]
+        assertThat(actualDisplay2.mDisplayId).isEqualTo(2)
+        assertThat(actualDisplay2.mWidth).isEqualTo(600f)
+        assertThat(actualDisplay2.mHeight).isEqualTo(200f)
+        assertThat(actualDisplay2.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay2.mOffset).isEqualTo(0f)
+        assertThat(actualDisplay2.mChildren).hasSize(1)
+
+        val actualDisplay3 = actualDisplay2.mChildren[0]
+        assertThat(actualDisplay3.mDisplayId).isEqualTo(3)
+        assertThat(actualDisplay3.mWidth).isEqualTo(600f)
+        assertThat(actualDisplay3.mHeight).isEqualTo(200f)
+        assertThat(actualDisplay3.mPosition).isEqualTo(POSITION_BOTTOM)
+        assertThat(actualDisplay3.mOffset).isEqualTo(0f)
+        assertThat(actualDisplay3.mChildren).isEmpty()
+    }
+
+    @Test
+    fun normalization_moveAndReparentDisplay() {
+        val display1 = DisplayTopology.TreeNode(/* displayId= */ 1, /* width= */ 200f,
+            /* height= */ 600f, /* position= */ null, /* offset= */ 0f)
+        topology.mRoot = display1
+
+        val display2 = DisplayTopology.TreeNode(/* displayId= */ 2, /* width= */ 200f,
+            /* height= */ 600f, POSITION_RIGHT, /* offset= */ 0f)
+        display1.mChildren.add(display2)
+
+        val primaryDisplayId = 3
+        val display3 = DisplayTopology.TreeNode(primaryDisplayId, /* width= */ 600f,
+            /* height= */ 200f, POSITION_RIGHT, /* offset= */ 400f)
+        display1.mChildren.add(display3)
+        topology.mPrimaryDisplayId = primaryDisplayId
+
+        val display4 = DisplayTopology.TreeNode(/* displayId= */ 4, /* width= */ 200f,
+            /* height= */ 600f, POSITION_RIGHT, /* offset= */ 0f)
+        display2.mChildren.add(display4)
+
+        topology.normalize()
+
+        assertThat(topology.mPrimaryDisplayId).isEqualTo(primaryDisplayId)
+
+        val actualDisplay1 = topology.mRoot!!
+        assertThat(actualDisplay1.mDisplayId).isEqualTo(1)
+        assertThat(actualDisplay1.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay1.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay1.mChildren).hasSize(1)
+
+        val actualDisplay2 = actualDisplay1.mChildren[0]
+        assertThat(actualDisplay2.mDisplayId).isEqualTo(2)
+        assertThat(actualDisplay2.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay2.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay2.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay2.mOffset).isEqualTo(0f)
+        assertThat(actualDisplay2.mChildren).hasSize(1)
+
+        val actualDisplay3 = actualDisplay2.mChildren[0]
+        assertThat(actualDisplay3.mDisplayId).isEqualTo(3)
+        assertThat(actualDisplay3.mWidth).isEqualTo(600f)
+        assertThat(actualDisplay3.mHeight).isEqualTo(200f)
+        assertThat(actualDisplay3.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay3.mOffset).isEqualTo(400f)
+        assertThat(actualDisplay3.mChildren).hasSize(1)
+
+        val actualDisplay4 = actualDisplay3.mChildren[0]
+        assertThat(actualDisplay4.mDisplayId).isEqualTo(4)
+        assertThat(actualDisplay4.mWidth).isEqualTo(200f)
+        assertThat(actualDisplay4.mHeight).isEqualTo(600f)
+        assertThat(actualDisplay4.mPosition).isEqualTo(POSITION_RIGHT)
+        assertThat(actualDisplay4.mOffset).isEqualTo(-400f)
+        assertThat(actualDisplay4.mChildren).isEmpty()
+    }
 }
\ No newline at end of file