Implement Socket::receive<T> and refactor Socket::receiveAck

Bug: 162032964
Bug: 161389935
Test: canhalctrl up test virtual vcan3
Change-Id: I8bd351cec0d484ee4be8a40908476194958afcb1
diff --git a/automotive/can/1.0/default/libnetdevice/can.cpp b/automotive/can/1.0/default/libnetdevice/can.cpp
index ab107fd..5a1105c 100644
--- a/automotive/can/1.0/default/libnetdevice/can.cpp
+++ b/automotive/can/1.0/default/libnetdevice/can.cpp
@@ -91,7 +91,7 @@
     }
 
     nl::Socket sock(NETLINK_ROUTE);
-    return sock.send(req) && sock.receiveAck();
+    return sock.send(req) && sock.receiveAck(req);
 }
 
 }  // namespace android::netdevice::can
diff --git a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
index ed2a51e..e2ba2cb 100644
--- a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
+++ b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
@@ -72,7 +72,7 @@
     }
 
     nl::Socket sock(NETLINK_ROUTE);
-    return sock.send(req) && sock.receiveAck();
+    return sock.send(req) && sock.receiveAck(req);
 }
 
 bool del(std::string dev) {
@@ -80,7 +80,7 @@
     req.addattr(IFLA_IFNAME, dev);
 
     nl::Socket sock(NETLINK_ROUTE);
-    return sock.send(req) && sock.receiveAck();
+    return sock.send(req) && sock.receiveAck(req);
 }
 
 std::optional<hwaddr_t> getHwAddr(const std::string& ifname) {
diff --git a/automotive/can/1.0/default/libnetdevice/vlan.cpp b/automotive/can/1.0/default/libnetdevice/vlan.cpp
index 3f904f0..33dc029 100644
--- a/automotive/can/1.0/default/libnetdevice/vlan.cpp
+++ b/automotive/can/1.0/default/libnetdevice/vlan.cpp
@@ -49,7 +49,7 @@
     }
 
     nl::Socket sock(NETLINK_ROUTE);
-    return sock.send(req) && sock.receiveAck();
+    return sock.send(req) && sock.receiveAck(req);
 }
 
 }  // namespace android::netdevice::vlan
diff --git a/automotive/can/1.0/default/libnl++/Socket.cpp b/automotive/can/1.0/default/libnl++/Socket.cpp
index 56e990c..1a34df8 100644
--- a/automotive/can/1.0/default/libnl++/Socket.cpp
+++ b/automotive/can/1.0/default/libnl++/Socket.cpp
@@ -103,60 +103,45 @@
     return {msg, sa};
 }
 
