Netlink socket refactoring
- merge two send() methods into one
- use internal receive buffer instead of asking user to supply one
- move setting sequence number to MessageFactory sending code
- don't limit send function to Kernel as a recipient
- move adding NLM_F_ACK to the caller side
- getSocketPid -> getPid
- unsigned int -> unsigned
One part missing is refactoring receiveAck (b/161389935).
Bug: 162032964
Test: canhalctrl up test virtual vcan3
Change-Id: Ie3d460dbc2ea1251469bf08504cfe2c6e80bbe75
diff --git a/automotive/can/1.0/default/libnetdevice/can.cpp b/automotive/can/1.0/default/libnetdevice/can.cpp
index b047bc9..ab107fd 100644
--- a/automotive/can/1.0/default/libnetdevice/can.cpp
+++ b/automotive/can/1.0/default/libnetdevice/can.cpp
@@ -70,7 +70,7 @@
struct can_bittiming bt = {};
bt.bitrate = bitrate;
- nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK, NLM_F_REQUEST);
+ nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK, NLM_F_REQUEST | NLM_F_ACK);
const auto ifidx = nametoindex(ifname);
if (ifidx == 0) {
diff --git a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
index f7f5f4d..ed2a51e 100644
--- a/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
+++ b/automotive/can/1.0/default/libnetdevice/libnetdevice.cpp
@@ -63,7 +63,7 @@
bool add(std::string dev, std::string type) {
nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK,
- NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL);
+ NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK);
req.addattr(IFLA_IFNAME, dev);
{
@@ -76,7 +76,7 @@
}
bool del(std::string dev) {
- nl::MessageFactory<struct ifinfomsg> req(RTM_DELLINK, NLM_F_REQUEST);
+ nl::MessageFactory<struct ifinfomsg> req(RTM_DELLINK, NLM_F_REQUEST | NLM_F_ACK);
req.addattr(IFLA_IFNAME, dev);
nl::Socket sock(NETLINK_ROUTE);
diff --git a/automotive/can/1.0/default/libnetdevice/vlan.cpp b/automotive/can/1.0/default/libnetdevice/vlan.cpp
index 3e07f67..3f904f0 100644
--- a/automotive/can/1.0/default/libnetdevice/vlan.cpp
+++ b/automotive/can/1.0/default/libnetdevice/vlan.cpp
@@ -34,7 +34,7 @@
}
nl::MessageFactory<struct ifinfomsg> req(RTM_NEWLINK,
- NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL);
+ NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK);
req.addattr(IFLA_IFNAME, vlan);
req.addattr<uint32_t>(IFLA_LINK, ethidx);
diff --git a/automotive/can/1.0/default/libnl++/Socket.cpp b/automotive/can/1.0/default/libnl++/Socket.cpp
index aac6416..56e990c 100644
--- a/automotive/can/1.0/default/libnl++/Socket.cpp
+++ b/automotive/can/1.0/default/libnl++/Socket.cpp
@@ -27,7 +27,7 @@
*/
static constexpr bool kSuperVerbose = false;
-Socket::Socket(int protocol, unsigned int pid, uint32_t groups) : mProtocol(protocol) {
+Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) {
mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
if (!mFd.ok()) {
PLOG(ERROR) << "Can't open Netlink socket";
@@ -47,83 +47,60 @@
}
}
-bool Socket::send(nlmsghdr* nlmsg, size_t totalLen) {
- if constexpr (kSuperVerbose) {
- nlmsg->nlmsg_seq = mSeq;
- LOG(VERBOSE) << (mFailed ? "(not) " : "")
- << "sending Netlink message: " << toString({nlmsg, totalLen}, mProtocol);
- }
-
- if (mFailed) return false;
-
- nlmsg->nlmsg_pid = 0; // kernel
- nlmsg->nlmsg_seq = mSeq++;
- nlmsg->nlmsg_flags |= NLM_F_ACK;
-
- iovec iov = {nlmsg, nlmsg->nlmsg_len};
-
- sockaddr_nl sa = {};
- sa.nl_family = AF_NETLINK;
-
- msghdr msg = {};
- msg.msg_name = &sa;
- msg.msg_namelen = sizeof(sa);
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
-
- if (sendmsg(mFd.get(), &msg, 0) < 0) {
- PLOG(ERROR) << "Can't send Netlink message";
- return false;
- }
- return true;
-}
-
bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
if constexpr (kSuperVerbose) {
- LOG(VERBOSE) << (mFailed ? "(not) " : "")
- << "sending Netlink message: " << toString(msg, mProtocol);
+ LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending Netlink message (" //
+ << msg->nlmsg_pid << " -> " << sa.nl_pid << "): " << toString(msg, mProtocol);
}
-
if (mFailed) return false;
+
+ mSeq = msg->nlmsg_seq;
const auto rawMsg = msg.getRaw();
const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
if (bytesSent < 0) {
PLOG(ERROR) << "Can't send Netlink message";
return false;
+ } else if (size_t(bytesSent) != rawMsg.len()) {
+ LOG(ERROR) << "Can't send Netlink message: truncated message";
+ return false;
}
return true;
}
-std::optional<Buffer<nlmsghdr>> Socket::receive(void* buf, size_t bufLen) {
- sockaddr_nl sa = {};
- return receive(buf, bufLen, sa);
+std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
+ return receiveFrom(maxSize).first;
}
-std::optional<Buffer<nlmsghdr>> Socket::receive(void* buf, size_t bufLen, sockaddr_nl& sa) {
- if (mFailed) return std::nullopt;
+std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
+ if (mFailed) return {std::nullopt, {}};
- socklen_t saLen = sizeof(sa);
- if (bufLen == 0) {
- LOG(ERROR) << "Receive buffer has zero size!";
- return std::nullopt;
+ if (maxSize == 0) {
+ LOG(ERROR) << "Maximum receive size should not be zero";
+ return {std::nullopt, {}};
}
- const auto bytesReceived =
- recvfrom(mFd.get(), buf, bufLen, MSG_TRUNC, reinterpret_cast<sockaddr*>(&sa), &saLen);
+ if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
+
+ sockaddr_nl sa = {};
+ socklen_t saLen = sizeof(sa);
+ const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC,
+ reinterpret_cast<sockaddr*>(&sa), &saLen);
+
if (bytesReceived <= 0) {
PLOG(ERROR) << "Failed to receive Netlink message";
- return std::nullopt;
- } else if (unsigned(bytesReceived) > bufLen) {
- PLOG(ERROR) << "Received data larger than the receive buffer! " << bytesReceived << " > "
- << bufLen;
- return std::nullopt;
+ return {std::nullopt, {}};
+ } else if (size_t(bytesReceived) > maxSize) {
+ PLOG(ERROR) << "Received data larger than maximum receive size: " //
+ << bytesReceived << " > " << maxSize;
+ return {std::nullopt, {}};
}
- Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(buf), bytesReceived);
+ Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(mReceiveBuffer.data()), bytesReceived);
if constexpr (kSuperVerbose) {
- LOG(VERBOSE) << "received " << toString(msg, mProtocol);
+ LOG(VERBOSE) << "received (" << sa.nl_pid << " -> " << msg->nlmsg_pid << "):" //
+ << toString(msg, mProtocol);
}
- return msg;
+ return {msg, sa};
}
/* TODO(161389935): Migrate receiveAck to use nlmsg<> internally. Possibly reuse
@@ -179,11 +156,11 @@
return false;
}
-std::optional<unsigned int> Socket::getSocketPid() {
+std::optional<unsigned> Socket::getPid() {
sockaddr_nl sa = {};
socklen_t sasize = sizeof(sa);
if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
- PLOG(ERROR) << "Failed to getsockname() for netlink_fd!";
+ PLOG(ERROR) << "Failed to get PID of Netlink socket";
return std::nullopt;
}
return sa.nl_pid;
diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h b/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h
index e00ca20..5272577 100644
--- a/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h
+++ b/automotive/can/1.0/default/libnl++/include/libnl++/MessageFactory.h
@@ -35,7 +35,6 @@
} // namespace impl
-// TODO(twasilczyk): rename to NetlinkMessage
/**
* Wrapper around NETLINK_ROUTE messages, to build them in C++ style.
*
diff --git a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
index 7685733..bc6ad9d 100644
--- a/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
+++ b/automotive/can/1.0/default/libnl++/include/libnl++/Socket.h
@@ -24,6 +24,7 @@
#include <linux/netlink.h>
#include <optional>
+#include <vector>
namespace android::nl {
@@ -33,59 +34,88 @@
* This class is not thread safe to use a single instance between multiple threads, but it's fine to
* use multiple instances over multiple threads.
*/
-struct Socket {
+class Socket {
+ public:
+ static constexpr size_t defaultReceiveSize = 8192;
+
/**
* Socket constructor.
*
* \param protocol the Netlink protocol to use.
- * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. (NOTE:
- * this is NOT the same as process id!)
+ * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid.
+ * (NOTE: this is NOT the same as process id).
* \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each
- * bit is a different group. Default value of 0 means no groups are selected. See man netlink.7
+ * bit is a different group. Default value of 0 means no groups are selected.
+ * See man netlink.7.
* for more details.
*/
- Socket(int protocol, unsigned int pid = 0, uint32_t groups = 0);
+ Socket(int protocol, unsigned pid = 0, uint32_t groups = 0);
/**
- * Send Netlink message to Kernel. The sequence number will be automatically incremented, and
- * the NLM_F_ACK (request ACK) flag will be set.
+ * Send Netlink message with incremented sequence number to the Kernel.
*
- * \param msg Message to send.
- * \return true, if succeeded
+ * \param msg Message to send. Its sequence number will be updated.
+ * \return true, if succeeded.
*/
- template <class T, unsigned int BUFSIZE>
+ template <class T, unsigned BUFSIZE>
bool send(MessageFactory<T, BUFSIZE>& req) {
- if (!req.isGood()) return false;
- return send(req.header(), req.totalLength);
+ sockaddr_nl sa = {};
+ sa.nl_family = AF_NETLINK;
+ sa.nl_pid = 0; // Kernel
+ return send(req, sa);
}
/**
- * Send Netlink message. The message will be sent as is, without any modification.
+ * Send Netlink message with incremented sequence number.
+ *
+ * \param msg Message to send. Its sequence number will be updated.
+ * \param sa Destination address.
+ * \return true, if succeeded.
+ */
+ template <class T, unsigned BUFSIZE>
+ bool send(MessageFactory<T, BUFSIZE>& req, const sockaddr_nl& sa) {
+ if (!req.isGood()) return false;
+
+ const auto nlmsg = req.header();
+ nlmsg->nlmsg_seq = mSeq + 1;
+
+ // With MessageFactory<>, we trust nlmsg_len to be correct.
+ return send({nlmsg, nlmsg->nlmsg_len}, sa);
+ }
+
+ /**
+ * Send Netlink message.
*
* \param msg Message to send.
* \param sa Destination address.
- * \return true, if succeeded
+ * \return true, if succeeded.
*/
bool send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa);
/**
- * Receive Netlink data.
+ * Receive one or multiple Netlink messages.
*
- * \param buf buffer to hold message data.
- * \param bufLen length of buf.
- * \return Buffer with message data, std::nullopt on error.
+ * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
+ * call to the read function or until deallocation of Socket instance.
+ *
+ * \param maxSize Maximum total size of received messages
+ * \return Buffer view with message data, std::nullopt on error.
*/
- std::optional<Buffer<nlmsghdr>> receive(void* buf, size_t bufLen);
+ std::optional<Buffer<nlmsghdr>> receive(size_t maxSize = defaultReceiveSize);
/**
- * Receive Netlink data with address info.
+ * Receive one or multiple Netlink messages and the sender process address.
*
- * \param buf buffer to hold message data.
- * \param bufLen length of buf.
- * \param sa Blank struct that recvfrom will populate with address info.
- * \return Buffer with message data, std::nullopt on error.
+ * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
+ * call to the read function or until deallocation of Socket instance.
+ *
+ * \param maxSize Maximum total size of received messages
+ * \return A pair (for use with structured binding) containing:
+ * - buffer view with message data, std::nullopt on error;
+ * - sender process address.
*/
- std::optional<Buffer<nlmsghdr>> receive(void* buf, size_t bufLen, sockaddr_nl& sa);
+ std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> receiveFrom(
+ size_t maxSize = defaultReceiveSize);
/**
* Receive Netlink ACK message from Kernel.
@@ -95,11 +125,11 @@
bool receiveAck();
/**
- * Gets the PID assigned to mFd.
+ * Fetches the socket PID.
*
- * \return pid that mSocket is bound to.
+ * \return PID that socket is bound to.
*/
- std::optional<unsigned int> getSocketPid();
+ std::optional<unsigned> getPid();
private:
const int mProtocol;
@@ -107,8 +137,7 @@
uint32_t mSeq = 0;
base::unique_fd mFd;
bool mFailed = false;
-
- bool send(nlmsghdr* msg, size_t totalLen);
+ std::vector<uint8_t> mReceiveBuffer;
DISALLOW_COPY_AND_ASSIGN(Socket);
};