From be2253d7dd6a5e71719858ddc1479e57a79834a8 Mon Sep 17 00:00:00 2001
From: Mattias Andrée <maandree@kth.se>
Date: Wed, 3 Aug 2016 21:48:21 +0200
Subject: Fix bugs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Mattias Andrée <maandree@kth.se>
---
 src/servers/coopgamma.c |   2 +-
 src/servers/crtc.c      |   2 +-
 src/servers/master.c    | 108 +++++++++++++++++++++++++++++++++---------------
 3 files changed, 77 insertions(+), 35 deletions(-)

(limited to 'src/servers')

diff --git a/src/servers/coopgamma.c b/src/servers/coopgamma.c
index 935bed0..bef0cf9 100644
--- a/src/servers/coopgamma.c
+++ b/src/servers/coopgamma.c
@@ -321,7 +321,7 @@ int handle_get_gamma(size_t conn, const char* restrict message_id, const char* r
 	n += strlen(output->table_filters[i].class) + 1;
     }
   
-  MAKE_MESSAGE(&buf, &n, 0,
+  MAKE_MESSAGE(&buf, &n, n,
 	       "In response to: %s\n"
 	       "Depth: %s\n"
 	       "Red size: %zu\n"
diff --git a/src/servers/crtc.c b/src/servers/crtc.c
index 5facade..bc3b51b 100644
--- a/src/servers/crtc.c
+++ b/src/servers/crtc.c
@@ -43,7 +43,7 @@ int handle_enumerate_crtcs(size_t conn, const char* restrict message_id)
   for (i = 0; i < outputs_n; i++)
     n += strlen(outputs[i].name) + 1;
   
-  MAKE_MESSAGE(&buf, &n, 0,
+  MAKE_MESSAGE(&buf, &n, n,
 	       "Command: crtc-enumeration\n"
 	       "In response to: %s\n"
 	       "Length: %zu\n"
diff --git a/src/servers/master.c b/src/servers/master.c
index f655e37..65012b9 100644
--- a/src/servers/master.c
+++ b/src/servers/master.c
@@ -23,10 +23,10 @@
 #include "../communication.h"
 #include "../state.h"
 
-#include <sys/select.h>
 #include <sys/socket.h>
 #include <errno.h>
 #include <fcntl.h>
+#include <poll.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
@@ -34,6 +34,12 @@
 
 
 
+/**
+ * All poll(3p) events that are not for writing
+ */
+#define NON_WR_POLL_EVENTS  (POLLIN | POLLRDNORM | POLLRDBAND | POLLPRI | POLLERR | POLLHUP | POLLNVAL)
+
+
 /**
  * Extract headers from an inbound message and pass
  * them on to appropriate message handling function
@@ -115,24 +121,43 @@ static int dispatch_message(size_t conn, struct message* restrict msg)
  * Sets the file descriptor set that includes
  * the server socket and all connections
  * 
- * @param   fds  The file descritor set
- * @return       The highest set file descritor plus 1
+ * The file descriptor will be ordered as in
+ * the array `connections`, `socketfd` will
+ * be last.
+ * 
+ * @param   fds        Reference parameter for the array of file descriptors
+ * @param   fdn        Output parameter for the number of file descriptors
+ * @parma   fds_alloc  Reference parameter for the allocation size of `fds`, in elements
+ * @return             Zero on success, -1 on error
  */
-GCC_ONLY(__attribute__((nonnull)))
-static int update_fdset(fd_set* restrict fds)
+static int update_fdset(struct pollfd** restrict fds, nfds_t* restrict fdn, nfds_t* restrict fds_alloc)
 {
-  int fdmax = socketfd;
   size_t i;
-  FD_ZERO(fds);
-  FD_SET(socketfd, fds);
+  nfds_t j = 0;
+  
+  if (connections_used + 1 > *fds_alloc)
+    {
+      void* new = realloc(*fds, (connections_used + 1) * sizeof(**fds));
+      if (new == NULL)
+	return -1;
+      *fds = new;
+      *fds_alloc = connections_used + 1;
+    }
+  
   for (i = 0; i < connections_used; i++)
     if (connections[i] >= 0)
       {
-	FD_SET(connections[i], fds);
-	if (fdmax < connections[i])
-	  fdmax = connections[i];
+	(*fds)[j].fd = connections[i];
+	(*fds)[j].events = NON_WR_POLL_EVENTS;
+	j++;
       }
-  return fdmax + 1;
+  
+  (*fds)[j].fd = socketfd;
+  (*fds)[j].events = NON_WR_POLL_EVENTS;
+  j++;
+  
+  *fdn = j;
+  return 0;
 }
 
 
@@ -245,7 +270,7 @@ static int handle_connection(size_t conn)
       connections[conn] = -1;
       if (conn < connections_ptr)
 	connections_ptr = conn;
-      if (conn == connections_used)
+      while ((connections_used > 0) && (connections[connections_used - 1] < 0))
 	connections_used -= 1;
       message_destroy(msg);
       if (connection_closed(fd) < 0)
@@ -283,60 +308,77 @@ void disconnect_all(void)
  */
 int main_loop(void)
 {
-  fd_set fds_orig, fds_rd, fds_wr, fds_ex;
-  int i, r, update, fdn = update_fdset(&fds_orig);
+  struct pollfd* fds = NULL;
+  nfds_t i, fdn = 0, fds_alloc = 0;
+  int r, update, saved_errno;
   size_t j;
   
+  if (update_fdset(&fds, &fdn, &fds_alloc) < 0)
+    goto fail;
+  
   while (!reexec && !terminate)
     {
       if (connection)
 	{
 	  if ((connection == 1 ? disconnect() : reconnect()) < 0)
-	    return connection = 0, -1;
+	    {
+	      connection = 0;
+	      goto fail;
+	    }
 	  connection = 0;
 	}
       
-      memcpy(&fds_rd, &fds_orig, sizeof(fd_set));
-      memcpy(&fds_ex, &fds_orig, sizeof(fd_set));
-      
-      FD_ZERO(&fds_wr);
-      for (j = 0; j < connections_used; j++)
-	if ((connections[j] >= 0) && ring_have_more(outbound + j))
-	  FD_SET(connections[j], &fds_wr);
+      for (j = 0, i = 0; j < connections_used; j++)
+	if (connections[j] >= 0)
+	  {
+	    if (ring_have_more(outbound + j))
+	      fds[(size_t)i++ + j].events |= POLLOUT;
+	    else
+	      fds[(size_t)i++ + j].events &= ~POLLOUT;
+	  }
       
-      if (select(fdn, &fds_rd, &fds_wr, &fds_ex, NULL) < 0)
+      if (poll(fds, fdn, -1) < 0)
 	{
-	  if (errno == EINTR)
-	    continue;
-	  return -1;
+	  if (errno == EAGAIN)
+	    perror(argv0);
+	  else if (errno != EINTR)
+	    goto fail;
 	}
       
       update = 0;
       for (i = 0; i < fdn; i++)
 	{
-	  int do_read  = FD_ISSET(i, &fds_rd) || FD_ISSET(i, &fds_ex);
-	  int do_write = FD_ISSET(i, &fds_wr);
+	  int do_read  = fds[i].revents & NON_WR_POLL_EVENTS;
+	  int do_write = fds[i].revents & POLLOUT;
+	  int fd = fds[i].fd;
 	  if (!do_read && !do_write)
 	    continue;
 	  
-	  if (i == socketfd)
+	  if (fd == socketfd)
 	    r = handle_server();
 	  else
 	    {
-	      for (j = 0; connections[j] != i; j++);
+	      for (j = 0; connections[j] != fd; j++);
 	      r = do_read ? handle_connection(j) : 0;
 	    }
 	  
 	  if ((r >= 0) && do_write)
 	    r |= continue_send(j);
 	  if (r < 0)
-	    return -1;
+	    goto fail;
 	  update |= (r > 0);
 	}
       if (update)
-	update_fdset(&fds_orig);
+	if (update_fdset(&fds, &fdn, &fds_alloc) < 0)
+	  goto fail;
     }
   
+  free(fds);
   return 0;
+ fail:
+  saved_errno = errno;
+  free(fds);
+  errno = saved_errno;
+  return -1;
 }
 
-- 
cgit v1.2.3-70-g09d2