Merge "Fix SntpClient 2036 issue (2/2)"
diff --git a/core/java/android/net/SntpClient.java b/core/java/android/net/SntpClient.java
index aea11fa..0eb4cf3 100644
--- a/core/java/android/net/SntpClient.java
+++ b/core/java/android/net/SntpClient.java
@@ -17,8 +17,11 @@
package android.net;
import android.compat.annotation.UnsupportedAppUsage;
+import android.net.sntp.Duration64;
+import android.net.sntp.Timestamp64;
import android.os.SystemClock;
import android.util.Log;
+import android.util.Slog;
import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.util.TrafficStatsConstants;
@@ -27,10 +30,12 @@
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.UnknownHostException;
+import java.security.NoSuchAlgorithmException;
+import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
-import java.util.Arrays;
import java.util.Objects;
+import java.util.Random;
import java.util.function.Supplier;
/**
@@ -65,13 +70,11 @@
private static final int NTP_STRATUM_DEATH = 0;
private static final int NTP_STRATUM_MAX = 15;
- // Number of seconds between Jan 1, 1900 and Jan 1, 1970
- // 70 years plus 17 leap days
- private static final long OFFSET_1900_TO_1970 = ((365L * 70L) + 17L) * 24L * 60L * 60L;
-
// The source of the current system clock time, replaceable for testing.
private final Supplier<Instant> mSystemTimeSupplier;
+ private final Random mRandom;
+
// The last offset calculated from an NTP server response
private long mClockOffset;
@@ -92,12 +95,13 @@
@UnsupportedAppUsage
public SntpClient() {
- this(Instant::now);
+ this(Instant::now, defaultRandom());
}
@VisibleForTesting
- public SntpClient(Supplier<Instant> systemTimeSupplier) {
+ public SntpClient(Supplier<Instant> systemTimeSupplier, Random random) {
mSystemTimeSupplier = Objects.requireNonNull(systemTimeSupplier);
+ mRandom = Objects.requireNonNull(random);
}
/**
@@ -144,10 +148,12 @@
// get current time and write it to the request packet
final Instant requestTime = mSystemTimeSupplier.get();
- final long requestTimestamp = requestTime.toEpochMilli();
+ final Timestamp64 requestTimestamp = Timestamp64.fromInstant(requestTime);
+ final Timestamp64 randomizedRequestTimestamp =
+ requestTimestamp.randomizeSubMillis(mRandom);
final long requestTicks = SystemClock.elapsedRealtime();
- writeTimeStamp(buffer, TRANSMIT_TIME_OFFSET, requestTimestamp);
+ writeTimeStamp(buffer, TRANSMIT_TIME_OFFSET, randomizedRequestTimestamp);
socket.send(request);
@@ -156,23 +162,25 @@
socket.receive(response);
final long responseTicks = SystemClock.elapsedRealtime();
final Instant responseTime = requestTime.plusMillis(responseTicks - requestTicks);
- final long responseTimestamp = responseTime.toEpochMilli();
+ final Timestamp64 responseTimestamp = Timestamp64.fromInstant(responseTime);
// extract the results
final byte leap = (byte) ((buffer[0] >> 6) & 0x3);
final byte mode = (byte) (buffer[0] & 0x7);
final int stratum = (int) (buffer[1] & 0xff);
- final long originateTimestamp = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
- final long receiveTimestamp = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
- final long transmitTimestamp = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);
- final long referenceTimestamp = readTimeStamp(buffer, REFERENCE_TIME_OFFSET);
+ final Timestamp64 referenceTimestamp = readTimeStamp(buffer, REFERENCE_TIME_OFFSET);
+ final Timestamp64 originateTimestamp = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
+ final Timestamp64 receiveTimestamp = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
+ final Timestamp64 transmitTimestamp = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);
/* Do validation according to RFC */
- // TODO: validate originateTime == requestTime.
- checkValidServerReply(leap, mode, stratum, transmitTimestamp, referenceTimestamp);
+ checkValidServerReply(leap, mode, stratum, transmitTimestamp, referenceTimestamp,
+ randomizedRequestTimestamp, originateTimestamp);
- long roundTripTimeMillis = responseTicks - requestTicks
- - (transmitTimestamp - receiveTimestamp);
+ long totalTransactionDurationMillis = responseTicks - requestTicks;
+ long serverDurationMillis =
+ Duration64.between(receiveTimestamp, transmitTimestamp).toDuration().toMillis();
+ long roundTripTimeMillis = totalTransactionDurationMillis - serverDurationMillis;
Duration clockOffsetDuration = calculateClockOffset(requestTimestamp,
receiveTimestamp, transmitTimestamp, responseTimestamp);
@@ -207,20 +215,24 @@
/** Performs the NTP clock offset calculation. */
@VisibleForTesting
- public static Duration calculateClockOffset(long clientRequestTimestamp,
- long serverReceiveTimestamp, long serverTransmitTimestamp,
- long clientResponseTimestamp) {
- // receiveTime = originateTime + transit + skew
- // responseTime = transmitTime + transit - skew
- // clockOffset = ((receiveTime - originateTime) + (transmitTime - responseTime))/2
- // = ((originateTime + transit + skew - originateTime) +
- // (transmitTime - (transmitTime + transit - skew)))/2
- // = ((transit + skew) + (transmitTime - transmitTime - transit + skew))/2
- // = (transit + skew - transit + skew)/2
- // = (2 * skew)/2 = skew
- long clockOffsetMillis = ((serverReceiveTimestamp - clientRequestTimestamp)
- + (serverTransmitTimestamp - clientResponseTimestamp)) / 2;
- return Duration.ofMillis(clockOffsetMillis);
+ public static Duration calculateClockOffset(Timestamp64 clientRequestTimestamp,
+ Timestamp64 serverReceiveTimestamp, Timestamp64 serverTransmitTimestamp,
+ Timestamp64 clientResponseTimestamp) {
+ // According to RFC4330:
+ // t is the system clock offset (the adjustment we are trying to find)
+ // t = ((T2 - T1) + (T3 - T4)) / 2
+ //
+ // Which is:
+ // t = (([server]receiveTimestamp - [client]requestTimestamp)
+ // + ([server]transmitTimestamp - [client]responseTimestamp)) / 2
+ //
+ // See the NTP spec and tests: the numeric types used are deliberate:
+ // + Duration64.between() uses 64-bit arithmetic (32-bit for the seconds).
+ // + plus() / dividedBy() use Duration, which isn't the double precision floating point
+ // used in NTPv4, but is good enough.
+ return Duration64.between(clientRequestTimestamp, serverReceiveTimestamp)
+ .plus(Duration64.between(clientResponseTimestamp, serverTransmitTimestamp))
+ .dividedBy(2);
}
@Deprecated
@@ -270,8 +282,9 @@
}
private static void checkValidServerReply(
- byte leap, byte mode, int stratum, long transmitTime, long referenceTime)
- throws InvalidServerReplyException {
+ byte leap, byte mode, int stratum, Timestamp64 transmitTimestamp,
+ Timestamp64 referenceTimestamp, Timestamp64 randomizedRequestTimestamp,
+ Timestamp64 originateTimestamp) throws InvalidServerReplyException {
if (leap == NTP_LEAP_NOSYNC) {
throw new InvalidServerReplyException("unsynchronized server");
}
@@ -281,73 +294,68 @@
if ((stratum == NTP_STRATUM_DEATH) || (stratum > NTP_STRATUM_MAX)) {
throw new InvalidServerReplyException("untrusted stratum: " + stratum);
}
- if (transmitTime == 0) {
- throw new InvalidServerReplyException("zero transmitTime");
+ if (!randomizedRequestTimestamp.equals(originateTimestamp)) {
+ throw new InvalidServerReplyException(
+ "originateTimestamp != randomizedRequestTimestamp");
}
- if (referenceTime == 0) {
- throw new InvalidServerReplyException("zero reference timestamp");
+ if (transmitTimestamp.equals(Timestamp64.ZERO)) {
+ throw new InvalidServerReplyException("zero transmitTimestamp");
+ }
+ if (referenceTimestamp.equals(Timestamp64.ZERO)) {
+ throw new InvalidServerReplyException("zero referenceTimestamp");
}
}
/**
* Reads an unsigned 32 bit big endian number from the given offset in the buffer.
*/
- private long read32(byte[] buffer, int offset) {
- byte b0 = buffer[offset];
- byte b1 = buffer[offset+1];
- byte b2 = buffer[offset+2];
- byte b3 = buffer[offset+3];
+ private long readUnsigned32(byte[] buffer, int offset) {
+ int i0 = buffer[offset++] & 0xFF;
+ int i1 = buffer[offset++] & 0xFF;
+ int i2 = buffer[offset++] & 0xFF;
+ int i3 = buffer[offset] & 0xFF;
- // convert signed bytes to unsigned values
- int i0 = ((b0 & 0x80) == 0x80 ? (b0 & 0x7F) + 0x80 : b0);
- int i1 = ((b1 & 0x80) == 0x80 ? (b1 & 0x7F) + 0x80 : b1);
- int i2 = ((b2 & 0x80) == 0x80 ? (b2 & 0x7F) + 0x80 : b2);
- int i3 = ((b3 & 0x80) == 0x80 ? (b3 & 0x7F) + 0x80 : b3);
-
- return ((long)i0 << 24) + ((long)i1 << 16) + ((long)i2 << 8) + (long)i3;
+ int bits = (i0 << 24) | (i1 << 16) | (i2 << 8) | i3;
+ return bits & 0xFFFF_FFFFL;
}
/**
- * Reads the NTP time stamp at the given offset in the buffer and returns
- * it as a system time (milliseconds since January 1, 1970).
+ * Reads the NTP time stamp from the given offset in the buffer.
*/
- private long readTimeStamp(byte[] buffer, int offset) {
- long seconds = read32(buffer, offset);
- long fraction = read32(buffer, offset + 4);
- // Special case: zero means zero.
- if (seconds == 0 && fraction == 0) {
- return 0;
- }
- return ((seconds - OFFSET_1900_TO_1970) * 1000) + ((fraction * 1000L) / 0x100000000L);
+ private Timestamp64 readTimeStamp(byte[] buffer, int offset) {
+ long seconds = readUnsigned32(buffer, offset);
+ int fractionBits = (int) readUnsigned32(buffer, offset + 4);
+ return Timestamp64.fromComponents(seconds, fractionBits);
}
/**
- * Writes system time (milliseconds since January 1, 1970) as an NTP time stamp
- * at the given offset in the buffer.
+ * Writes the NTP time stamp at the given offset in the buffer.
*/
- private void writeTimeStamp(byte[] buffer, int offset, long time) {
- // Special case: zero means zero.
- if (time == 0) {
- Arrays.fill(buffer, offset, offset + 8, (byte) 0x00);
- return;
- }
-
- long seconds = time / 1000L;
- long milliseconds = time - seconds * 1000L;
- seconds += OFFSET_1900_TO_1970;
-
+ private void writeTimeStamp(byte[] buffer, int offset, Timestamp64 timestamp) {
+ long seconds = timestamp.getEraSeconds();
// write seconds in big endian format
- buffer[offset++] = (byte)(seconds >> 24);
- buffer[offset++] = (byte)(seconds >> 16);
- buffer[offset++] = (byte)(seconds >> 8);
- buffer[offset++] = (byte)(seconds >> 0);
+ buffer[offset++] = (byte) (seconds >>> 24);
+ buffer[offset++] = (byte) (seconds >>> 16);
+ buffer[offset++] = (byte) (seconds >>> 8);
+ buffer[offset++] = (byte) (seconds);
- long fraction = milliseconds * 0x100000000L / 1000L;
+ int fractionBits = timestamp.getFractionBits();
// write fraction in big endian format
- buffer[offset++] = (byte)(fraction >> 24);
- buffer[offset++] = (byte)(fraction >> 16);
- buffer[offset++] = (byte)(fraction >> 8);
- // low order bits should be random data
- buffer[offset++] = (byte)(Math.random() * 255.0);
+ buffer[offset++] = (byte) (fractionBits >>> 24);
+ buffer[offset++] = (byte) (fractionBits >>> 16);
+ buffer[offset++] = (byte) (fractionBits >>> 8);
+ buffer[offset] = (byte) (fractionBits);
+ }
+
+ private static Random defaultRandom() {
+ Random random;
+ try {
+ random = SecureRandom.getInstanceStrong();
+ } catch (NoSuchAlgorithmException e) {
+ // This should never happen.
+ Slog.wtf(TAG, "Unable to access SecureRandom", e);
+ random = new Random(System.currentTimeMillis());
+ }
+ return random;
}
}
diff --git a/core/tests/coretests/src/android/net/SntpClientTest.java b/core/tests/coretests/src/android/net/SntpClientTest.java
index 178cd02..b400b9b 100644
--- a/core/tests/coretests/src/android/net/SntpClientTest.java
+++ b/core/tests/coretests/src/android/net/SntpClientTest.java
@@ -46,6 +46,7 @@
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
+import java.util.Random;
import java.util.function.Supplier;
@RunWith(AndroidJUnit4.class)
@@ -134,6 +135,7 @@
private SntpClient mClient;
private Network mNetwork;
private Supplier<Instant> mSystemTimeSupplier;
+ private Random mRandom;
@SuppressWarnings("unchecked")
@Before
@@ -143,9 +145,13 @@
// A mock network has NETID_UNSET, which allows the test to run, with a loopback server,
// even w/o external networking.
mNetwork = mock(Network.class, CALLS_REAL_METHODS);
+ mRandom = mock(Random.class);
mSystemTimeSupplier = mock(Supplier.class);
- mClient = new SntpClient(mSystemTimeSupplier);
+ // Returning zero means the "randomized" bottom bits of the clients transmit timestamp /
+ // server's originate timestamp will be zeros.
+ when(mRandom.nextInt()).thenReturn(0);
+ mClient = new SntpClient(mSystemTimeSupplier, mRandom);
}
/** Tests when the client and server are in ERA0. b/199481251. */
@@ -258,14 +264,14 @@
long simulatedClientElapsedTimeMillis = totalElapsedTimeMillis;
// Create some symmetrical timestamps.
- long clientRequestTimestamp =
- clientTime.minusMillis(simulatedClientElapsedTimeMillis / 2).toEpochMilli();
- long clientResponseTimestamp =
- clientTime.plusMillis(simulatedClientElapsedTimeMillis / 2).toEpochMilli();
- long serverReceiveTimestamp =
- serverTime.minusMillis(simulatedServerElapsedTimeMillis / 2).toEpochMilli();
- long serverTransmitTimestamp =
- serverTime.plusMillis(simulatedServerElapsedTimeMillis / 2).toEpochMilli();
+ Timestamp64 clientRequestTimestamp = Timestamp64.fromInstant(
+ clientTime.minusMillis(simulatedClientElapsedTimeMillis / 2));
+ Timestamp64 clientResponseTimestamp = Timestamp64.fromInstant(
+ clientTime.plusMillis(simulatedClientElapsedTimeMillis / 2));
+ Timestamp64 serverReceiveTimestamp = Timestamp64.fromInstant(
+ serverTime.minusMillis(simulatedServerElapsedTimeMillis / 2));
+ Timestamp64 serverTransmitTimestamp = Timestamp64.fromInstant(
+ serverTime.plusMillis(simulatedServerElapsedTimeMillis / 2));
Duration actualOffset = SntpClient.calculateClockOffset(
clientRequestTimestamp, serverReceiveTimestamp,
diff --git a/core/tests/coretests/src/android/net/sntp/PredictableRandom.java b/core/tests/coretests/src/android/net/sntp/PredictableRandom.java
new file mode 100644
index 0000000..bb2922b
--- /dev/null
+++ b/core/tests/coretests/src/android/net/sntp/PredictableRandom.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.sntp;
+
+import java.util.Random;
+
+class PredictableRandom extends Random {
+ private int[] mIntSequence = new int[] { 1 };
+ private int mIntPos = 0;
+
+ public void setIntSequence(int[] intSequence) {
+ this.mIntSequence = intSequence;
+ }
+
+ @Override
+ public int nextInt() {
+ int value = mIntSequence[mIntPos++];
+ mIntPos %= mIntSequence.length;
+ return value;
+ }
+}
diff --git a/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java b/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
index 7e945e5..c923812 100644
--- a/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
+++ b/core/tests/coretests/src/android/net/sntp/Timestamp64Test.java
@@ -24,6 +24,9 @@
import org.junit.Test;
import java.time.Instant;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Set;
public class Timestamp64Test {
@@ -205,6 +208,96 @@
actualNanos == expectedNanos || actualNanos == expectedNanos - 1);
}
+ @Test
+ public void testMillisRandomizationConstant() {
+ // Mathematically, we can say that to represent 1000 different values, we need 10 binary
+ // digits (2^10 = 1024). The same is true whether we're dealing with integers or fractions.
+ // Unfortunately, for fractions those 1024 values do not correspond to discrete decimal
+ // values. Discrete millisecond values as fractions (e.g. 0.001 - 0.999) cannot be
+ // represented exactly except where the value can also be represented as some combination of
+ // powers of -2. When we convert back and forth, we truncate, so millisecond decimal
+ // fraction N represented as a binary fraction will always be equal to or lower than N. If
+ // we are truncating correctly it will never be as low as (N-0.001). N -> [N-0.001, N].
+
+ // We need to keep 10 bits to hold millis (inaccurately, since there are numbers that
+ // cannot be represented exactly), leaving us able to randomize the remaining 22 bits of the
+ // fraction part without significantly affecting the number represented.
+ assertEquals(22, Timestamp64.SUB_MILLIS_BITS_TO_RANDOMIZE);
+
+ // Brute force proof that randomization logic will keep the timestamp within the range
+ // [N-0.001, N] where x is in milliseconds.
+ int smallFractionRandomizedLow = 0;
+ int smallFractionRandomizedHigh = 0b00000000_00111111_11111111_11111111;
+ int largeFractionRandomizedLow = 0b11111111_11000000_00000000_00000000;
+ int largeFractionRandomizedHigh = 0b11111111_11111111_11111111_11111111;
+
+ long smallLowNanos = Timestamp64.fromComponents(
+ 0, smallFractionRandomizedLow).toInstant(0).getNano();
+ long smallHighNanos = Timestamp64.fromComponents(
+ 0, smallFractionRandomizedHigh).toInstant(0).getNano();
+ long smallDelta = smallHighNanos - smallLowNanos;
+ long millisInNanos = 1_000_000_000 / 1_000;
+ assertTrue(smallDelta >= 0 && smallDelta < millisInNanos);
+
+ long largeLowNanos = Timestamp64.fromComponents(
+ 0, largeFractionRandomizedLow).toInstant(0).getNano();
+ long largeHighNanos = Timestamp64.fromComponents(
+ 0, largeFractionRandomizedHigh).toInstant(0).getNano();
+ long largeDelta = largeHighNanos - largeLowNanos;
+ assertTrue(largeDelta >= 0 && largeDelta < millisInNanos);
+
+ PredictableRandom random = new PredictableRandom();
+ random.setIntSequence(new int[] { 0xFFFF_FFFF });
+ Timestamp64 zero = Timestamp64.fromComponents(0, 0);
+ Timestamp64 zeroWithFractionRandomized = zero.randomizeSubMillis(random);
+ assertEquals(zero.getEraSeconds(), zeroWithFractionRandomized.getEraSeconds());
+ assertEquals(smallFractionRandomizedHigh, zeroWithFractionRandomized.getFractionBits());
+ }
+
+ @Test
+ public void testRandomizeLowestBits() {
+ Random random = new Random(1);
+ {
+ int fractionBits = 0;
+ expectIllegalArgumentException(
+ () -> Timestamp64.randomizeLowestBits(random, fractionBits, -1));
+ expectIllegalArgumentException(
+ () -> Timestamp64.randomizeLowestBits(random, fractionBits, 0));
+ expectIllegalArgumentException(
+ () -> Timestamp64.randomizeLowestBits(random, fractionBits, Integer.SIZE));
+ expectIllegalArgumentException(
+ () -> Timestamp64.randomizeLowestBits(random, fractionBits, Integer.SIZE + 1));
+ }
+
+ // Check the behavior looks correct from a probabilistic point of view.
+ for (int input : new int[] { 0, 0xFFFFFFFF }) {
+ for (int bitCount = 1; bitCount < Integer.SIZE; bitCount++) {
+ int upperBitMask = 0xFFFFFFFF << bitCount;
+ int expectedUpperBits = input & upperBitMask;
+
+ Set<Integer> values = new HashSet<>();
+ values.add(input);
+
+ int trials = 100;
+ for (int i = 0; i < trials; i++) {
+ int outputFractionBits =
+ Timestamp64.randomizeLowestBits(random, input, bitCount);
+
+ // Record the output value for later analysis.
+ values.add(outputFractionBits);
+
+ // Check upper bits did not change.
+ assertEquals(expectedUpperBits, outputFractionBits & upperBitMask);
+ }
+
+ // It's possible to be more rigorous here, perhaps with a histogram. As bitCount
+ // rises, values.size() quickly trend towards the value of trials + 1. For now, this
+ // mostly just guards against a no-op implementation.
+ assertTrue(bitCount + ":" + values.size(), values.size() > 1);
+ }
+ }
+ }
+
private static void expectIllegalArgumentException(Runnable r) {
try {
r.run();