adb: implement zstd compression for file sync.

Bug: http://b/150827486
Test: test_device.py
Change-Id: I9fac4c760d9dbdce0b3b883db975cfa9b27a9e80
diff --git a/adb/Android.bp b/adb/Android.bp
index 9db151d..8747182 100644
--- a/adb/Android.bp
+++ b/adb/Android.bp
@@ -124,11 +124,12 @@
         "libadbd_core",
         "libadbconnection_server",
         "libasyncio",
+        "libbase",
         "libbrotli",
         "libcutils_sockets",
         "libdiagnose_usb",
         "libmdnssd",
-        "libbase",
+        "libzstd",
 
         "libadb_protos",
         "libapp_processes_protos_lite",
@@ -351,6 +352,7 @@
         "liblog",
         "libziparchive",
         "libz",
+        "libzstd",
     ],
 
     // Don't add anything here, we don't want additional shared dependencies
@@ -483,6 +485,7 @@
         "libbrotli",
         "libdiagnose_usb",
         "liblz4",
+        "libzstd",
     ],
 
     shared_libs: [
@@ -586,6 +589,7 @@
         "libdiagnose_usb",
         "liblz4",
         "libmdnssd",
+        "libzstd",
     ],
 
     visibility: [
diff --git a/adb/client/commandline.cpp b/adb/client/commandline.cpp
index eaa32e5..43772ba 100644
--- a/adb/client/commandline.cpp
+++ b/adb/client/commandline.cpp
@@ -1336,6 +1336,8 @@
         return CompressionType::Brotli;
     } else if (str == "lz4") {
         return CompressionType::LZ4;
+    } else if (str == "zstd") {
+        return CompressionType::Zstd;
     }
 
     error_exit("unexpected compression type %s", str.c_str());
diff --git a/adb/client/file_sync_client.cpp b/adb/client/file_sync_client.cpp
index 7185939..8bbe2a8 100644
--- a/adb/client/file_sync_client.cpp
+++ b/adb/client/file_sync_client.cpp
@@ -240,6 +240,7 @@
             have_sendrecv_v2_ = CanUseFeature(*features, kFeatureSendRecv2);
             have_sendrecv_v2_brotli_ = CanUseFeature(*features, kFeatureSendRecv2Brotli);
             have_sendrecv_v2_lz4_ = CanUseFeature(*features, kFeatureSendRecv2LZ4);
+            have_sendrecv_v2_zstd_ = CanUseFeature(*features, kFeatureSendRecv2Zstd);
             have_sendrecv_v2_dry_run_send_ = CanUseFeature(*features, kFeatureSendRecv2DryRunSend);
             std::string error;
             fd.reset(adb_connect("sync:", &error));
@@ -268,13 +269,16 @@
     bool HaveSendRecv2() const { return have_sendrecv_v2_; }
     bool HaveSendRecv2Brotli() const { return have_sendrecv_v2_brotli_; }
     bool HaveSendRecv2LZ4() const { return have_sendrecv_v2_lz4_; }
