/* See LICENSE file for copyright and license details. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "arg.h" #define STYPE_MAX(T) (long long int)((1ULL << (8 * sizeof(T) - 1)) - 1) #define eprintf(...) (weprintf(__VA_ARGS__), exit(1)) enum blocking_mode { BLOCKING_QUEUE, BLOCKING_DISCARD, BLOCKING_BLOCK, BLOCKING_ERROR }; enum order { ORDER_QUEUE = 0, ORDER_STACK = 1, ORDER_RANDOM_QUEUE = 2, ORDER_RANDOM_STACK = 3, }; #define ORDER_RANDOM 2 struct client { int fd; enum blocking_mode soft_blocking_mode; enum blocking_mode hard_blocking_mode; enum order order; char **subs; size_t nsubs; size_t subs_siz; struct client *prev; struct client *next; }; char *argv0; static struct client head; static struct client tail; static int epfd; static int had_client = 0; static struct sockaddr_un addr; static uid_t *users; static size_t nusers; static const char *pidfile = "/run/sbus.pid"; static void usage(void) { fprintf(stderr, "usage: %s [-a address] [-f | -p pidfile] [-u user] ... [-cgor]\n", argv0); exit(1); } static void weprintf(const char *fmt, ...) { va_list args; va_start(args, fmt); vfprintf(stderr, fmt, args); if (strchr(fmt, '\0')[-1] == ':') { fputc(' ', stderr); perror(NULL); } va_end(args); } static void sigexit(int signo) { if (*addr.sun_path) if (unlink(addr.sun_path)) weprintf("unlink %s:", addr.sun_path); if (pidfile) if (unlink(pidfile)) weprintf("unlink %s:", pidfile); exit(0); (void) signo; } static struct client * add_client(int fd) { struct client *cl; cl = malloc(sizeof(*cl)); if (!cl) return NULL; cl->fd = fd; cl->soft_blocking_mode = BLOCKING_QUEUE; cl->hard_blocking_mode = BLOCKING_DISCARD; cl->order = ORDER_RANDOM_QUEUE; cl->subs = NULL; cl->nsubs = 0; cl->subs_siz = 0; cl->next = &tail; cl->prev = tail.prev; tail.prev->next = cl; tail.prev = cl; return cl; } static void remove_client(struct client *cl) { close(cl->fd); cl->prev->next = cl->next; cl->next->prev = cl->prev; while (cl->nsubs--) free(cl->subs[cl->nsubs]); free(cl->subs); free(cl); } static void accept_client(int fd) { struct ucred cred; struct epoll_event ev; size_t i; if (fd < 0) { weprintf("accept :"); return; } if (nusers) { if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cred, &(socklen_t){sizeof(cred)}) < 0) { weprintf("getsockopt SOL_SOCKET SO_PEERCRED:"); close(fd); return; } for (i = nusers; i--;) if (users[i] == cred.uid) goto cred_ok; weprintf("rejected connection from user %li\n", (long int)cred.uid); close(fd); return; } cred_ok: ev.events = EPOLLIN | EPOLLRDHUP; ev.data.ptr = add_client(fd); if (!ev.data.ptr) { close(fd); } else if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev)) { weprintf("epoll_ctl EPOLL_CTL_ADD :"); remove_client((void *)ev.data.ptr); } else { had_client = 1; } } static int is_subscription_match(const char *sub, const char *key) { const char *sub_start = sub; for (;;) { while (*sub && *sub == *key) { sub++; key++; } if (!*key) return !*sub; if (!*sub) return sub == sub_start || sub[-1] == '/'; if (*sub == '*') { sub++; while (*key && *key != '/') key++; continue; } return 0; } } static int is_subscribed(const struct client *cl, const char *key) { size_t i = cl->nsubs; while (i--) if (is_subscription_match(cl->subs[i], key)) return 1; return 0; } static int is_subscription_acceptable(struct client *cl, const char *key) { struct ucred cred; long long int tmp; const char *p; if (!strncmp(key, "!/cred/", sizeof("!/cred/") - 1)) { if (getsockopt(cl->fd, SOL_SOCKET, SO_PEERCRED, &cred, &(socklen_t){sizeof(cred)}) < 0) { weprintf("getsockopt SOL_SOCKET SO_PEERCRED:"); return -1; } errno = 0; p = &key[sizeof("!/cred/") - 1]; #define TEST_CRED(ID)\ if (!*p) {\ return 0;\ } else if (*p++ != '/') {\ if (!isdigit(*p))\ return 0;\ tmp = strtoll(p, (void *)&p, 10);\ if (errno || (*p && *p != '/') || (ID##_t)tmp != cred.ID)\ return 0;\ } TEST_CRED(gid); TEST_CRED(uid); TEST_CRED(pid); #undef TEST_CRED } return 1; } static void add_subscription(struct client *cl, const char *key) { size_t n; char **new, *k; switch (is_subscription_acceptable(cl, key)) { case -1: remove_client(cl); return; case 0: weprintf("client subscribed unacceptable routing key\n"); remove_client(cl); return; default: break; } if (cl->subs_siz == cl->nsubs) { n = cl->subs_siz ? (cl->subs_siz << 1) : 1; new = realloc(cl->subs, n * sizeof(char *)); if (!new) { weprintf("realloc:"); remove_client(cl); return; } cl->subs = new; cl->subs_siz = n; } k = strdup(key); if (!k) { weprintf("strdup:"); remove_client(cl); return; } cl->subs[cl->nsubs++] = k; } static void remove_subscription(struct client *cl, const char *key) { size_t i = cl->nsubs; char **new; while (i--) { if (!strcmp(key, cl->subs[i])) { free(cl->subs[i]); memmove(&cl->subs[i], &cl->subs[i + 1], --(cl->nsubs) - i); if (cl->subs_siz >= 4 * cl->nsubs) { new = realloc(cl->subs, cl->nsubs * sizeof(char *)); if (new) { cl->subs_siz = cl->nsubs; cl->subs = new; } } break; } } } static int send_packet(struct client *cl, const char *buf, size_t n) { /* TODO honour cl->soft_blocking_mode, cl->hard_blocking_mode, and cl->order */ return -(send(cl->fd, buf, n, 0) < 0); } static void handle_cmsg(struct client *cl, const char *msg, size_t n) { if (!strcmp(msg, "CMSG !/cred/prefix")) { n = sizeof("CMSG !/cred/prefix"); } else if (!strcmp(msg, "CMSG blocking/soft/queue")) { cl->soft_blocking_mode = BLOCKING_QUEUE; } else if (!strcmp(msg, "CMSG blocking/soft/discard")) { cl->soft_blocking_mode = BLOCKING_DISCARD; } else if (!strcmp(msg, "CMSG blocking/soft/block")) { cl->soft_blocking_mode = BLOCKING_BLOCK; } else if (!strcmp(msg, "CMSG blocking/soft/error")) { cl->soft_blocking_mode = BLOCKING_ERROR; } else if (!strcmp(msg, "CMSG blocking/hard/discard")) { cl->hard_blocking_mode = BLOCKING_DISCARD; } else if (!strcmp(msg, "CMSG blocking/hard/block")) { cl->hard_blocking_mode = BLOCKING_BLOCK; } else if (!strcmp(msg, "CMSG blocking/hard/error")) { cl->hard_blocking_mode = BLOCKING_ERROR; } else if (!strcmp(msg, "CMSG order/queue")) { cl->order = ORDER_QUEUE; } else if (!strcmp(msg, "CMSG order/stack")) { cl->order = ORDER_STACK; } else if (!strcmp(msg, "CMSG order/random")) { cl->order |= ORDER_RANDOM; } else { return; } if (send_packet(cl, msg, n)) { weprintf("send :"); remove_client(cl); } } static void broadcast(const char *msg, size_t n) { struct client *cl = head.next, *tmp; for (; cl->next; cl = cl->next) { if (!is_subscribed(cl, &msg[4])) continue; if (send_packet(cl, msg, n)) { cl = (tmp = cl)->prev; weprintf("send :"); remove_client(tmp); } } } static void handle_message(struct client *cl) { static char buf[3 << 17]; int fd = cl->fd; ssize_t r; r = recv(fd, buf, sizeof(buf) - 1, 0); if (r < 0) { weprintf("recv :"); remove_client(cl); return; } buf[r] = '\0'; if (!strncmp(buf, "MSG ", 4)) { broadcast(buf, r); } else if (!strncmp(buf, "UNSUB ", 6)) { remove_subscription(cl, &buf[6]); } else if (!strncmp(buf, "SUB ", 4)) { add_subscription(cl, &buf[4]); } else if (!strncmp(buf, "CMSG ", 5)) { handle_cmsg(cl, buf, r); } else { weprintf("received bad message\n"); remove_client(cl); } } static void randomise(void *buf, size_t n) { char *p = buf; while (n--) *p++ = rand(); } static void print_address(void) { char buf[2 * sizeof(addr.sun_path) + 1]; char *p = buf; const unsigned char *a = (const unsigned char *)addr.sun_path; size_t n = sizeof(addr.sun_path); for (; n--; p += 2, a += 1) { p[0] = "0123456789abcdef"[(int)*a >> 4]; p[1] = "0123456789abcdef"[(int)*a & 15]; } *p = '\0'; printf("/dev/unix/abstract/%s\n", buf); if (fflush(stdout) || ferror(stdout)) eprintf("failed print generated address:"); } static int make_socket(const char *address, int reuse, mode_t mode) { int fd = -1, randaddr = 0, hi, lo, listening = 0; long int tmp; size_t n; const char *p, *q; char *a; memset(&addr, 0, sizeof(addr)); addr.sun_family = AF_UNIX; if (strstr(address, "/dev/fd/") == address) { p = &address[sizeof("/dev/fd/") - 1]; if (!isdigit(*p)) goto def; errno = 0; tmp = strtol(p, &a, 10); if (errno || *a || tmp < 0) { errno = 0; goto def; } if (tmp > INT_MAX) { errno = EBADF; goto bad_address; } fd = (int)tmp; reuse = 0; } else if (!strcmp(address, "/dev/unix/abstract")) { randaddr = 1; reuse = 0; } else if (strstr(address, "/dev/unix/abstract/") == address) { p = &address[sizeof("/dev/unix/abstract/") - 1]; n = strlen(p); if (n & 1) goto def; for (q = p; *q; q++) if (!isxdigit(*q)) goto def; if (n > sizeof(addr.sun_path) * 2) { errno = ENAMETOOLONG; goto bad_address; } a = addr.sun_path; for (; *p; p += 2) { hi = (p[0] & 15) + 9 * !isdigit(p[0]); lo = (p[1] & 15) + 9 * !isdigit(p[1]); *a++ = (hi << 4) | lo; } reuse = 0; } else { def: if (strlen(address) >= sizeof(addr.sun_path)) { errno = ENAMETOOLONG; goto bad_address; } strcpy(addr.sun_path, address); } if (reuse) unlink(addr.sun_path); if (fd < 0) { fd = socket(PF_UNIX, SOCK_SEQPACKET, 0); if (fd < 0) eprintf("socket PF_UNIX SOCK_SEQPACKET:"); if (fchmod(fd, mode)) eprintf("fchmod %o:", mode); if (randaddr) { srand((unsigned)time(NULL)); for (;;) { randomise(&addr.sun_path[1], sizeof(addr.sun_path) - 1); if (!bind(fd, (void *)&addr, sizeof(addr))) break; else if (errno != EADDRINUSE) eprintf("bind :"); } print_address(); } else { if (bind(fd, (void *)&addr, sizeof(addr))) { if (*addr.sun_path) eprintf("bind %s:", addr.sun_path); else eprintf("bind :", &address[sizeof("/dev/unix/abstract/") - 1]); } } } else { if (mode & 0070) weprintf("ignoring -g due to using passed down socket\n"); if (mode & 0007) weprintf("ignoring -o due to using passed down socket\n"); if (getsockopt(fd, SOL_SOCKET, SO_ACCEPTCONN, &listening, &(socklen_t){sizeof(listening)})) eprintf("getsockopt SOL_SOCKET SO_ACCEPTCONN:"); } if (!listening && listen(fd, SOMAXCONN)) eprintf("listen:"); return fd; bad_address: eprintf("bad unix socket address:"); exit(1); } static void daemonise(void) { pid_t pid; int rw[2], status = 0, fd; FILE *fp; if (pipe(rw)) eprintf("pipe:"); switch ((pid = fork())) { case -1: eprintf("fork:"); case 0: close(rw[0]); setsid(); switch (fork()) { case -1: eprintf("fork:"); case 0: if (signal(SIGHUP, SIG_IGN) == SIG_ERR) weprintf("signal SIGHUP SIG_IGN:"); if (signal(SIGINT, sigexit) == SIG_ERR) weprintf("signal SIGINT :"); if (pidfile) { pid = getpid(); fd = open(pidfile, O_WRONLY | O_CREAT | O_EXCL, 0644); if (fd < 0) eprintf("open %s O_WRONLY O_CREAT O_EXCL:", pidfile); fp = fdopen(fd, "w"); fprintf(fp, "%li\n", (long int)pid); if (fflush(fp) || ferror(fp)) eprintf("fprintf %s:", pidfile); fclose(fp); } if (chdir("/")) eprintf("chdir /:"); close(STDIN_FILENO); close(STDOUT_FILENO); if (isatty(STDERR_FILENO)) { fd = open("/dev/null", O_WRONLY); if (fd) eprintf("open /dev/null O_WRONLY:"); if (dup2(fd, STDERR_FILENO) != STDERR_FILENO) eprintf("dup2 /dev/null /dev/stderr:"); close(fd); } if (write(rw[1], &status, 1) < 1) eprintf("write :"); close(rw[1]); break; default: exit(0); } break; default: close(rw[1]); if (waitpid(pid, &status, 0) != pid) eprintf("waitpid:"); if (status) exit(1); switch (read(rw[0], &status, 1)) { case -1: eprintf("read :"); case 0: exit(1); default: exit(0); } } } int main(int argc, char *argv[]) { struct epoll_event evs[32]; const char *address = "/run/sbus.socket"; int auto_close = 0; int foreground = 0; mode_t mode = 0700; int reuse_address = 0; struct passwd *user; int server, n; long long int tmp; char *arg; users = alloca(argc * sizeof(*users)); ARGBEGIN { case 'a': address = EARGF(); break; case 'c': auto_close = 1; break; case 'f': foreground = 1; break; case 'g': mode |= 0070; break; case 'o': mode |= 0007; break; case 'p': pidfile = EARGF(); break; case 'r': reuse_address = 1; break; case 'u': arg = EARGF(); if (!isdigit(*arg)) goto user_by_name; errno = 0; tmp = strtoll(arg, &arg, 10); if (errno || *arg || tmp < 0 || tmp > STYPE_MAX(uid_t)) goto user_by_name; users[nusers++] = (uid_t)tmp; user_by_name: user = getpwnam(arg); if (!user) eprintf("getpwnam %s:", arg); users[nusers++] = user->pw_uid; break; default: usage(); } ARGEND; if (argc) usage(); umask(0); server = make_socket(address, reuse_address, mode); if (foreground) { close(STDIN_FILENO); close(STDOUT_FILENO); if (signal(SIGHUP, sigexit) == SIG_ERR) weprintf("signal SIGHUP :"); if (signal(SIGINT, sigexit) == SIG_ERR) weprintf("signal SIGINT :"); pidfile = NULL; } else { if (!strcmp(pidfile, "/dev/null")) pidfile = NULL; daemonise(); } if (nusers) users[nusers++] = getuid(); head.next = &tail; tail.prev = &head; epfd = epoll_create1(0); if (epfd < 0) eprintf("epoll_create1:"); evs->events = EPOLLIN; evs->data.ptr = NULL; if (epoll_ctl(epfd, EPOLL_CTL_ADD, server, evs)) eprintf("epoll_ctl EPOLL_CTL_ADD :"); while (!auto_close || !had_client || head.next->next) { n = epoll_wait(epfd, evs, sizeof(evs) / sizeof(*evs), -1); if (n < 0) eprintf("epoll_wait:"); while (n--) { if (!evs[n].data.ptr) accept_client(accept(server, NULL, NULL)); else if (evs[n].events & (EPOLLRDHUP | EPOLLHUP)) remove_client((void *)evs[n].data.ptr); else handle_message((void *)evs[n].data.ptr); } } sigexit(0); }