aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/libmdsserver/mds-message.c186
-rw-r--r--src/libmdsserver/mds-message.h26
2 files changed, 199 insertions, 13 deletions
diff --git a/src/libmdsserver/mds-message.c b/src/libmdsserver/mds-message.c
index 94e9e63..ca5acb4 100644
--- a/src/libmdsserver/mds-message.c
+++ b/src/libmdsserver/mds-message.c
@@ -19,6 +19,8 @@
#include <stdlib.h>
#include <string.h>
+#include <errno.h>
+#include <unistd.h>
/**
@@ -35,10 +37,12 @@ int mds_message_initialise(mds_message_t* this)
this->header_count = 0;
this->payload = NULL;
this->payload_size = 0;
+ this->payload_ptr = 0;
this->buffer = NULL;
this->buffer_size = 128;
this->buffer_ptr = 0;
this->buffer = malloc(this->buffer_size * sizeof(char));
+ this->stage = 0;
if (this->buffer == NULL)
return -1;
return 0;
@@ -75,11 +79,161 @@ void mds_message_destroy(mds_message_t* this)
*
* @param this Memory slot in which to store the new message
* @param fd The file descriptor
- * @return Non-zero on error, errno will be set accordingly.
- * Destroy the message on error.
+ * @return Non-zero on error or interruption, errno will be
+ * set accordingly. Destroy the message on error,
+ * be aware that the reading could have been
+ * interrupted by a signal rather than canonical error.
+ * If -2 is returned errno will not have been set,
+ * -2 indicates that the message is malformated,
+ * which is a state that cannot be recovered from.
*/
int mds_message_read(mds_message_t* this, int fd)
{
+ size_t header_commit_buffer = 0;
+
+ if (this->stage == 2)
+ {
+ if (this->headers != NULL)
+ {
+ size_t i;
+ for (i = 0; i < this->header_count; i++)
+ if (this->headers[i] != NULL)
+ free(this->headers[i]);
+ free(this->headers);
+ }
+ this->header_count = 0;
+
+ if (this->payload != NULL)
+ free(this->payload);
+ this->payload_size = 0;
+ this->payload_ptr = 0;
+
+ this->stage = 0;
+ }
+
+ for (;;)
+ {
+ size_t n, i;
+ ssize_t got;
+ char* p;
+
+ if (this->stage == 0)
+ while ((p = memchr(this->buffer, '\n', this->buffer_ptr)) != NULL)
+ {
+ size_t len = (size_t)(p - this->buffer) + 1;
+ char* header;
+
+ if (len == 0)
+ {
+ memmove(this->buffer, this->buffer + 1, this->buffer_ptr -= 1);
+ for (i = 0; i < this->header_count; i++)
+ if (strstr(this->headers[i], "Length: ") == this->headers[i])
+ {
+ header = this->headers[i] + strlen("Length: ");
+ this->payload_size = (size_t)atoll(header);
+ for (; *header; header++)
+ if ((*header < '0') || ('9' < *header))
+ return -2; /* Malformated value, enters unrecoverable state. */
+ break;
+ }
+ this->stage = 1;
+
+ if (this->payload_size > 0)
+ {
+ this->payload = malloc(this->payload_size * sizeof(char));
+ if (this->payload == NULL)
+ return -1;
+ }
+ break;
+ }
+
+ if (header_commit_buffer == 0)
+ {
+ header_commit_buffer = 8;
+ if (this->headers == NULL)
+ {
+ this->headers = malloc(header_commit_buffer * sizeof(char*));
+ if (this->headers == NULL)
+ return -1;
+ }
+ else
+ {
+ char** old_headers = this->headers;
+ n = this->header_count + header_commit_buffer;
+ this->headers = realloc(this->headers, n * sizeof(char*));
+ if (this->headers == NULL)
+ {
+ this->headers = old_headers;
+ return -1;
+ }
+ }
+ }
+
+ header = malloc(len * sizeof(char));
+ if (header == NULL)
+ return -1;
+ memcpy(header, this->buffer, len);
+ header[len - 1] = '\0';
+ memmove(this->buffer, this->buffer + len, this->buffer_ptr -= len);
+ if ((p = memchr(header, ':', len)) == NULL)
+ {
+ /* Buck you, rawmemchr should not segfault the program. */
+ free(header);
+ return -2;
+ }
+ if (p[1] != ' ') /* Also an invalid format. */
+ {
+ free(header);
+ return -2;
+ }
+ this->headers[this->header_count++] = header;
+ header_commit_buffer -= 1;
+ }
+
+ if ((this->stage == 1) && (this->payload_ptr > 0))
+ {
+ size_t need = this->payload_size - this->payload_ptr;
+ if (this->buffer_ptr <= need)
+ memcpy(this->payload + this->payload_ptr, this->buffer, this->buffer_ptr);
+ else
+ {
+ memcpy(this->payload + this->payload_ptr, this->buffer, need);
+ memmove(this->buffer, this->buffer + need, this->buffer_ptr - need);
+ }
+ this->payload_ptr += this->buffer_ptr;
+ this->buffer_ptr = 0;
+ if (this->payload_ptr == this->payload_size)
+ {
+ this->stage = 3;
+ break;
+ }
+ }
+
+ n = this->buffer_size - this->buffer_ptr;
+
+ if (n < 128)
+ {
+ char* old_buffer = this->buffer;
+ this->buffer_size <<= 1;
+ this->buffer = realloc(this->buffer, this->buffer_size * sizeof(char));
+ if (this->buffer == NULL)
+ {
+ this->buffer = old_buffer;
+ this->buffer_size >>= 1;
+ return -1;
+ }
+ n = this->buffer_size - this->buffer_ptr;
+ }
+
+ errno = 0;
+ got = read(fd, this->buffer + this->buffer_ptr, n);
+ if (errno)
+ {
+ this->buffer_ptr += (size_t)(got < 0 ? 0 : got);
+ return -1;
+ }
+ }
+
return 0;
}
@@ -94,11 +248,14 @@ int mds_message_read(mds_message_t* this, int fd)
*/
size_t mds_message_marshal_size(mds_message_t* this, int include_buffer)
{
- size_t rc = (include_buffer ? 3 : 2) + this->header_count + this->payload_size;
+ size_t rc = this->header_count + this->payload_size;
size_t i;
for (i = 0; i < this->header_count; i++)
rc += strlen(this->headers[i]);
- return rc * sizeof(char);
+ rc *= sizeof(char);
+ rc += (include_buffer ? 4 : 2) * sizeof(size_t);
+ rc += (include_buffer ? 1 : 0) * sizeof(int);
+ return rc;
}
@@ -118,8 +275,17 @@ void mds_message_marshal(mds_message_t* this, char* data, int include_buffer)
((size_t*)data)[0] = this->header_count;
((size_t*)data)[1] = this->payload_size;
if (include_buffer)
- ((size_t*)data)[2] = this->buffer_ptr;
- data += (include_buffer ? 3 : 2) * sizeof(size_t) / sizeof(char);
+ {
+ ((size_t*)data)[2] = this->payload_ptr;
+ ((size_t*)data)[3] = this->buffer_ptr;
+ }
+ data += (include_buffer ? 4 : 2) * sizeof(size_t) / sizeof(char);
+
+ if (include_buffer)
+ {
+ ((int*)data)[0] = this->stage;
+ data += sizeof(int) / sizeof(char);
+ }
for (i = 0; i < this->header_count; i++)
{
@@ -153,13 +319,17 @@ int mds_message_unmarshal(mds_message_t* this, char* data)
header_count = ((size_t*)data)[0];
this->header_count = 0;
this->payload_size = ((size_t*)data)[1];
- this->buffer_ptr = ((size_t*)data)[2];
+ this->payload_ptr = ((size_t*)data)[2];
+ this->buffer_ptr = ((size_t*)data)[3];
this->buffer_size = this->buffer_ptr;
this->headers = NULL;
this->payload = NULL;
this->buffer = NULL;
- data += 3 * sizeof(size_t) / sizeof(char);
+ data += 4 * sizeof(size_t) / sizeof(char);
+
+ this->stage = ((int*)data)[0];
+ data += sizeof(int) / sizeof(char);
/* To 2-power-multiple of 128 bytes. */
this->buffer_size >>= 7;
diff --git a/src/libmdsserver/mds-message.h b/src/libmdsserver/mds-message.h
index f04a8c6..5cc1916 100644
--- a/src/libmdsserver/mds-message.h
+++ b/src/libmdsserver/mds-message.h
@@ -33,6 +33,7 @@ typedef struct mds_message
* name and its associated value, joined by ": ". A header
* cannot be `NULL` (unless its memory allocation failed,)
* but `headers` itself is NULL if there are not headers.
+ * The "Length" should be included in this list.
*/
char** headers;
@@ -52,20 +53,30 @@ typedef struct mds_message
size_t payload_size;
/**
- * Internal buffer for the reading function
+ * How much of the payload that has been stored (internal data)
+ */
+ size_t payload_ptr;
+
+ /**
+ * Internal buffer for the reading function (internal data)
*/
char* buffer;
/**
- * The size allocated to `buffer`
+ * The size allocated to `buffer` (internal data)
*/
size_t buffer_size;
/**
- * The number of bytes used in `buffer`
+ * The number of bytes used in `buffer` (internal data)
*/
size_t buffer_ptr;
+ /**
+ * 0 while reading headers, 1 while reading payload, and 2 when done (internal data)
+ */
+ int stage;
+
} mds_message_t;
@@ -92,8 +103,13 @@ void mds_message_destroy(mds_message_t* this);
*
* @param this Memory slot in which to store the new message
* @param fd The file descriptor
- * @return Non-zero on error, errno will be set accordingly.
- * Destroy the message on error.
+ * @return Non-zero on error or interruption, errno will be
+ * set accordingly. Destroy the message on error,
+ * be aware that the reading could have been
+ * interrupted by a signal rather than canonical error.
+ * If -2 is returned errno will not have been set,
+ * -2 indicates that the message is malformated,
+ * which is a state that cannot be recovered from.
*/
int mds_message_read(mds_message_t* this, int fd);