aboutsummaryrefslogblamecommitdiffstats
path: root/libaxl_connect.c
blob: ee4edd8e1186b498bf86b48f19f42967fbcd5589 (plain) (tree)






































































































































































































































































                                                                                                                               
/* See LICENSE file for copyright and license details. */
#include "common.h"

enum {
      FamilyInternet  = 0,
      FamilyDECnet    = 1,
      FamilyChaos     = 2,
      FamilyInternet6 = 6,
      FamilyLocal     = 256
};

static char *
path_in_home(const char *filename)
{
	const char *home;
	char *ret;

	home = getenv("HOME");
	if (!home || !*home) { /* TODO */
		abort();
	}

	ret = liberror_malloc(strlen(home) + strlen(filename) + 2);
	if (!ret)
		return NULL;

	stpcpy(stpcpy(stpcpy(ret, home), "/"), filename);

	return ret;
}

static char *
get_auth_file(int *freep)
{
	char *xauthfile = getenv("XAUTHORITY");
	if (!xauthfile || !*xauthfile) {
		xauthfile = path_in_home(".Xauthority");
		*freep = 1;
	} else {
		*freep = 0;
	}
	return xauthfile;
}

static int
next_auth(int authfd, char **authbufp, size_t *bufsizep, size_t *lenp, size_t *havep)
{
	ssize_t r;
	size_t got = *havep, need = 4;
	int stage;
	void *new;

	*lenp = 0;

	for (stage = 0; stage < 4; stage++) {
		while (got < need) {
			if (need > *bufsizep) {
				new = liberror_realloc(*authbufp, need);
				if (!new)
					return -1;
				*authbufp = new;
				*bufsizep = need;
			}
			r = read(authfd, &(*authbufp)[got], *bufsizep - got);
			if (r < 0) {
				liberror_save_backtrace(NULL);
				liberror_set_error_errno(strerror(errno), "read", errno);
				return -1;
			} else if (!r) {
				return 0;
			}
			got += (size_t)r;
		}
		need += (size_t)ntohs(*(uint16_t *)&(*authbufp)[need - 2]) + (stage < 3 ? 2 : 0);
	}

	*lenp = need;
	if (*havep > need)
		*havep -= need;
	else
		*havep = got - need;

	return 0;
}

static int
get_auth(const char *xauthfile, int sockfd, const char *host, const char *protocol, int display,
         char **authnamep, size_t *authnamelenp, char **authdatap, size_t *authdatalenp, char **authbufp)
{
	int authfd, family, saved_errno;
	char hostname[HOST_NAME_MAX + 1], number[2 + 3 * sizeof(int)];
	struct sockaddr_storage sockaddr;
	socklen_t sockaddrlen = (socklen_t)sizeof(sockaddr);
	size_t bufsize = 128, len, have = 0, off, numberlen, hostnamelen;
	uint16_t partlen;

	(void) host;
	(void) protocol;

	*authnamep = *authdatap = *authbufp = NULL;
	*authnamelenp = *authdatalenp = 0;

	numberlen = (size_t)sprintf(number, "%i", display);

	if (gethostname(hostname, HOST_NAME_MAX))
		stpcpy(hostname, "localhost");
	hostnamelen = strlen(hostname);

	if (getpeername(sockfd, (void *)&sockaddr, &sockaddrlen)) {
		liberror_save_backtrace(NULL);
		liberror_set_error_errno(strerror(errno), "getsockname", errno);
		return -1;
	}

	switch (sockaddr.ss_family) {
	case AF_LOCAL:
		family = FamilyLocal;
		break;
	case AF_INET:
		family = FamilyInternet; /* TODO */
		return 0;
	case AF_INET6:
		family = FamilyInternet6; /* TODO */
		return 0;
	default:
		return 0;
	}

	*authbufp = liberror_malloc(bufsize);
	if (!*authbufp)
		return -1;

	authfd = open(xauthfile, O_RDONLY);
	if (authfd < 0 && errno != ENOENT) {
		liberror_save_backtrace(NULL);
		liberror_set_error_errno(strerror(errno), "open", errno);
		return -1;
	} else if (authfd < 0) {
		return 0;
	}

	for (;; memmove(*authbufp, &(*authbufp)[len], have)) {
		if (next_auth(authfd, authbufp, &bufsize, &len, &have)) {
			liberror_save_backtrace(NULL);
			liberror_set_error_errno(strerror(errno), "read", errno);
			saved_errno = errno;
			close(authfd);
			errno = saved_errno;
			return -1;
		} else if (!len) {
			break;
		}

		if (*(uint16_t *)&(*authbufp)[0] != htons(family))
			continue;
		if (*(uint16_t *)&(*authbufp)[2] != htons((uint16_t)hostnamelen))
			continue;
		if (memcmp(&(*authbufp)[4], hostname, hostnamelen))
			continue;
		off = 4 + (size_t)hostnamelen;
		partlen = ntohs(*(uint16_t *)&(*authbufp)[off]);
		off += 2;
		if (partlen != numberlen)
			continue;
		if (memcmp(&(*authbufp)[off], number, numberlen))
			continue;
		off += numberlen;

		*authnamelenp = (size_t)ntohs(*(uint16_t *)&(*authbufp)[off]);
		off += 2;
		*authnamep = &(*authbufp)[off];
		off += *authnamelenp;
		*authdatalenp = (size_t)ntohs(*(uint16_t *)&(*authbufp)[off]);
		off += 2;
		*authdatap = &(*authbufp)[off];

		break;
	}

	close(authfd);
	return 0;
}