-/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
- * Socket::receive(). */
-bool Socket::receiveAck() {
-    if (mFailed) return false;
+bool Socket::receiveAck(uint32_t seq) {
+    const auto nlerr = receive<nlmsgerr>({NLMSG_ERROR});
+    if (!nlerr.has_value()) return false;
 
-    char buf[8192];
-
-    sockaddr_nl sa;
-    iovec iov = {buf, sizeof(buf)};
-
-    msghdr msg = {};
-    msg.msg_name = &sa;
-    msg.msg_namelen = sizeof(sa);
-    msg.msg_iov = &iov;
-    msg.msg_iovlen = 1;
-
-    const ssize_t status = recvmsg(mFd.get(), &msg, 0);
-    if (status < 0) {
-        PLOG(ERROR) << "Failed to receive Netlink message";
-        return false;
-    }
-    size_t remainingLen = status;
-
-    if (msg.msg_flags & MSG_TRUNC) {
-        LOG(ERROR) << "Failed to receive Netlink message: truncated";
+    if (nlerr->data.msg.nlmsg_seq != seq) {
+        LOG(ERROR) << "Received ACK for a different message (" << nlerr->data.msg.nlmsg_seq
+                   << ", expected " << seq << "). Multi-message tracking is not implemented.";
         return false;
     }
 
-    for (auto nlmsg = reinterpret_cast<nlmsghdr*>(buf); NLMSG_OK(nlmsg, remainingLen);
-         nlmsg = NLMSG_NEXT(nlmsg, remainingLen)) {
-        if constexpr (kSuperVerbose) {
-            LOG(VERBOSE) << "received Netlink response: "
-                         << toString({nlmsg, nlmsg->nlmsg_len}, mProtocol);
-        }
+    if (nlerr->data.error == 0) return true;
 
-        // We're looking for error/ack message only, ignoring others.
-        if (nlmsg->nlmsg_type != NLMSG_ERROR) {
-            LOG(WARNING) << "Received unexpected Netlink message (ignored): " << nlmsg->nlmsg_type;
-            continue;
-        }
-
-        // Found error/ack message, return status.
-        const auto nlerr = reinterpret_cast<nlmsgerr*>(NLMSG_DATA(nlmsg));
-        if (nlerr->error != 0) {
-            LOG(ERROR) << "Received Netlink error message: " << strerror(-nlerr->error);
-            return false;
-        }
-        return true;
-    }
-    // Couldn't find any error/ack messages.
+    LOG(WARNING) << "Received Netlink error message: " << strerror(-nlerr->data.error);
     return false;
 }
 
+std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msgtypes,
+                                                size_t maxSize) {
+    while (!mFailed) {
+        const auto msgBuf = receive(maxSize);
+        if (!msgBuf.has_value()) return std::nullopt;
+
+        for (const auto rawMsg : *msgBuf) {
+            if (msgtypes.count(rawMsg->nlmsg_type) == 0) {
+                LOG(WARNING) << "Received (and ignored) unexpected Netlink message of type "
+                             << rawMsg->nlmsg_type;
+                continue;
+            }
+
+            return rawMsg;
+        }
+    }
+
+    return std::nullopt;
+}
+
 std::optional<unsigned> Socket::getPid() {
+    if (mFailed) return std::nullopt;
+
     sockaddr_nl sa = {};
     socklen_t sasize = sizeof(sa);
     if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/Message.h b/automotive/can/1.0/default/libnl++/include/libnl++/Message.h
index 2b84a86..50b3c4b 100644
--- a/automotive/can/1.0/default/libnl++/include/libnl++/Message.h
+++ b/automotive/can/1.0/default/libnl++/include/libnl++/Message.h
@@ -19,6 +19,8 @@
 #include <libnl++/Attributes.h>
 #include <libnl++/Buffer.h>
 
+#include <set>
+
 namespace android::nl {
 
 /**
@@ -60,7 +62,8 @@
      * \return Parsed message or nullopt, if the buffer data is invalid or message type
      *         doesn't match.
      */
-    static std::optional<Message<T>> parse(Buffer<nlmsghdr> buf, std::set<nlmsgtype_t> msgtypes) {
+    static std::optional<Message<T>> parse(Buffer<nlmsghdr> buf,
+                                           const std::set<nlmsgtype_t>& msgtypes) {
         const auto& [nlOk, nlHeader] = buf.getFirst();  // we're doing it twice, but it's fine
         if (!nlOk) return std::nullopt;
 
diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
index bc6ad9d..16b63f5 100644
--- a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
+++ b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
@@ -19,11 +19,13 @@
 #include <android-base/macros.h>
 #include <android-base/unique_fd.h>
 #include <libnl++/Buffer.h>
+#include <libnl++/Message.h>
 #include <libnl++/MessageFactory.h>
 
 #include <linux/netlink.h>
 
 #include <optional>
+#include <set>
 #include <vector>
 
 namespace android::nl {
@@ -57,7 +59,7 @@
      * \param msg Message to send. Its sequence number will be updated.
      * \return true, if succeeded.
      */
-    template <class T, unsigned BUFSIZE>
+    template <typename T, unsigned BUFSIZE>
     bool send(MessageFactory<T, BUFSIZE>& req) {
         sockaddr_nl sa = {};
         sa.nl_family = AF_NETLINK;
@@ -72,7 +74,7 @@
      * \param sa Destination address.
      * \return true, if succeeded.
      */
-    template <class T, unsigned BUFSIZE>
+    template <typename T, unsigned BUFSIZE>
     bool send(MessageFactory<T, BUFSIZE>& req, const sockaddr_nl& sa) {
         if (!req.isGood()) return false;
 
@@ -109,7 +111,7 @@
      * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
      * call to the read function or until deallocation of Socket instance.
      *
-     * \param maxSize Maximum total size of received messages
+     * \param maxSize Maximum total size of received messages.
      * \return A pair (for use with structured binding) containing:
      *         - buffer view with message data, std::nullopt on error;
      *         - sender process address.
@@ -118,27 +120,70 @@
             size_t maxSize = defaultReceiveSize);
 
     /**
-     * Receive Netlink ACK message from Kernel.
+     * Receive matching Netlink message of a given payload type.
      *
-     * \return true if received ACK message, false in case of error
+     * This method should be used if the caller expects exactly one incoming message of exactly
+     * given type (such as ACK). If there is a use case to handle multiple types of messages,
+     * please use receive(size_t) directly and iterate through potential multipart messages.
+     *
+     * If this method is used in such an environment, it will only return the first matching message
+     * from multipart packet and will issue warnings on messages that do not match.
+     *
+     * \param msgtypes Expected message types (such as NLMSG_ERROR).
+     * \param maxSize Maximum total size of received messages.
+     * \return Parsed message or std::nullopt in case of error.
      */
-    bool receiveAck();
+    template <typename T>
+    std::optional<Message<T>> receive(const std::set<nlmsgtype_t>& msgtypes,
+                                      size_t maxSize = defaultReceiveSize) {
+        const auto msg = receive(msgtypes, maxSize);
+        if (!msg.has_value()) return std::nullopt;
+
+        const auto parsed = Message<T>::parse(*msg);
+        if (!parsed.has_value()) {
+            LOG(WARNING) << "Received matching Netlink message, but couldn't parse it";
+            return std::nullopt;
+        }
+
+        return parsed;
+    }
+
+    /**
+     * Receive Netlink ACK message.
+     *
+     * \param req Message to match sequence number against.
+     * \return true if received ACK message, false in case of error.
+     */
+    template <typename T, unsigned BUFSIZE>
+    bool receiveAck(MessageFactory<T, BUFSIZE>& req) {
+        return receiveAck(req.header()->nlmsg_seq);
+    }
+
+    /**
+     * Receive Netlink ACK message.
+     *
+     * \param seq Sequence number of message to ACK.
+     * \return true if received ACK message, false in case of error.
+     */
+    bool receiveAck(uint32_t seq);
 
     /**
      * Fetches the socket PID.
      *
-     * \return PID that socket is bound to.
+     * \return PID that socket is bound to or std::nullopt.
      */
     std::optional<unsigned> getPid();
 
   private:
     const int mProtocol;
-
-    uint32_t mSeq = 0;
     base::unique_fd mFd;
-    bool mFailed = false;
     std::vector<uint8_t> mReceiveBuffer;
 
+    bool mFailed = false;
+    uint32_t mSeq = 0;
+
+    std::optional<Buffer<nlmsghdr>> receive(const std::set<nlmsgtype_t>& msgtypes, size_t maxSize);
+
     DISALLOW_COPY_AND_ASSIGN(Socket);
 };