+    bool HaveSendRecv2Zstd() const { return have_sendrecv_v2_zstd_; }
     bool HaveSendRecv2DryRunSend() const { return have_sendrecv_v2_dry_run_send_; }
 
     // Resolve a compression type which might be CompressionType::Any to a specific compression
     // algorithm.
     CompressionType ResolveCompressionType(CompressionType compression) const {
         if (compression == CompressionType::Any) {
-            if (HaveSendRecv2LZ4()) {
+            if (HaveSendRecv2Zstd()) {
+                return CompressionType::Zstd;
+            } else if (HaveSendRecv2LZ4()) {
                 return CompressionType::LZ4;
             } else if (HaveSendRecv2Brotli()) {
                 return CompressionType::Brotli;
@@ -374,6 +378,10 @@
                 msg.send_v2_setup.flags = kSyncFlagLZ4;
                 break;
 
+            case CompressionType::Zstd:
+                msg.send_v2_setup.flags = kSyncFlagZstd;
+                break;
+
             case CompressionType::Any:
                 LOG(FATAL) << "unexpected CompressionType::Any";
         }
@@ -421,6 +429,10 @@
                 msg.recv_v2_setup.flags |= kSyncFlagLZ4;
                 break;
 
+            case CompressionType::Zstd:
+                msg.recv_v2_setup.flags |= kSyncFlagZstd;
+                break;
+
             case CompressionType::Any:
                 LOG(FATAL) << "unexpected CompressionType::Any";
         }
@@ -631,7 +643,8 @@
         syncsendbuf sbuf;
         sbuf.id = ID_DATA;
 
-        std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder> encoder_storage;
+        std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder, ZstdEncoder>
+                encoder_storage;
         Encoder* encoder = nullptr;
         switch (compression) {
             case CompressionType::None:
@@ -646,6 +659,10 @@
                 encoder = &encoder_storage.emplace<LZ4Encoder>(SYNC_DATA_MAX);
                 break;
 
+            case CompressionType::Zstd:
+                encoder = &encoder_storage.emplace<ZstdEncoder>(SYNC_DATA_MAX);
+                break;
+
             case CompressionType::Any:
                 LOG(FATAL) << "unexpected CompressionType::Any";
         }
@@ -928,6 +945,7 @@
     bool have_sendrecv_v2_;
     bool have_sendrecv_v2_brotli_;
     bool have_sendrecv_v2_lz4_;
+    bool have_sendrecv_v2_zstd_;
     bool have_sendrecv_v2_dry_run_send_;
 
     TransferLedger global_ledger_;
@@ -1133,7 +1151,8 @@
     uint64_t bytes_copied = 0;
 
     Block buffer(SYNC_DATA_MAX);
-    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder> decoder_storage;
+    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder, ZstdDecoder>
+            decoder_storage;
     Decoder* decoder = nullptr;
 
     std::span buffer_span(buffer.data(), buffer.size());
@@ -1150,6 +1169,10 @@
             decoder = &decoder_storage.emplace<LZ4Decoder>(buffer_span);
             break;
 
+        case CompressionType::Zstd:
+            decoder = &decoder_storage.emplace<ZstdDecoder>(buffer_span);
+            break;
+
         case CompressionType::Any:
             LOG(FATAL) << "unexpected CompressionType::Any";
     }
diff --git a/adb/compression_utils.h b/adb/compression_utils.h
index a0c48a2..a747108 100644
--- a/adb/compression_utils.h
+++ b/adb/compression_utils.h
@@ -25,6 +25,7 @@
 #include <brotli/decode.h>
 #include <brotli/encode.h>
 #include <lz4frame.h>
+#include <zstd.h>
 
 #include "types.h"
 
@@ -381,3 +382,105 @@
     std::unique_ptr<LZ4F_cctx, LZ4F_errorCode_t (*)(LZ4F_cctx*)> encoder_;
     IOVector output_buffer_;
 };
