Implement Socket::receive<T> and refactor Socket::receiveAck
Bug: 162032964
Bug: 161389935
Test: canhalctrl up test virtual vcan3
Change-Id: I8bd351cec0d484ee4be8a40908476194958afcb1
diff --git a/automotive/can/1.0/default/libnetdevice/can.cpp b/automotive/can/1.0/default/libnetdevice/can.cpp
index ab107fd..5a1105c 100644
--- a/automotive/can/1.0/default/libnetdevice/can.cpp
+++ b/automotive/can/1.0/default/libnetdevice/can.cpp
@@ -91,7 +91,7 @@
}
nl::Socket sock(NETLINK_ROUTE);
- return sock.send(req) && sock.receiveAck();
+ return sock.send(req) && sock.receiveAck(req);
}
} // namespace android::netdevice::can
diff --git a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
index ed2a51e..e2ba2cb 100644
--- a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
+++ b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
@@ -72,7 +72,7 @@
}
nl::Socket sock(NETLINK_ROUTE);
- return sock.send(req) && sock.receiveAck();
+ return sock.send(req) && sock.receiveAck(req);
}
bool del(std::string dev) {
@@ -80,7 +80,7 @@
req.addattr(IFLA_IFNAME, dev);
nl::Socket sock(NETLINK_ROUTE);
- return sock.send(req) && sock.receiveAck();
+ return sock.send(req) && sock.receiveAck(req);
}
std::optional<hwaddr_t> getHwAddr(const std::string& ifname) {
diff --git a/automotive/can/1.0/default/libnetdevice/vlan.cpp b/automotive/can/1.0/default/libnetdevice/vlan.cpp
index 3f904f0..33dc029 100644
--- a/automotive/can/1.0/default/libnetdevice/vlan.cpp
+++ b/automotive/can/1.0/default/libnetdevice/vlan.cpp
@@ -49,7 +49,7 @@
}
nl::Socket sock(NETLINK_ROUTE);
- return sock.send(req) && sock.receiveAck();
+ return sock.send(req) && sock.receiveAck(req);
}
} // namespace android::netdevice::vlan
diff --git a/automotive/can/1.0/default/libnl++/Socket.cpp b/automotive/can/1.0/default/libnl++/Socket.cpp
index 56e990c..1a34df8 100644
--- a/automotive/can/1.0/default/libnl++/Socket.cpp
+++ b/automotive/can/1.0/default/libnl++/Socket.cpp
@@ -103,60 +103,45 @@
return {msg, sa};
}
-/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
- * Socket::receive(). */
-bool Socket::receiveAck() {
- if (mFailed) return false;
+bool Socket::receiveAck(uint32_t seq) {
+ const auto nlerr = receive<nlmsgerr>({NLMSG_ERROR});
+ if (!nlerr.has_value()) return false;
- char buf[8192];
-
- sockaddr_nl sa;
- iovec iov = {buf, sizeof(buf)};
-
- msghdr msg = {};
- msg.msg_name = &sa;
- msg.msg_namelen = sizeof(sa);
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
-
- const ssize_t status = recvmsg(mFd.get(), &msg, 0);
- if (status < 0) {
- PLOG(ERROR) << "Failed to receive Netlink message";
- return false;
- }
- size_t remainingLen = status;
-
- if (msg.msg_flags & MSG_TRUNC) {
- LOG(ERROR) << "Failed to receive Netlink message: truncated";
+ if (nlerr->data.msg.nlmsg_seq != seq) {
+ LOG(ERROR) << "Received ACK for a different message (" << nlerr->data.msg.nlmsg_seq
+ << ", expected " << seq << "). Multi-message tracking is not implemented.";
return false;
}
- for (auto nlmsg = reinterpret_cast<nlmsghdr*>(buf); NLMSG_OK(nlmsg, remainingLen);
- nlmsg = NLMSG_NEXT(nlmsg, remainingLen)) {
- if constexpr (kSuperVerbose) {
- LOG(VERBOSE) << "received Netlink response: "
- << toString({nlmsg, nlmsg->nlmsg_len}, mProtocol);
- }
+ if (nlerr->data.error == 0) return true;
- // We're looking for error/ack message only, ignoring others.
- if (nlmsg->nlmsg_type != NLMSG_ERROR) {
- LOG(WARNING) << "Received unexpected Netlink message (ignored): " << nlmsg->nlmsg_type;
- continue;
- }
-
- // Found error/ack message, return status.
- const auto nlerr = reinterpret_cast<nlmsgerr*>(NLMSG_DATA(nlmsg));
- if (nlerr->error != 0) {
- LOG(ERROR) << "Received Netlink error message: " << strerror(-nlerr->error);
- return false;
- }
- return true;
- }
- // Couldn't find any error/ack messages.
+ LOG(WARNING) << "Received Netlink error message: " << strerror(-nlerr->data.error);
return false;
}
+std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msgtypes,
+ size_t maxSize) {
+ while (!mFailed) {
+ const auto msgBuf = receive(maxSize);
+ if (!msgBuf.has_value()) return std::nullopt;
+
+ for (const auto rawMsg : *msgBuf) {
+ if (msgtypes.count(rawMsg->nlmsg_type) == 0) {
+ LOG(WARNING) << "Received (and ignored) unexpected Netlink message of type "
+ << rawMsg->nlmsg_type;
+ continue;
+ }
+
+ return rawMsg;
+ }
+ }
+
+ return std::nullopt;
+}
+
std::optional<unsigned> Socket::getPid() {
+ if (mFailed) return std::nullopt;
+
sockaddr_nl sa = {};
socklen_t sasize = sizeof(sa);
if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/Message.h b/automotive/can/1.0/default/libnl++/include/libnl++/Message.h
index 2b84a86..50b3c4b 100644
--- a/automotive/can/1.0/default/libnl++/include/libnl++/Message.h
+++ b/automotive/can/1.0/default/libnl++/include/libnl++/Message.h
@@ -19,6 +19,8 @@
#include <libnl++/Attributes.h>
#include <libnl++/Buffer.h>
+#include <set>
+
namespace android::nl {
/**
@@ -60,7 +62,8 @@
* \return Parsed message or nullopt, if the buffer data is invalid or message type
* doesn't match.
*/
- static std::optional<Message<T>> parse(Buffer<nlmsghdr> buf, std::set<nlmsgtype_t> msgtypes) {
+ static std::optional<Message<T>> parse(Buffer<nlmsghdr> buf,
+ const std::set<nlmsgtype_t>& msgtypes) {
const auto& [nlOk, nlHeader] = buf.getFirst(); // we're doing it twice, but it's fine
if (!nlOk) return std::nullopt;
diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
index bc6ad9d..16b63f5 100644
--- a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
+++ b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
@@ -19,11 +19,13 @@
#include <android-base/macros.h>
#include <android-base/unique_fd.h>
#include <libnl++/Buffer.h>
+#include <libnl++/Message.h>
#include <libnl++/MessageFactory.h>
#include <linux/netlink.h>
#include <optional>
+#include <set>
#include <vector>
namespace android::nl {
@@ -57,7 +59,7 @@
* \param msg Message to send. Its sequence number will be updated.
* \return true, if succeeded.
*/
- template <class T, unsigned BUFSIZE>
+ template <typename T, unsigned BUFSIZE>
bool send(MessageFactory<T, BUFSIZE>& req) {
sockaddr_nl sa = {};
sa.nl_family = AF_NETLINK;
@@ -72,7 +74,7 @@
* \param sa Destination address.
* \return true, if succeeded.
*/
- template <class T, unsigned BUFSIZE>
+ template <typename T, unsigned BUFSIZE>
bool send(MessageFactory<T, BUFSIZE>& req, const sockaddr_nl& sa) {
if (!req.isGood()) return false;
@@ -109,7 +111,7 @@
* WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
* call to the read function or until deallocation of Socket instance.
*
- * \param maxSize Maximum total size of received messages
+ * \param maxSize Maximum total size of received messages.
* \return A pair (for use with structured binding) containing:
* - buffer view with message data, std::nullopt on error;
* - sender process address.
@@ -118,27 +120,70 @@
size_t maxSize = defaultReceiveSize);
/**
- * Receive Netlink ACK message from Kernel.
+ * Receive matching Netlink message of a given payload type.
*
- * \return true if received ACK message, false in case of error
+ * This method should be used if the caller expects exactly one incoming message of exactly
+ * given type (such as ACK). If there is a use case to handle multiple types of messages,
+ * please use receive(size_t) directly and iterate through potential multipart messages.
+ *
+ * If this method is used in such an environment, it will only return the first matching message
+ * from multipart packet and will issue warnings on messages that do not match.
+ *
+ * \param msgtypes Expected message types (such as NLMSG_ERROR).
+ * \param maxSize Maximum total size of received messages.
+ * \return Parsed message or std::nullopt in case of error.
*/
- bool receiveAck();
+ template <typename T>
+ std::optional<Message<T>> receive(const std::set<nlmsgtype_t>& msgtypes,
+ size_t maxSize = defaultReceiveSize) {
+ const auto msg = receive(msgtypes, maxSize);
+ if (!msg.has_value()) return std::nullopt;
+
+ const auto parsed = Message<T>::parse(*msg);
+ if (!parsed.has_value()) {
+ LOG(WARNING) << "Received matching Netlink message, but couldn't parse it";
+ return std::nullopt;
+ }
+
+ return parsed;
+ }
+
+ /**
+ * Receive Netlink ACK message.
+ *
+ * \param req Message to match sequence number against.
+ * \return true if received ACK message, false in case of error.
+ */
+ template <typename T, unsigned BUFSIZE>
+ bool receiveAck(MessageFactory<T, BUFSIZE>& req) {
+ return receiveAck(req.header()->nlmsg_seq);
+ }
+
+ /**
+ * Receive Netlink ACK message.
+ *
+ * \param seq Sequence number of message to ACK.
+ * \return true if received ACK message, false in case of error.
+ */
+ bool receiveAck(uint32_t seq);
/**
* Fetches the socket PID.
*
- * \return PID that socket is bound to.
+ * \return PID that socket is bound to or std::nullopt.
*/
std::optional<unsigned> getPid();
private:
const int mProtocol;
-
- uint32_t mSeq = 0;
base::unique_fd mFd;
- bool mFailed = false;
std::vector<uint8_t> mReceiveBuffer;
+ bool mFailed = false;
+ uint32_t mSeq = 0;
+
+ std::optional<Buffer<nlmsghdr>> receive(const std::set<nlmsgtype_t>& msgtypes, size_t maxSize);
+
DISALLOW_COPY_AND_ASSIGN(Socket);
};