Procházet zdrojové kódy

Fixed SSLPlugin handling

z3APA3A před 5 roky
rodič
revize
05bc297ea7
3 změnil soubory, kde provedl 57 přidání a 43 odebrání
  1. 14 14
      src/plugins/SSLPlugin/my_ssl.c
  2. 36 26
      src/plugins/SSLPlugin/ssl_plugin.c
  3. 7 3
      src/proxy.c

+ 14 - 14
src/plugins/SSLPlugin/my_ssl.c

@@ -52,14 +52,14 @@ static size_t bin2hex (const unsigned char* bin, size_t bin_length, char* str, s
 	char *p;
 	size_t i;
 	
-	if ( str_length < ( bin_length+1) ) 
+	if ( str_length < ( (bin_length*2)+1) ) 
 		return 0; 
 
 	p = str; 
 	for ( i=0; i < bin_length; ++i )  
 	{ 
-		*p++ = hexMap[*bin >> 4];  
-		*p++ = hexMap[*bin & 0xf]; 
+		*p++ = hexMap[(*(unsigned char *)bin) >> 4];  
+		*p++ = hexMap[(*(unsigned char *)bin) & 0xf]; 
 		++bin;
 	} 
 	
@@ -115,10 +115,18 @@ SSL_CERT ssl_copy_cert(SSL_CERT cert)
 	unsigned char p2[] = "3proxy";
 	unsigned char p3[] = "3proxy CA";
 
-	char hash_name_sha1[sizeof(src_cert->sha1_hash)*2 + 1];
-	char cache_name[200];
+	int hash_size = 20;
+	char hash_sha1[20];
+	char hash_name_sha1[(20*2) + 1];
+	char cache_name[256];
 
-	bin2hex(src_cert->sha1_hash, sizeof(src_cert->sha1_hash), hash_name_sha1, sizeof(hash_name_sha1));
+	err = X509_digest(src_cert, EVP_sha1(), hash_sha1, NULL);
+	if(!err){
+		X509_free(dst_cert);
+		return NULL;
+	}
+
+	bin2hex(hash_sha1, 20, hash_name_sha1, sizeof(hash_name_sha1));
 	sprintf(cache_name, "%s%s.pem", cert_path, hash_name_sha1);
 	/* check if certificate is already cached */
 	fcache = fopen(cache_name, "rb");
@@ -153,19 +161,11 @@ SSL_CERT ssl_copy_cert(SSL_CERT cert)
 	}
 
 
-	/* Its self signed so set the issuer name to be the same as the
- 	 * subject.
-	 */
 	err = X509_set_issuer_name(dst_cert, name);
 	if(!err){
 		X509_free(dst_cert);
 		return NULL;
 	}
