Page MenuHomeFreeBSD

D48544.diff
No OneTemporary

D48544.diff

diff --git a/sys/kern/uipc_domain.c b/sys/kern/uipc_domain.c
--- a/sys/kern/uipc_domain.c
+++ b/sys/kern/uipc_domain.c
@@ -109,7 +109,7 @@
return (EOPNOTSUPP);
}
-static int
+int
pr_listen_notsupp(struct socket *so, int backlog, struct thread *td)
{
return (EOPNOTSUPP);
diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h
--- a/sys/netinet/in_pcb.h
+++ b/sys/netinet/in_pcb.h
@@ -167,7 +167,10 @@
struct m_snd_tag;
struct inpcb {
/* Cache line #1 (amd64) */
- CK_LIST_ENTRY(inpcb) inp_hash_exact; /* hash table linkage */
+ union {
+ CK_LIST_ENTRY(inpcb) inp_hash_exact; /* hash table linkage */
+ LIST_ENTRY(inpcb) inp_lbgroup_list; /* lb group list */
+ };
CK_LIST_ENTRY(inpcb) inp_hash_wild; /* hash table linkage */
struct rwlock inp_lock;
/* Cache line #2 (amd64) */
@@ -428,6 +431,7 @@
*/
struct inpcblbgroup {
CK_LIST_ENTRY(inpcblbgroup) il_list;
+ LIST_HEAD(, inpcb) il_pending; /* PCBs waiting for listen() */
struct epoch_context il_epoch_ctx;
struct ucred *il_cred;
uint16_t il_lport; /* (c) */
@@ -671,6 +675,7 @@
int in_pcbladdr(struct inpcb *, struct in_addr *, struct in_addr *,
struct ucred *);
int in_pcblbgroup_numa(struct inpcb *, int arg);
+void in_pcblisten(struct inpcb *);
struct inpcb *
in_pcblookup(struct inpcbinfo *, struct in_addr, u_int,
struct in_addr, u_int, int, struct ifnet *);
diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -263,6 +263,7 @@
grp = malloc(bytes, M_PCB, M_ZERO | M_NOWAIT);
if (grp == NULL)
return (NULL);
+ LIST_INIT(&grp->il_pending);
grp->il_cred = crhold(cred);
grp->il_vflag = vflag;
grp->il_lport = port;
@@ -285,11 +286,45 @@
static void
in_pcblbgroup_free(struct inpcblbgroup *grp)
{
+ KASSERT(LIST_EMPTY(&grp->il_pending),
+ ("local group %p still has pending inps", grp));
CK_LIST_REMOVE(grp, il_list);
NET_EPOCH_CALL(in_pcblbgroup_free_deferred, &grp->il_epoch_ctx);
}
+static struct inpcblbgroup *
+in_pcblbgroup_find(struct inpcb *inp)
+{
+ struct inpcbinfo *pcbinfo;
+ struct inpcblbgroup *grp;
+ struct inpcblbgrouphead *hdr;
+
+ INP_LOCK_ASSERT(inp);
+
+ pcbinfo = inp->inp_pcbinfo;
+ INP_HASH_LOCK_ASSERT(pcbinfo);
+ KASSERT((inp->inp_flags & INP_INLBGROUP) != 0,
+ ("inpcb %p is not in a load balance group", inp));
+
+ hdr = &pcbinfo->ipi_lbgrouphashbase[
+ INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask)];
+ CK_LIST_FOREACH(grp, hdr, il_list) {
+ struct inpcb *inp1;
+
+ for (unsigned int i = 0; i < grp->il_inpcnt; i++) {
+ if (inp == grp->il_inp[i])
+ goto found;
+ }
+ LIST_FOREACH(inp1, &grp->il_pending, inp_lbgroup_list) {
+ if (inp == inp1)
+ goto found;
+ }
+ }
+found:
+ return (grp);
+}
+
static void
in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp)
{
@@ -298,14 +333,24 @@
grp->il_inpcnt));
INP_WLOCK_ASSERT(inp);
- inp->inp_flags |= INP_INLBGROUP;
- grp->il_inp[grp->il_inpcnt] = inp;
+ if (inp->inp_socket->so_proto->pr_listen != pr_listen_notsupp &&
+ !SOLISTENING(inp->inp_socket)) {
+ /*
+ * If this is a TCP socket, it should not be visible to lbgroup
+ * lookups until listen() has been called.
+ */
+ LIST_INSERT_HEAD(&grp->il_pending, inp, inp_lbgroup_list);
+ } else {
+ grp->il_inp[grp->il_inpcnt] = inp;
- /*
- * Synchronize with in_pcblookup_lbgroup(): make sure that we don't
- * expose a null slot to the lookup path.
- */
- atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1);
+ /*
+ * Synchronize with in_pcblookup_lbgroup(): make sure that we
+ * don't expose a null slot to the lookup path.
+ */
+ atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1);
+ }
+
+ inp->inp_flags |= INP_INLBGROUP;
}
static struct inpcblbgroup *
@@ -329,6 +374,8 @@
grp->il_inp[i] = old_grp->il_inp[i];
grp->il_inpcnt = old_grp->il_inpcnt;
CK_LIST_INSERT_HEAD(hdr, grp, il_list);
+ LIST_SWAP(&old_grp->il_pending, &grp->il_pending, inpcb,
+ inp_lbgroup_list);
in_pcblbgroup_free(old_grp);
return (grp);
}
@@ -412,6 +459,7 @@
struct inpcbinfo *pcbinfo;
struct inpcblbgrouphead *hdr;
struct inpcblbgroup *grp;
+ struct inpcb *inp1;
int i;
pcbinfo = inp->inp_pcbinfo;
@@ -427,13 +475,11 @@
if (grp->il_inp[i] != inp)
continue;
- if (grp->il_inpcnt == 1) {
+ if (grp->il_inpcnt == 1 &&
+ LIST_EMPTY(&grp->il_pending)) {
/* We are the last, free this local group. */
in_pcblbgroup_free(grp);
} else {
- KASSERT(grp->il_inpcnt >= 2,
- ("invalid local group count %d",
- grp->il_inpcnt));
grp->il_inp[i] =
grp->il_inp[grp->il_inpcnt - 1];
@@ -446,17 +492,22 @@
inp->inp_flags &= ~INP_INLBGROUP;
return;
}
+ LIST_FOREACH(inp1, &grp->il_pending, inp_lbgroup_list) {
+ if (inp == inp1) {
+ LIST_REMOVE(inp, inp_lbgroup_list);
+ inp->inp_flags &= ~INP_INLBGROUP;
+ return;
+ }
+ }
}
- KASSERT(0, ("%s: did not find %p", __func__, inp));
+ __assert_unreachable();
}
int
in_pcblbgroup_numa(struct inpcb *inp, int arg)
{
struct inpcbinfo *pcbinfo;
- struct inpcblbgrouphead *hdr;
- struct inpcblbgroup *grp;
- int err, i;
+ int error;
uint8_t numa_domain;
switch (arg) {
@@ -472,33 +523,20 @@
numa_domain = arg;
}
- err = 0;
pcbinfo = inp->inp_pcbinfo;
INP_WLOCK_ASSERT(inp);
INP_HASH_WLOCK(pcbinfo);
- hdr = &pcbinfo->ipi_lbgrouphashbase[
- INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask)];
- CK_LIST_FOREACH(grp, hdr, il_list) {
- for (i = 0; i < grp->il_inpcnt; ++i) {
- if (grp->il_inp[i] != inp)
- continue;
-
- if (grp->il_numa_domain == numa_domain) {
- goto abort_with_hash_wlock;
- }
-
- /* Remove it from the old group. */
- in_pcbremlbgrouphash(inp);
-
- /* Add it to the new group based on numa domain. */
- in_pcbinslbgrouphash(inp, numa_domain);
- goto abort_with_hash_wlock;
- }
+ if (in_pcblbgroup_find(inp) != NULL) {
+ /* Remove it from the old group. */
+ in_pcbremlbgrouphash(inp);
+ /* Add it to the new group based on numa domain. */
+ in_pcbinslbgrouphash(inp, numa_domain);
+ error = 0;
+ } else {
+ error = ENOENT;
}
- err = ENOENT;
-abort_with_hash_wlock:
INP_HASH_WUNLOCK(pcbinfo);
- return (err);
+ return (error);
}
/* Make sure it is safe to use hashinit(9) on CK_LIST. */
@@ -1437,6 +1475,25 @@
}
#endif /* INET */
+void
+in_pcblisten(struct inpcb *inp)
+{
+ struct inpcblbgroup *grp;
+
+ INP_WLOCK_ASSERT(inp);
+
+ if ((inp->inp_flags & INP_INLBGROUP) != 0) {
+ struct inpcbinfo *pcbinfo;
+
+ pcbinfo = inp->inp_pcbinfo;
+ INP_HASH_WLOCK(pcbinfo);
+ grp = in_pcblbgroup_find(inp);
+ LIST_REMOVE(inp, inp_lbgroup_list);
+ in_pcblbgroup_insert(grp, inp);
+ INP_HASH_WUNLOCK(pcbinfo);
+ }
+}
+
/*
* inpcb hash lookups are protected by SMR section.
*
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -391,6 +391,8 @@
}
SOCK_UNLOCK(so);
+ if (error == 0)
+ in_pcblisten(inp);
if (tp->t_flags & TF_FASTOPEN)
tp->t_tfo_pending = tcp_fastopen_alloc_counter();
@@ -448,6 +450,8 @@
}
SOCK_UNLOCK(so);
+ if (error == 0)
+ in_pcblisten(inp);
if (tp->t_flags & TF_FASTOPEN)
tp->t_tfo_pending = tcp_fastopen_alloc_counter();
diff --git a/sys/sys/socketvar.h b/sys/sys/socketvar.h
--- a/sys/sys/socketvar.h
+++ b/sys/sys/socketvar.h
@@ -596,6 +596,8 @@
int accept_filt_generic_mod_event(module_t mod, int event, void *data);
#endif
+int pr_listen_notsupp(struct socket *so, int backlog, struct thread *td);
+
#endif /* _KERNEL */
/*
diff --git a/tests/sys/netinet/Makefile b/tests/sys/netinet/Makefile
--- a/tests/sys/netinet/Makefile
+++ b/tests/sys/netinet/Makefile
@@ -27,6 +27,8 @@
ATF_TESTS_PYTEST+= carp.py
ATF_TESTS_PYTEST+= igmp.py
+LIBADD.so_reuseport_lb_test= pthread
+
# Some of the arp tests look for log messages in the dmesg buffer, so run them
# serially to avoid problems with interleaved output.
TEST_METADATA.arp+= is_exclusive="true"
diff --git a/tests/sys/netinet/so_reuseport_lb_test.c b/tests/sys/netinet/so_reuseport_lb_test.c
--- a/tests/sys/netinet/so_reuseport_lb_test.c
+++ b/tests/sys/netinet/so_reuseport_lb_test.c
@@ -28,12 +28,16 @@
*/
#include <sys/param.h>
+#include <sys/event.h>
#include <sys/socket.h>
#include <netinet/in.h>
+#include <netinet/tcp.h>
#include <err.h>
#include <errno.h>
+#include <pthread.h>
+#include <stdatomic.h>
#include <stdlib.h>
#include <unistd.h>
@@ -235,10 +239,149 @@
}
}
+struct concurrent_add_softc {
+ struct sockaddr_storage ss;
+ int socks[128];
+ int kq;
+};
+
+static void *
+listener(void *arg)
+{
+ for (struct concurrent_add_softc *sc = arg;;) {
+ struct kevent kev;
+ ssize_t n;
+ int error, count, cs, s;
+ uint8_t b;
+
+ count = kevent(sc->kq, NULL, 0, &kev, 1, NULL);
+ ATF_REQUIRE_MSG(count == 1,
+ "kevent() failed: %s", strerror(errno));
+
+ s = (int)kev.ident;
+ cs = accept(s, NULL, NULL);
+ ATF_REQUIRE_MSG(cs >= 0,
+ "accept() failed: %s", strerror(errno));
+
+ b = 'M';
+ n = write(cs, &b, sizeof(b));
+ ATF_REQUIRE_MSG(n >= 0, "write() failed: %s", strerror(errno));
+ ATF_REQUIRE(n == 1);
+
+ error = close(cs);
+ ATF_REQUIRE_MSG(error == 0 || errno == ECONNRESET,
+ "close() failed: %s", strerror(errno));
+ }
+}
+
+static void *
+connector(void *arg)
+{
+ for (struct concurrent_add_softc *sc = arg;;) {
+ ssize_t n;
+ int error, s;
+ uint8_t b;
+
+ s = socket(sc->ss.ss_family, SOCK_STREAM, 0);
+ ATF_REQUIRE_MSG(s >= 0, "socket() failed: %s", strerror(errno));
+
+ error = setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (int[]){1},
+ sizeof(int));
+
+ error = connect(s, (struct sockaddr *)&sc->ss, sc->ss.ss_len);
+ ATF_REQUIRE_MSG(error == 0, "connect() failed: %s",
+ strerror(errno));
+
+ n = read(s, &b, sizeof(b));
+ ATF_REQUIRE_MSG(n >= 0, "read() failed: %s",
+ strerror(errno));
+ ATF_REQUIRE(n == 1);
+ ATF_REQUIRE(b == 'M');
+ error = close(s);
+ ATF_REQUIRE_MSG(error == 0,
+ "close() failed: %s", strerror(errno));
+ }
+}
+
+/*
+ * Run three threads. One accepts connections from listening sockets on a
+ * kqueue, while the other makes connections. The third thread slowly adds
+ * sockets to the LB group. This is meant to help flush out race conditions.
+ */
+ATF_TC_WITHOUT_HEAD(concurrent_add);
+ATF_TC_BODY(concurrent_add, tc)
+{
+ struct concurrent_add_softc sc;
+ struct sockaddr_in *sin;
+ pthread_t threads[4];
+ int error;
+
+ sc.kq = kqueue();
+ ATF_REQUIRE_MSG(sc.kq >= 0, "kqueue() failed: %s", strerror(errno));
+
+ error = pthread_create(&threads[0], NULL, listener, &sc);
+ ATF_REQUIRE_MSG(error == 0, "pthread_create() failed: %s",
+ strerror(error));
+
+ sin = (struct sockaddr_in *)&sc.ss;
+ memset(sin, 0, sizeof(*sin));
+ sin->sin_len = sizeof(*sin);
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(0);
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+ for (size_t i = 0; i < nitems(sc.socks); i++) {
+ struct kevent kev;
+ int s;
+
+ sc.socks[i] = s = socket(AF_INET, SOCK_STREAM, 0);
+ ATF_REQUIRE_MSG(s >= 0, "socket() failed: %s", strerror(errno));
+
+ error = setsockopt(s, SOL_SOCKET, SO_REUSEPORT_LB, (int[]){1},
+ sizeof(int));
+ ATF_REQUIRE_MSG(error == 0,
+ "setsockopt(SO_REUSEPORT_LB) failed: %s", strerror(errno));
+
+ error = bind(s, (struct sockaddr *)sin, sizeof(*sin));
+ ATF_REQUIRE_MSG(error == 0, "bind() failed: %s",
+ strerror(errno));
+
+ error = listen(s, 5);
+ ATF_REQUIRE_MSG(error == 0, "listen() failed: %s",
+ strerror(errno));
+
+ EV_SET(&kev, s, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, 0);
+ error = kevent(sc.kq, &kev, 1, NULL, 0, NULL);
+ ATF_REQUIRE_MSG(error == 0, "kevent() failed: %s",
+ strerror(errno));
+
+ if (i == 0) {
+ socklen_t slen = sizeof(sc.ss);
+
+ error = getsockname(sc.socks[i],
+ (struct sockaddr *)&sc.ss, &slen);
+ ATF_REQUIRE_MSG(error == 0, "getsockname() failed: %s",
+ strerror(errno));
+ ATF_REQUIRE(sc.ss.ss_family == AF_INET);
+
+ for (size_t j = 1; j < nitems(threads); j++) {
+ error = pthread_create(&threads[j], NULL,
+ connector, &sc);
+ ATF_REQUIRE_MSG(error == 0,
+ "pthread_create() failed: %s",
+ strerror(error));
+ }
+ }
+
+ usleep(20000);
+ }
+}
+
ATF_TP_ADD_TCS(tp)
{
ATF_TP_ADD_TC(tp, basic_ipv4);
ATF_TP_ADD_TC(tp, basic_ipv6);
+ ATF_TP_ADD_TC(tp, concurrent_add);
return (atf_no_error());
}

File Metadata

Mime Type
text/plain
Expires
Mon, Feb 3, 7:50 PM (22 h, 8 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
16441485
Default Alt Text
D48544.diff (12 KB)

Event Timeline