diff --git a/core/res/res/values/config.xml b/core/res/res/values/config.xml
index c0c8618..249881e 100644
--- a/core/res/res/values/config.xml
+++ b/core/res/res/values/config.xml
@@ -4325,6 +4325,10 @@
          or empty if the default should be used. -->
     <string translatable="false" name="config_deviceSpecificDisplayAreaPolicyProvider"></string>
 
+    <!-- Class name of the device specific implementation of DeviceStatePolicy.Provider
+        or empty if the default should be used. -->
+    <string translatable="false" name="config_deviceSpecificDeviceStatePolicyProvider"></string>
+
     <!-- Component name of media projection permission dialog -->
     <string name="config_mediaProjectionPermissionDialogComponent" translatable="false">com.android.systemui/com.android.systemui.media.MediaProjectionPermissionActivity</string>
 
diff --git a/core/res/res/values/symbols.xml b/core/res/res/values/symbols.xml
index 8c8ef12..bcc3a6d 100644
--- a/core/res/res/values/symbols.xml
+++ b/core/res/res/values/symbols.xml
@@ -4667,4 +4667,6 @@
   <java-symbol type="bool" name="config_enableSafetyCenter" />
 
   <java-symbol type="string" name="config_deviceManagerUpdater" />
+
+  <java-symbol type="string" name="config_deviceSpecificDeviceStatePolicyProvider" />
 </resources>
diff --git a/services/core/java/com/android/server/devicestate/DeviceStateManagerService.java b/services/core/java/com/android/server/devicestate/DeviceStateManagerService.java
index 709af91..c2ca3a5 100644
--- a/services/core/java/com/android/server/devicestate/DeviceStateManagerService.java
+++ b/services/core/java/com/android/server/devicestate/DeviceStateManagerService.java
@@ -51,7 +51,6 @@
 import com.android.server.DisplayThread;
 import com.android.server.LocalServices;
 import com.android.server.SystemService;
-import com.android.server.policy.DeviceStatePolicyImpl;
 import com.android.server.wm.ActivityTaskManagerInternal;
 import com.android.server.wm.WindowProcessController;
 
@@ -142,7 +141,9 @@
     private final SparseArray<ProcessRecord> mProcessRecords = new SparseArray<>();
 
     public DeviceStateManagerService(@NonNull Context context) {
-        this(context, new DeviceStatePolicyImpl(context));
+        this(context, DeviceStatePolicy.Provider
+                .fromResources(context.getResources())
+                .instantiate(context));
     }
 
     @VisibleForTesting
diff --git a/services/core/java/com/android/server/devicestate/DeviceStatePolicy.java b/services/core/java/com/android/server/devicestate/DeviceStatePolicy.java
index 274b8e5..5c4e2f3 100644
--- a/services/core/java/com/android/server/devicestate/DeviceStatePolicy.java
+++ b/services/core/java/com/android/server/devicestate/DeviceStatePolicy.java
@@ -17,6 +17,11 @@
 package com.android.server.devicestate;
 
 import android.annotation.NonNull;
+import android.content.Context;
+import android.content.res.Resources;
+import android.text.TextUtils;
+
+import com.android.server.policy.DeviceStatePolicyImpl;
 
 /**
  * Interface for the component responsible for supplying the current device state as well as
@@ -24,9 +29,15 @@
  *
  * @see DeviceStateManagerService
  */