+
+struct ZstdDecoder final : public Decoder {
+    explicit ZstdDecoder(std::span<char> output_buffer)
+        : Decoder(output_buffer), decoder_(ZSTD_createDStream(), ZSTD_freeDStream) {
+        if (!decoder_) {
+            LOG(FATAL) << "failed to initialize Zstd decompression context";
+        }
+    }
+
+    DecodeResult Decode(std::span<char>* output) final {
+        ZSTD_inBuffer in;
+        in.src = input_buffer_.front_data();
+        in.size = input_buffer_.front_size();
+        in.pos = 0;
+
+        ZSTD_outBuffer out;
+        out.dst = output_buffer_.data();
+        // The standard specifies size() as returning size_t, but our current version of
+        // libc++ returns a signed value instead.
+        out.size = static_cast<size_t>(output_buffer_.size());
+        out.pos = 0;
+
+        size_t rc = ZSTD_decompressStream(decoder_.get(), &out, &in);
+        if (ZSTD_isError(rc)) {
+            LOG(ERROR) << "ZSTD_decompressStream failed: " << ZSTD_getErrorName(rc);
+            return DecodeResult::Error;
+        }
+
+        input_buffer_.drop_front(in.pos);
+        if (rc == 0) {
+            if (!input_buffer_.empty()) {
+                LOG(ERROR) << "Zstd stream hit end before reading all data";
+                return DecodeResult::Error;
+            }
+            zstd_done_ = true;
+        }
+
+        *output = std::span<char>(output_buffer_.data(), out.pos);
+
+        if (finished_) {
+            return input_buffer_.empty() && zstd_done_ ? DecodeResult::Done
+                                                       : DecodeResult::MoreOutput;
+        }
+        return DecodeResult::NeedInput;
+    }
+
+  private:
+    bool zstd_done_ = false;
+    std::unique_ptr<ZSTD_DStream, size_t (*)(ZSTD_DStream*)> decoder_;
+};
+
+struct ZstdEncoder final : public Encoder {
+    explicit ZstdEncoder(size_t output_block_size)
+        : Encoder(output_block_size), encoder_(ZSTD_createCStream(), ZSTD_freeCStream) {
+        if (!encoder_) {
+            LOG(FATAL) << "failed to initialize Zstd compression context";
+        }
+        ZSTD_CCtx_setParameter(encoder_.get(), ZSTD_c_compressionLevel, 1);
+    }
+
+    EncodeResult Encode(Block* output) final {
+        ZSTD_inBuffer in;
+        in.src = input_buffer_.front_data();
+        in.size = input_buffer_.front_size();
+        in.pos = 0;
+
+        output->resize(output_block_size_);
+
+        ZSTD_outBuffer out;
+        out.dst = output->data();
+        out.size = static_cast<size_t>(output->size());
+        out.pos = 0;
+
+        ZSTD_EndDirective end_directive = finished_ ? ZSTD_e_end : ZSTD_e_continue;
+        size_t rc = ZSTD_compressStream2(encoder_.get(), &out, &in, end_directive);
+        if (ZSTD_isError(rc)) {
+            LOG(ERROR) << "ZSTD_compressStream2 failed: " << ZSTD_getErrorName(rc);
+            return EncodeResult::Error;
+        }
+
+        input_buffer_.drop_front(in.pos);
+        output->resize(out.pos);
+
+        if (rc == 0) {
+            // Zstd finished flushing its data.
+            if (finished_) {
+                if (!input_buffer_.empty()) {
+                    LOG(ERROR) << "ZSTD_compressStream2 finished early";
+                    return EncodeResult::Error;
+                }
+                return EncodeResult::Done;
+            } else {
+                return input_buffer_.empty() ? EncodeResult::NeedInput : EncodeResult::MoreOutput;
+            }
+        } else {
+            return EncodeResult::MoreOutput;
+        }
+    }
+
+  private:
+    std::unique_ptr<ZSTD_CStream, size_t (*)(ZSTD_CStream*)> encoder_;
+};
diff --git a/adb/daemon/file_sync_service.cpp b/adb/daemon/file_sync_service.cpp
index d58131e..513b8dd 100644
--- a/adb/daemon/file_sync_service.cpp
+++ b/adb/daemon/file_sync_service.cpp
@@ -272,7 +272,8 @@
     syncmsg msg;
     Block buffer(SYNC_DATA_MAX);
     std::span<char> buffer_span(buffer.data(), buffer.size());
-    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder> decoder_storage;
+    std::variant<std::monostate, NullDecoder, BrotliDecoder, LZ4Decoder, ZstdDecoder>
+            decoder_storage;
     Decoder* decoder = nullptr;
 
     switch (compression) {
@@ -288,6 +289,10 @@
             decoder = &decoder_storage.emplace<LZ4Decoder>(buffer_span);
             break;
 
+        case CompressionType::Zstd:
+            decoder = &decoder_storage.emplace<ZstdDecoder>(buffer_span);
+            break;
+
         case CompressionType::Any:
             LOG(FATAL) << "unexpected CompressionType::Any";
     }
@@ -590,6 +595,15 @@
         }
         compression = CompressionType::LZ4;
     }
