diff options
Diffstat (limited to 'src/libmdsserver')
-rw-r--r-- | src/libmdsserver/mds-message.c | 186 | ||||
-rw-r--r-- | src/libmdsserver/mds-message.h | 26 |
2 files changed, 199 insertions, 13 deletions
diff --git a/src/libmdsserver/mds-message.c b/src/libmdsserver/mds-message.c index 94e9e63..ca5acb4 100644 --- a/src/libmdsserver/mds-message.c +++ b/src/libmdsserver/mds-message.c @@ -19,6 +19,8 @@ #include <stdlib.h> #include <string.h> +#include <errno.h> +#include <unistd.h> /** @@ -35,10 +37,12 @@ int mds_message_initialise(mds_message_t* this) this->header_count = 0; this->payload = NULL; this->payload_size = 0; + this->payload_ptr = 0; this->buffer = NULL; this->buffer_size = 128; this->buffer_ptr = 0; this->buffer = malloc(this->buffer_size * sizeof(char)); + this->stage = 0; if (this->buffer == NULL) return -1; return 0; @@ -75,11 +79,161 @@ void mds_message_destroy(mds_message_t* this) * * @param this Memory slot in which to store the new message * @param fd The file descriptor - * @return Non-zero on error, errno will be set accordingly. - * Destroy the message on error. + * @return Non-zero on error or interruption, errno will be + * set accordingly. Destroy the message on error, + * be aware that the reading could have been + * interrupted by a signal rather than canonical error. + * If -2 is returned errno will not have been set, + * -2 indicates that the message is malformated, + * which is a state that cannot be recovered from. */ int mds_message_read(mds_message_t* this, int fd) { + size_t header_commit_buffer = 0; + + if (this->stage == 2) + { + if (this->headers != NULL) + { + size_t i; + for (i = 0; i < this->header_count; i++) + if (this->headers[i] != NULL) + free(this->headers[i]); + free(this->headers); + } + this->header_count = 0; + + if (this->payload != NULL) + free(this->payload); + this->payload_size = 0; + this->payload_ptr = 0; + + this->stage = 0; + } + + for (;;) + { + size_t n, i; + ssize_t got; + char* p; + + if (this->stage == 0) + while ((p = memchr(this->buffer, '\n', this->buffer_ptr)) != NULL) + { + size_t len = (size_t)(p - this->buffer) + 1; + char* header; + + if (len == 0) + { + memmove(this->buffer, this->buffer + 1, this->buffer_ptr -= 1); + for (i = 0; i < this->header_count; i++) + if (strstr(this->headers[i], "Length: ") == this->headers[i]) + { + header = this->headers[i] + strlen("Length: "); + this->payload_size = (size_t)atoll(header); + for (; *header; header++) + if ((*header < '0') || ('9' < *header)) + return -2; /* Malformated value, enters unrecoverable state. */ + break; + } + this->stage = 1; + + if (this->payload_size > 0) + { + this->payload = malloc(this->payload_size * sizeof(char)); + if (this->payload == NULL) + return -1; + } + break; + } + + if (header_commit_buffer == 0) + { + header_commit_buffer = 8; + if (this->headers == NULL) + { + this->headers = malloc(header_commit_buffer * sizeof(char*)); + if (this->headers == NULL) + return -1; + } + else + { + char** old_headers = this->headers; + n = this->header_count + header_commit_buffer; + this->headers = realloc(this->headers, n * sizeof(char*)); + if (this->headers == NULL) + { + this->headers = old_headers; + return -1; + } + } + } + + header = malloc(len * sizeof(char)); + if (header == NULL) + return -1; + memcpy(header, this->buffer, len); + header[len - 1] = '\0'; + memmove(this->buffer, this->buffer + len, this->buffer_ptr -= len); + if ((p = memchr(header, ':', len)) == NULL) + { + /* Buck you, rawmemchr should not segfault the program. */ + free(header); + return -2; + } + if (p[1] != ' ') /* Also an invalid format. */ + { + free(header); + return -2; + } + this->headers[this->header_count++] = header; + header_commit_buffer -= 1; + } + + if ((this->stage == 1) && (this->payload_ptr > 0)) + { + size_t need = this->payload_size - this->payload_ptr; + if (this->buffer_ptr <= need) + memcpy(this->payload + this->payload_ptr, this->buffer, this->buffer_ptr); + else + { + memcpy(this->payload + this->payload_ptr, this->buffer, need); + memmove(this->buffer, this->buffer + need, this->buffer_ptr - need); + } + this->payload_ptr += this->buffer_ptr; + this->buffer_ptr = 0; + if (this->payload_ptr == this->payload_size) + { + this->stage = 3; + break; + } + } + + n = this->buffer_size - this->buffer_ptr; + + if (n < 128) + { + char* old_buffer = this->buffer; + this->buffer_size <<= 1; + this->buffer = realloc(this->buffer, this->buffer_size * sizeof(char)); + if (this->buffer == NULL) + { + this->buffer = old_buffer; + this->buffer_size >>= 1; + return -1; + } + n = this->buffer_size - this->buffer_ptr; + } + + errno = 0; + got = read(fd, this->buffer + this->buffer_ptr, n); + if (errno) + { + this->buffer_ptr += (size_t)(got < 0 ? 0 : got); + return -1; + } + } + return 0; } @@ -94,11 +248,14 @@ int mds_message_read(mds_message_t* this, int fd) */ size_t mds_message_marshal_size(mds_message_t* this, int include_buffer) { - size_t rc = (include_buffer ? 3 : 2) + this->header_count + this->payload_size; + size_t rc = this->header_count + this->payload_size; size_t i; for (i = 0; i < this->header_count; i++) rc += strlen(this->headers[i]); - return rc * sizeof(char); + rc *= sizeof(char); + rc += (include_buffer ? 4 : 2) * sizeof(size_t); + rc += (include_buffer ? 1 : 0) * sizeof(int); + return rc; } @@ -118,8 +275,17 @@ void mds_message_marshal(mds_message_t* this, char* data, int include_buffer) ((size_t*)data)[0] = this->header_count; ((size_t*)data)[1] = this->payload_size; if (include_buffer) - ((size_t*)data)[2] = this->buffer_ptr; - data += (include_buffer ? 3 : 2) * sizeof(size_t) / sizeof(char); + { + ((size_t*)data)[2] = this->payload_ptr; + ((size_t*)data)[3] = this->buffer_ptr; + } + data += (include_buffer ? 4 : 2) * sizeof(size_t) / sizeof(char); + + if (include_buffer) + { + ((int*)data)[0] = this->stage; + data += sizeof(int) / sizeof(char); + } for (i = 0; i < this->header_count; i++) { @@ -153,13 +319,17 @@ int mds_message_unmarshal(mds_message_t* this, char* data) header_count = ((size_t*)data)[0]; this->header_count = 0; this->payload_size = ((size_t*)data)[1]; - this->buffer_ptr = ((size_t*)data)[2]; + this->payload_ptr = ((size_t*)data)[2]; + this->buffer_ptr = ((size_t*)data)[3]; this->buffer_size = this->buffer_ptr; this->headers = NULL; this->payload = NULL; this->buffer = NULL; - data += 3 * sizeof(size_t) / sizeof(char); + data += 4 * sizeof(size_t) / sizeof(char); + + this->stage = ((int*)data)[0]; + data += sizeof(int) / sizeof(char); /* To 2-power-multiple of 128 bytes. */ this->buffer_size >>= 7; diff --git a/src/libmdsserver/mds-message.h b/src/libmdsserver/mds-message.h index f04a8c6..5cc1916 100644 --- a/src/libmdsserver/mds-message.h +++ b/src/libmdsserver/mds-message.h @@ -33,6 +33,7 @@ typedef struct mds_message * name and its associated value, joined by ": ". A header * cannot be `NULL` (unless its memory allocation failed,) * but `headers` itself is NULL if there are not headers. + * The "Length" should be included in this list. */ char** headers; @@ -52,20 +53,30 @@ typedef struct mds_message size_t payload_size; /** - * Internal buffer for the reading function + * How much of the payload that has been stored (internal data) + */ + size_t payload_ptr; + + /** + * Internal buffer for the reading function (internal data) */ char* buffer; /** - * The size allocated to `buffer` + * The size allocated to `buffer` (internal data) */ size_t buffer_size; /** - * The number of bytes used in `buffer` + * The number of bytes used in `buffer` (internal data) */ size_t buffer_ptr; + /** + * 0 while reading headers, 1 while reading payload, and 2 when done (internal data) + */ + int stage; + } mds_message_t; @@ -92,8 +103,13 @@ void mds_message_destroy(mds_message_t* this); * * @param this Memory slot in which to store the new message * @param fd The file descriptor - * @return Non-zero on error, errno will be set accordingly. - * Destroy the message on error. + * @return Non-zero on error or interruption, errno will be + * set accordingly. Destroy the message on error, + * be aware that the reading could have been + * interrupted by a signal rather than canonical error. + * If -2 is returned errno will not have been set, + * -2 indicates that the message is malformated, + * which is a state that cannot be recovered from. */ int mds_message_read(mds_message_t* this, int fd); |