-public interface DeviceStatePolicy {
+public abstract class DeviceStatePolicy {
+    protected final Context mContext;
+
+    protected DeviceStatePolicy(@NonNull Context context) {
+        mContext = context;
+    }
+
     /** Returns the provider of device states. */
-    DeviceStateProvider getDeviceStateProvider();
+    public abstract DeviceStateProvider getDeviceStateProvider();
 
     /**
      * Configures the system into the provided state. Guaranteed not to be called again until the
@@ -36,5 +47,58 @@
      * @param onComplete a callback that must be triggered once the system has been properly
      *                   configured to match the supplied state.
      */
-    void configureDeviceForState(int state, @NonNull Runnable onComplete);
+    public abstract void configureDeviceForState(int state, @NonNull Runnable onComplete);
+
+    /** Provider for platform-default device state policy. */
+    static final class DefaultProvider implements DeviceStatePolicy.Provider {
+        @Override
+        public DeviceStatePolicy instantiate(@NonNull Context context) {
+            return new DeviceStatePolicyImpl(context);
+        }
+    }
+
+    /**
+     * Provider for {@link DeviceStatePolicy} instances.
+     *
+     * <p>By implementing this interface and overriding the
+     * {@code config_deviceSpecificDeviceStatePolicyProvider}, a device-specific implementations
+     * of {@link DeviceStatePolicy} can be supplied.
+     */
+    public interface Provider {
+        /**
+         * Instantiates a new {@link DeviceStatePolicy}.
+         *
+         * @see DeviceStatePolicy#DeviceStatePolicy
+         */
+        DeviceStatePolicy instantiate(@NonNull Context context);
+
+        /**
+         * Instantiates the device-specific {@link DeviceStatePolicy.Provider}.
+         *
+         * Checks the {@code config_deviceSpecificDeviceStatePolicyProvider} resource to see if
+         * a device specific policy provider has been supplied. If so, returns an instance of that
+         * provider. If there is no value provided then the method returns the
+         * {@link DeviceStatePolicy.DefaultProvider}.
+         *
+         * An {@link IllegalStateException} is thrown if there is a value provided for that
+         * resource, but it doesn't correspond to a class that is found.
+         */
+        static Provider fromResources(@NonNull Resources res) {
+            final String name = res.getString(
+                    com.android.internal.R.string.config_deviceSpecificDeviceStatePolicyProvider);
+            if (TextUtils.isEmpty(name)) {
+                return new DeviceStatePolicy.DefaultProvider();
+            }
+
+            try {
+                return (DeviceStatePolicy.Provider) Class.forName(name).newInstance();
+            } catch (ReflectiveOperationException | ClassCastException e) {
+                throw new IllegalStateException("Couldn't instantiate class " + name
+                        + " for config_deviceSpecificDeviceStatePolicyProvider:"
+                        + " make sure it has a public zero-argument constructor"
+                        + " and implements DeviceStatePolicy.Provider", e);
+            }
+        }
+    }
+
 }
diff --git a/services/core/java/com/android/server/policy/DeviceStatePolicyImpl.java b/services/core/java/com/android/server/policy/DeviceStatePolicyImpl.java
index 154f9a4..7754944 100644
--- a/services/core/java/com/android/server/policy/DeviceStatePolicyImpl.java
+++ b/services/core/java/com/android/server/policy/DeviceStatePolicyImpl.java
@@ -27,12 +27,11 @@
  *
  * @see DeviceStateProviderImpl
  */
-public final class DeviceStatePolicyImpl implements DeviceStatePolicy {
-    private final Context mContext;
+public final class DeviceStatePolicyImpl extends DeviceStatePolicy {
     private final DeviceStateProvider mProvider;
 
-    public DeviceStatePolicyImpl(Context context) {
-        mContext = context;
+    public DeviceStatePolicyImpl(@NonNull Context context) {
+        super(context);
         mProvider = DeviceStateProviderImpl.create(mContext);
     }
 
diff --git a/services/tests/servicestests/src/com/android/server/devicestate/DeviceStateManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/devicestate/DeviceStateManagerServiceTest.java
index 03eba9b..d2cff0e 100644
--- a/services/tests/servicestests/src/com/android/server/devicestate/DeviceStateManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/devicestate/DeviceStateManagerServiceTest.java
@@ -551,13 +551,14 @@
         Assert.assertTrue(Arrays.equals(expected, actual));
     }
 
-    private static final class TestDeviceStatePolicy implements DeviceStatePolicy {
+    private static final class TestDeviceStatePolicy extends DeviceStatePolicy {
         private final DeviceStateProvider mProvider;
         private int mLastDeviceStateRequestedToConfigure = INVALID_DEVICE_STATE;
         private boolean mConfigureBlocked = false;
         private Runnable mPendingConfigureCompleteRunnable;
 
         TestDeviceStatePolicy(DeviceStateProvider provider) {
+            super(InstrumentationRegistry.getContext());
             mProvider = provider;
         }
 
diff --git a/services/tests/servicestests/src/com/android/server/devicestate/DeviceStatePolicyProviderTest.java b/services/tests/servicestests/src/com/android/server/devicestate/DeviceStatePolicyProviderTest.java
new file mode 100644
index 0000000..0bd81b7
--- /dev/null
+++ b/services/tests/servicestests/src/com/android/server/devicestate/DeviceStatePolicyProviderTest.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.devicestate;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertThrows;
+
+import android.content.Context;
+import android.content.res.Resources;
+import android.platform.test.annotations.Presubmit;
+
+import org.hamcrest.Matchers;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Unit tests for the {@link DeviceStatePolicy.Provider}
+ * <p/>
+ * Build/Install/Run:
+ *  <code>atest DeviceStatePolicyProviderTest</code>
+ */
+@Presubmit
+public class DeviceStatePolicyProviderTest {
+
+    @Test
+    public void test_emptyPolicyProvider() {
+        Assert.assertThat(DeviceStatePolicy.Provider.fromResources(resourcesWithProvider("")),
+                Matchers.instanceOf(DeviceStatePolicy.DefaultProvider.class));
+    }
+
+    @Test
+    public void test_nullPolicyProvider() {
+        Assert.assertThat(DeviceStatePolicy.Provider.fromResources(resourcesWithProvider(null)),
+                Matchers.instanceOf(DeviceStatePolicy.DefaultProvider.class));
+    }
+
+    @Test
+    public void test_customPolicyProvider() {
+        Assert.assertThat(DeviceStatePolicy.Provider.fromResources(resourcesWithProvider(
+                TestProvider.class.getName())),
+                Matchers.instanceOf(TestProvider.class));
+    }
+
+    @Test
+    public void test_badPolicyProvider_notImplementingProviderInterface() {
+        assertThrows(IllegalStateException.class, () -> {
+            DeviceStatePolicy.Provider.fromResources(resourcesWithProvider(
+                    Object.class.getName()));
+        });
+    }
+
+    @Test
+    public void test_badPolicyProvider_doesntExist() {
+        assertThrows(IllegalStateException.class, () -> {
+            DeviceStatePolicy.Provider.fromResources(resourcesWithProvider(
+                    "com.android.devicestate.nonexistent.policy"));
+        });
+    }
+
+    private static Resources resourcesWithProvider(String provider) {
+        final Resources mockResources = mock(Resources.class);
+        when(mockResources.getString(
+                com.android.internal.R.string.config_deviceSpecificDeviceStatePolicyProvider))
+                .thenReturn(provider);
+        return mockResources;
+    }
+
+    // Stub implementation of DeviceStatePolicy.Provider for testing
+    static class TestProvider implements DeviceStatePolicy.Provider {
+        @Override
+        public DeviceStatePolicy instantiate(Context context) {
+            throw new RuntimeException("test stub");
+        }
+    }
+}