+    if (msg.send_v2_setup.flags & kSyncFlagZstd) {
+        msg.send_v2_setup.flags &= ~kSyncFlagZstd;
+        if (compression) {
+            SendSyncFail(s, android::base::StringPrintf("multiple compression flags received: %d",
+                                                        orig_flags));
+            return false;
+        }
+        compression = CompressionType::Zstd;
+    }
     if (msg.send_v2_setup.flags & kSyncFlagDryRun) {
         msg.send_v2_setup.flags &= ~kSyncFlagDryRun;
         dry_run = true;
@@ -623,7 +637,8 @@
     syncmsg msg;
     msg.data.id = ID_DATA;
 
-    std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder> encoder_storage;
+    std::variant<std::monostate, NullEncoder, BrotliEncoder, LZ4Encoder, ZstdEncoder>
+            encoder_storage;
     Encoder* encoder;
 
     switch (compression) {
@@ -639,6 +654,10 @@
             encoder = &encoder_storage.emplace<LZ4Encoder>(SYNC_DATA_MAX);
             break;
 
+        case CompressionType::Zstd:
+            encoder = &encoder_storage.emplace<ZstdEncoder>(SYNC_DATA_MAX);
+            break;
+
         case CompressionType::Any:
             LOG(FATAL) << "unexpected CompressionType::Any";
     }
@@ -726,6 +745,15 @@
         }
         compression = CompressionType::LZ4;
     }
+    if (msg.recv_v2_setup.flags & kSyncFlagZstd) {
+        msg.recv_v2_setup.flags &= ~kSyncFlagZstd;
+        if (compression) {
+            SendSyncFail(s, android::base::StringPrintf("multiple compression flags received: %d",
+                                                        orig_flags));
+            return false;
+        }
+        compression = CompressionType::Zstd;
+    }
 
     if (msg.recv_v2_setup.flags) {
         SendSyncFail(s, android::base::StringPrintf("unknown flags: %d", msg.recv_v2_setup.flags));
diff --git a/adb/file_sync_protocol.h b/adb/file_sync_protocol.h
index 8f8f85f..5234c20 100644
--- a/adb/file_sync_protocol.h
+++ b/adb/file_sync_protocol.h
@@ -93,6 +93,7 @@
     kSyncFlagNone = 0,
     kSyncFlagBrotli = 1,
     kSyncFlagLZ4 = 2,
+    kSyncFlagZstd = 4,
     kSyncFlagDryRun = 0x8000'0000U,
 };
 
@@ -101,6 +102,7 @@
     Any,
     Brotli,
     LZ4,
+    Zstd,
 };
 
 // send_v1 sent the path in a buffer, followed by a comma and the mode as a string.
diff --git a/adb/test_device.py b/adb/test_device.py
index 9f1f403..c1caafc 100755
--- a/adb/test_device.py
+++ b/adb/test_device.py
@@ -1362,6 +1362,10 @@
     compression = "lz4"
 
 
+class FileOperationsTestZstd(FileOperationsTest.Base):
+    compression = "zstd"
+
+
 class DeviceOfflineTest(DeviceTest):
     def _get_device_state(self, serialno):
         output = subprocess.check_output(self.device.adb_cmd + ['devices'])
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 1667011..b6b6984 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -85,6 +85,7 @@
 const char* const kFeatureSendRecv2 = "sendrecv_v2";
 const char* const kFeatureSendRecv2Brotli = "sendrecv_v2_brotli";
 const char* const kFeatureSendRecv2LZ4 = "sendrecv_v2_lz4";
+const char* const kFeatureSendRecv2Zstd = "sendrecv_v2_zstd";
 const char* const kFeatureSendRecv2DryRunSend = "sendrecv_v2_dry_run_send";
 
 namespace {
@@ -1189,6 +1190,7 @@
                 kFeatureSendRecv2,
                 kFeatureSendRecv2Brotli,
                 kFeatureSendRecv2LZ4,
+                kFeatureSendRecv2Zstd,
                 kFeatureSendRecv2DryRunSend,
                 // Increment ADB_SERVER_VERSION when adding a feature that adbd needs
                 // to know about. Otherwise, the client can be stuck running an old
diff --git a/adb/transport.h b/adb/transport.h
index 2ac21cf..b1f2744 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -93,6 +93,8 @@
 extern const char* const kFeatureSendRecv2Brotli;
 // adbd supports LZ4 for send/recv v2.
 extern const char* const kFeatureSendRecv2LZ4;
+// adbd supports Zstd for send/recv v2.
+extern const char* const kFeatureSendRecv2Zstd;
 // adbd supports dry-run send for send/recv v2.
 extern const char* const kFeatureSendRecv2DryRunSend;