-	err = X509_digest(dst_cert, EVP_sha1(), dst_cert->sha1_hash, NULL);
-	if(!err){
-		X509_free(dst_cert);
-		return NULL;
-	}
 	err = X509_sign(dst_cert, CA_key, EVP_sha256());
 	if(!err){
 		X509_free(dst_cert);

+ 36 - 26
src/plugins/SSLPlugin/ssl_plugin.c

@@ -56,6 +56,7 @@ struct SSLqueue {
 */
 static struct SSLqueue *searchSSL(SOCKET s){
 	struct SSLqueue *sslq = NULL;
+
 	pthread_mutex_lock(&ssl_mutex);
 	for(sslq = SSLq; sslq; sslq = sslq->next)
 		if(sslq->s == s) break;
@@ -65,19 +66,21 @@ static struct SSLqueue *searchSSL(SOCKET s){
 
 static void addSSL(SOCKET s, SSL_CERT cert, SSL_CONN conn, struct clientparam* param){
 	struct SSLqueue *sslq;
+
 	sslq = (struct SSLqueue *) malloc(sizeof(struct SSLqueue));
 	sslq->s = s;
 	sslq->cert = cert;
 	sslq->conn = conn;
+	sslq->param = param;
 	pthread_mutex_lock(&ssl_mutex);
 	sslq->next = SSLq;
-	sslq->param = param;
 	SSLq = sslq;
 	pthread_mutex_unlock(&ssl_mutex);
 }
 
 int delSSL(SOCKET s){
 	struct SSLqueue *sqi, *sqt = NULL;
+
 	if(!SSLq) return 0;
 	pthread_mutex_lock(&ssl_mutex);
 	if(SSLq){
@@ -113,13 +116,15 @@ static int ssl_send(SOCKET s, const void *msg, size_t len, int flags){
 	struct SSLqueue *sslq;
 
 	if ((sslq = searchSSL(s))){
-		int i=0, res, err;
-		do {
-			if((res = ssl_write(sslq->conn, (void *)msg, len)) < 0) {
-					err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
-					usleep(10*SLEEPTIME);
+		int res, err;
+		if((res = ssl_write(sslq->conn, (void *)msg, len)) <= 0){
+			err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
+			if (err == SSL_ERROR_WANT_WRITE){
+				_set_errno(EAGAIN);
+				return -1;
 			}
-		} while (res < 0 && (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) && ++i < 100); 
+			else _set_errno(err);
+		}
 		return res;
 	}
 
@@ -135,13 +140,15 @@ static int ssl_sendto(SOCKET s, const void *msg, size_t len, int flags, const st
 	struct SSLqueue *sslq;
 
 	if ((sslq = searchSSL(s))){
-		int i=0, res, err;
-		do {
-			if((res = ssl_write(sslq->conn, (void *)msg, len)) < 0) {
-					err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
-					usleep(10*SLEEPTIME);
+		int res, err;
+		if((res = ssl_write(sslq->conn, (void *)msg, len)) <= 0) {
+			err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
+			if (err == SSL_ERROR_WANT_WRITE){
+				_set_errno(EAGAIN);
+				return -1;
 			}
-		} while (res < 0 && (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) && ++i < 100); 
+			else _set_errno(err);
+		}
 		return res;
 	}
 
@@ -156,16 +163,17 @@ static int ssl_recvfrom(SOCKET s, void *msg, size_t len, int flags, struct socka
 	struct SSLqueue *sslq;
 
 	if ((sslq = searchSSL(s))){
-		int i=0, res, err;
-		do {
-			if((res = ssl_read(sslq->conn, (void *)msg, len)) < 0) {
-					err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
-					usleep(10*SLEEPTIME);
+		int res, err;
+		if((res = ssl_read(sslq->conn, (void *)msg, len)) <= 0) {
+			err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
+			if (err == SSL_ERROR_WANT_READ) {
+				_set_errno(EAGAIN);
+				return -1;
 			}
-		} while (res < 0 && (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) && ++i < 100); 
+			else _set_errno(err);
+		}
 		return res;
 	}
-
 	return sso._recvfrom(s, msg, len, flags, from, fromlen);
 }
 
@@ -177,13 +185,15 @@ static int WINAPI ssl_recv(SOCKET s, void *msg, size_t len, int flags){
 	struct SSLqueue *sslq;
 
 	if ((sslq = searchSSL(s))){
-		int i=0, res, err;
-		do {
-			if((res = ssl_read(sslq->conn, (void *)msg, len)) < 0) {
-					err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
-					usleep(10*SLEEPTIME);
+		int res, err;
+		if((res = ssl_read(sslq->conn, (void *)msg, len)) <= 0) {
+			err = SSL_get_error((SSL *)((ssl_conn*)sslq->conn)->ssl, res);
+			if (err == SSL_ERROR_WANT_READ) {
+				_set_errno(EAGAIN);
+				return -1;
 			}
-		} while (res < 0 && (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) && ++i < 100); 
+			else _set_errno(err);
+		}
 		return res;
 	}
 

+ 7 - 3
src/proxy.c

@@ -816,12 +816,16 @@ for(;;){
  if(isconnect && param->redirtype != R_HTTP) {
 	socksend(param->clisock, (unsigned char *)proxy_stringtable[8], (int)strlen(proxy_stringtable[8]), conf.timeouts[STRING_S]);
 	if(param->redirectfunc) {
-		 if(req)myfree(req);
-		 if(buf)myfree(buf);
-
+		if(req)myfree(req);
+		if(buf)myfree(buf);
 		return (*param->redirectfunc)(param);
 	}
 	param->res =  mapsocket(param, conf.timeouts[CONNECTION_L]);
+	if(param->redirectfunc) {
+		if(req)myfree(req);
+		if(buf)myfree(buf);
+		return (*param->redirectfunc)(param);
+	}
 	RETURN(param->res);
  }