LIBAXL_CONNECTION *
libaxl_connect(const char *restrict display, char **restrict reasonp)
{
	struct liberror_state error_state;
	LIBAXL_CONNECTION *conn = NULL;
	LIBAXL_CONTEXT *ctx = NULL;
	int xauthfile_free = 0, dispnum, screen, major, minor, r;
	char *xauthfile = NULL, *host = NULL, *protocol = NULL;
	char *authname = NULL, *authdata = NULL, *authbuf = NULL;
	size_t authnamelen, authdatalen;
	int saved_errno = errno;

	if (reasonp)
		*reasonp = NULL;

	r = libaxl_parse_display(display, &host, &protocol, &dispnum, &screen);
	if (r)
		goto fail;

	xauthfile = get_auth_file(&xauthfile_free);
	if (!xauthfile)
		goto fail;

	conn = libaxl_connect_without_handshake(host, protocol, dispnum, screen);
	if (!conn)
		goto fail;

	ctx = libaxl_context_create(conn);
	if (!ctx)
		goto fail;

	if (get_auth(xauthfile, conn->fd, host, protocol, dispnum, &authname, &authnamelen, &authdata, &authdatalen, &authbuf))
		goto fail;

	r = libaxl_send_handshake(ctx, authname, authnamelen, authdata, authdatalen, MSG_NOSIGNAL);
	if (r) {
		if (r == LIBAXL_ERROR_SYSTEM && errno == EINPROGRESS) {
			for (;;) {
				r = libaxl_flush(conn, MSG_NOSIGNAL);
				if (r != LIBAXL_ERROR_SYSTEM || errno != EINPROGRESS)
					break;
				liberror_pop_error();
			}
		}
		if (r)
			goto fail;
		liberror_pop_error();
	}

	errno = saved_errno;

	r = libaxl_receive_handshake(ctx, &major, &minor, reasonp, MSG_NOSIGNAL);
	switch (r) {
	case LIBAXL_HANDSHAKE_FAILED: /* TODO */
	case LIBAXL_HANDSHAKE_AUTHENTICATE: /* TODO */
		abort();
		break;

	case LIBAXL_HANDSHAKE_SUCCESS:
		break;

	default:
		goto fail;
	}

	libaxl_context_free(ctx);
	return conn;

fail:
	liberror_start(&error_state);
	if (xauthfile_free)
		free(xauthfile);
	free(authbuf);
	free(host);
	free(protocol);
	libaxl_context_free(ctx);
	libaxl_close(conn);
	liberror_end(&error_state);
	return NULL;
}