[dovecot-cvs] dovecot/src/login-common ssl-proxy-openssl.c,1.9,1.10

cras at procontrol.fi cras at procontrol.fi
Sat Mar 29 10:56:48 EET 2003


Update of /home/cvs/dovecot/src/login-common
In directory danu:/tmp/cvs-serv25009

Modified Files:
	ssl-proxy-openssl.c 
Log Message:
rewrite, maybe it works properly now.



Index: ssl-proxy-openssl.c
===================================================================
RCS file: /home/cvs/dovecot/src/login-common/ssl-proxy-openssl.c,v
retrieving revision 1.9
retrieving revision 1.10
diff -u -d -r1.9 -r1.10
--- ssl-proxy-openssl.c	4 Mar 2003 02:18:09 -0000	1.9
+++ ssl-proxy-openssl.c	29 Mar 2003 08:56:45 -0000	1.10
@@ -14,10 +14,11 @@
 #include <openssl/ssl.h>
 #include <openssl/err.h>
 
-enum ssl_state {
-	SSL_STATE_HANDSHAKE,
-	SSL_STATE_READ,
-	SSL_STATE_WRITE
+enum ssl_io_action {
+	SSL_ADD_INPUT,
+	SSL_REMOVE_INPUT,
+	SSL_ADD_OUTPUT,
+	SSL_REMOVE_OUTPUT
 };
 
 struct ssl_proxy {
@@ -25,112 +26,127 @@
 
 	SSL *ssl;
 	struct ip_addr ip;
-        enum ssl_state state;
+	int handshaked;
 
 	int fd_ssl, fd_plain;
-	struct io *io_ssl, *io_plain_read, *io_plain_write;
-	int io_ssl_dir;
+	struct io *io_ssl_read, *io_ssl_write, *io_plain_read, *io_plain_write;
 
 	unsigned char plainout_buf[1024];
-	unsigned int plainout_pos, plainout_size;
+	unsigned int plainout_size;
 
 	unsigned char sslout_buf[1024];
-	unsigned int sslout_pos, sslout_size;
+	unsigned int sslout_size;
 };
 
 static SSL_CTX *ssl_ctx;
 static struct hash_table *ssl_proxies;
 
-static void plain_read(struct ssl_proxy *proxy);
-static void plain_write(struct ssl_proxy *proxy);
-
+static void plain_read(void *context);
+static void plain_write(void *context);
+static void ssl_write(struct ssl_proxy *proxy);
+static void ssl_step(void *context);
 static int ssl_proxy_destroy(struct ssl_proxy *proxy);
-static void ssl_set_direction(struct ssl_proxy *proxy, int dir);
 
-static void plain_block_input(struct ssl_proxy *proxy, int block)
+static void ssl_set_io(struct ssl_proxy *proxy, enum ssl_io_action action)
 {
-	if (block) {
-		if (proxy->io_plain_read != NULL) {
-			io_remove(proxy->io_plain_read);
-			proxy->io_plain_read = NULL;
+	switch (action) {
+	case SSL_ADD_INPUT:
+		if (proxy->io_ssl_read != NULL)
+			break;
+		proxy->io_ssl_read = io_add(proxy->fd_ssl, IO_READ,
+					    ssl_step, proxy);
+		break;
+	case SSL_REMOVE_INPUT:
+		if (proxy->io_ssl_read != NULL) {
+			io_remove(proxy->io_ssl_read);
+			proxy->io_ssl_read = NULL;
 		}
-	} else {
-		if (proxy->io_plain_read == NULL) {
-			proxy->io_plain_read =
-				io_add(proxy->fd_plain, IO_READ,
-				       (io_callback_t *)plain_read, proxy);
+		break;
+	case SSL_ADD_OUTPUT:
+		if (proxy->io_ssl_write != NULL)
+			break;
+		proxy->io_ssl_write = io_add(proxy->fd_ssl, IO_WRITE,
+					    ssl_step, proxy);
+		break;
+	case SSL_REMOVE_OUTPUT:
+		if (proxy->io_ssl_write != NULL) {
+			io_remove(proxy->io_ssl_write);
+			proxy->io_ssl_write = NULL;
 		}
+		break;
 	}
 }
 
-static void ssl_block(struct ssl_proxy *proxy, int block)
+static void plain_block_input(struct ssl_proxy *proxy, int block)
 {
-	i_assert(proxy->state == SSL_STATE_READ);
-
 	if (block) {
-		if (proxy->io_ssl != NULL) {
-			io_remove(proxy->io_ssl);
-			proxy->io_ssl = NULL;
+		if (proxy->io_plain_read != NULL) {
+			io_remove(proxy->io_plain_read);
+			proxy->io_plain_read = NULL;
 		}
-
-		proxy->io_ssl_dir = -2;
 	} else {
-		proxy->io_ssl_dir = -1;
-		ssl_set_direction(proxy, IO_READ);
+		if (proxy->io_plain_read == NULL) {
+			proxy->io_plain_read = io_add(proxy->fd_plain, IO_READ,
+						      plain_read, proxy);
+		}
 	}
 }
 
-static void plain_read(struct ssl_proxy *proxy)
+static void plain_read(void *context)
 {
+	struct ssl_proxy *proxy = context;
 	ssize_t ret;
 
-	i_assert(proxy->sslout_size == 0);
-
-	ret = net_receive(proxy->fd_plain, proxy->sslout_buf,
-			  sizeof(proxy->sslout_buf));
-	if (ret < 0)
-		ssl_proxy_destroy(proxy);
-	else if (ret > 0) {
-		proxy->sslout_size = ret;
-		proxy->sslout_pos = 0;
-
-		proxy->state = SSL_STATE_WRITE;
-		ssl_set_direction(proxy, IO_WRITE);
-
+	if (proxy->sslout_size == sizeof(proxy->sslout_buf)) {
+		/* buffer full, block input until it's written */
 		plain_block_input(proxy, TRUE);
+		return;
+	}
+
+	while (proxy->sslout_size < sizeof(proxy->sslout_buf)) {
+		ret = net_receive(proxy->fd_plain,
+				  proxy->sslout_buf + proxy->sslout_size,
+				  sizeof(proxy->sslout_buf) -
+				  proxy->sslout_size);
+		if (ret <= 0) {
+			if (ret < 0)
+				ssl_proxy_destroy(proxy);
+			break;
+		} else {
+			proxy->sslout_size += ret;
+			ssl_write(proxy);
+		}
 	}
 }
 
-static void plain_write(struct ssl_proxy *proxy)
+static void plain_write(void *context)
 {
+	struct ssl_proxy *proxy = context;
 	ssize_t ret;
 
-	ret = net_transmit(proxy->fd_plain,
-			   proxy->plainout_buf + proxy->plainout_pos,
+	ret = net_transmit(proxy->fd_plain, proxy->plainout_buf,
 			   proxy->plainout_size);
 	if (ret < 0)
 		ssl_proxy_destroy(proxy);
 	else {
 		proxy->plainout_size -= ret;
-		proxy->plainout_pos += ret;
+		memmove(proxy->plainout_buf, proxy->plainout_buf + ret,
+			proxy->plainout_size);
 
 		if (proxy->plainout_size > 0) {
-			ssl_block(proxy, TRUE);
 			if (proxy->io_plain_write == NULL) {
 				proxy->io_plain_write =
 					io_add(proxy->fd_plain, IO_WRITE,
-					       (io_callback_t *)plain_write,
-					       proxy);
+					       plain_write, proxy);
 			}
 		} else {
-			proxy->plainout_pos = 0;
-			ssl_block(proxy, FALSE);
-
 			if (proxy->io_plain_write != NULL) {
 				io_remove(proxy->io_plain_write);
                                 proxy->io_plain_write = NULL;
 			}
 		}
+
+		ssl_set_io(proxy, SSL_ADD_INPUT);
 	}
 
 }
@@ -160,10 +176,10 @@
 
 	switch (err) {
 	case SSL_ERROR_WANT_READ:
-		ssl_set_direction(proxy, IO_READ);
+		ssl_set_io(proxy, SSL_ADD_INPUT);
 		break;
 	case SSL_ERROR_WANT_WRITE:
-		ssl_set_direction(proxy, IO_WRITE);
+		ssl_set_io(proxy, SSL_ADD_OUTPUT);
 		break;
 	case SSL_ERROR_SYSCALL:
 		/* eat up the error queue */
@@ -201,94 +217,77 @@
 	}
 }
 
-static void ssl_handshake_step(struct ssl_proxy *proxy)
+static void ssl_handshake(struct ssl_proxy *proxy)
 {
 	int ret;
 
 	ret = SSL_accept(proxy->ssl);
-	if (ret != 1) {
-		plain_block_input(proxy, TRUE);
+	if (ret != 1)
 		ssl_handle_error(proxy, ret, "SSL_accept()");
-	} else {
+	else {
+		proxy->handshaked = TRUE;
+
+		ssl_set_io(proxy, SSL_ADD_INPUT);
 		plain_block_input(proxy, FALSE);
-		ssl_set_direction(proxy, IO_READ);
-		proxy->state = SSL_STATE_READ;
 	}
 }
 
-static void ssl_read_step(struct ssl_proxy *proxy)
+static void ssl_read(struct ssl_proxy *proxy)
 {
 	int ret;
 
-	i_assert(proxy->plainout_size == 0);
-
-	ret = SSL_read(proxy->ssl, proxy->plainout_buf,
-		       sizeof(proxy->plainout_buf));
-	if (ret <= 0)
-		ssl_handle_error(proxy, ret, "SSL_read()");
-	else {
-		plain_block_input(proxy, FALSE);
-		ssl_set_direction(proxy, IO_READ);
-
-		proxy->plainout_pos = 0;
-		proxy->plainout_size = ret;
-		plain_write(proxy);
+	while (proxy->plainout_size < sizeof(proxy->plainout_buf)) {
+		ret = SSL_read(proxy->ssl,
+			       proxy->plainout_buf + proxy->plainout_size,
+			       sizeof(proxy->plainout_buf) -
+			       proxy->plainout_size);
+		if (ret <= 0) {
+			ssl_handle_error(proxy, ret, "SSL_read()");
+			break;
+		} else {
+			proxy->plainout_size += ret;
+			plain_write(proxy);
+		}
 	}
 }
 
-static void ssl_write_step(struct ssl_proxy *proxy)
+static void ssl_write(struct ssl_proxy *proxy)
 {
 	int ret;
 
-	ret = SSL_write(proxy->ssl, proxy->sslout_buf + proxy->sslout_pos,
-			proxy->sslout_size);
+	ret = SSL_write(proxy->ssl, proxy->sslout_buf, proxy->sslout_size);
 	if (ret <= 0)
 		ssl_handle_error(proxy, ret, "SSL_write()");
 	else {
 		proxy->sslout_size -= ret;
-		proxy->sslout_pos += ret;
+		memmove(proxy->sslout_buf, proxy->sslout_buf + ret,
+			proxy->sslout_size);
 
-		if (proxy->sslout_size > 0) {
-			plain_block_input(proxy, TRUE);
-			ssl_set_direction(proxy, IO_WRITE);
-			proxy->state = SSL_STATE_WRITE;
-		} else {
-			plain_block_input(proxy, FALSE);
-			ssl_set_direction(proxy, IO_READ);
-			proxy->state = SSL_STATE_READ;
-			proxy->sslout_pos = 0;
-		}
+		ssl_set_io(proxy, proxy->sslout_size > 0 ?
+			   SSL_ADD_OUTPUT : SSL_REMOVE_OUTPUT);
+		plain_block_input(proxy, FALSE);
 	}
 }
 
 static void ssl_step(void *context)
 {
-        struct ssl_proxy *proxy = context;
+	struct ssl_proxy *proxy = context;
 
-	switch (proxy->state) {
-	case SSL_STATE_HANDSHAKE:
-		ssl_handshake_step(proxy);
-		break;
-	case SSL_STATE_READ:
-		ssl_read_step(proxy);
-		break;
-	case SSL_STATE_WRITE:
-		ssl_write_step(proxy);
-		break;
+	if (!proxy->handshaked) {
+		ssl_handshake(proxy);
+		if (!proxy->handshaked)
+			return;
 	}
-}
-
-static void ssl_set_direction(struct ssl_proxy *proxy, int dir)
-{
-	i_assert(proxy->io_ssl_dir != -2);
 
-	if (proxy->io_ssl_dir == dir)
-		return;
+	if (proxy->plainout_size == sizeof(proxy->plainout_buf))
+		ssl_set_io(proxy, SSL_REMOVE_INPUT);
+	else
+		ssl_read(proxy);
 
-	if (proxy->io_ssl != NULL)
-		io_remove(proxy->io_ssl);
-	proxy->io_ssl = io_add(proxy->fd_ssl, dir, ssl_step, proxy);
-        proxy->io_ssl_dir = dir;
+	if (proxy->sslout_size == 0)
+		ssl_set_io(proxy, SSL_REMOVE_OUTPUT);
+	else
+		ssl_write(proxy);
 }
 
 int ssl_proxy_new(int fd, struct ip_addr *ip)
@@ -329,11 +328,8 @@
 	proxy->fd_plain = sfd[0];
 	proxy->ip = *ip;
 
-	proxy->state = SSL_STATE_HANDSHAKE;
-	ssl_set_direction(proxy, IO_READ);
-
 	proxy->refcount++;
-	ssl_handshake_step(proxy);
+	ssl_handshake(proxy);
 	if (!ssl_proxy_destroy(proxy)) {
 		/* handshake failed. return the disconnected socket anyway
 		   so the caller doesn't try to use the old closed fd */
@@ -357,8 +353,10 @@
 	(void)net_disconnect(proxy->fd_ssl);
 	(void)net_disconnect(proxy->fd_plain);
 
-	if (proxy->io_ssl != NULL)
-		io_remove(proxy->io_ssl);
+	if (proxy->io_ssl_read != NULL)
+		io_remove(proxy->io_ssl_read);
+	if (proxy->io_ssl_write != NULL)
+		io_remove(proxy->io_ssl_write);
 	if (proxy->io_plain_read != NULL)
 		io_remove(proxy->io_plain_read);
 	if (proxy->io_plain_write != NULL)




More information about the dovecot-cvs mailing list