diff options
Diffstat (limited to 'src/modules/netlib/netlibssl.cpp')
-rw-r--r-- | src/modules/netlib/netlibssl.cpp | 981 |
1 files changed, 981 insertions, 0 deletions
diff --git a/src/modules/netlib/netlibssl.cpp b/src/modules/netlib/netlibssl.cpp new file mode 100644 index 0000000000..2db769a3da --- /dev/null +++ b/src/modules/netlib/netlibssl.cpp @@ -0,0 +1,981 @@ +/*
+
+Miranda IM: the free IM client for Microsoft* Windows*
+
+Copyright 2000-2009 Miranda ICQ/IM project,
+all portions of this codebase are copyrighted to the people
+listed in contributors.txt.
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation; either version 2
+of the License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, write to the Free Software
+Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+*/
+
+#include "commonheaders.h"
+#include <m_popup.h>
+#include "netlib.h"
+
+#define SECURITY_WIN32
+#include <security.h>
+#include <schannel.h>
+
+//#include <SCHNLSP.H>
+
+typedef BOOL (* SSL_EMPTY_CACHE_FN_M)(VOID);
+
+static HMODULE g_hSchannel;
+static PSecurityFunctionTableA g_pSSPI;
+static HANDLE g_hSslMutex;
+static SSL_EMPTY_CACHE_FN_M MySslEmptyCache;
+static CredHandle hCreds;
+static bool bSslInitDone;
+
+typedef BOOL (WINAPI *pfnCertGetCertificateChain)(HCERTCHAINENGINE, PCCERT_CONTEXT, LPFILETIME, HCERTSTORE, PCERT_CHAIN_PARA, DWORD, LPVOID, PCCERT_CHAIN_CONTEXT*);
+static pfnCertGetCertificateChain fnCertGetCertificateChain;
+
+typedef VOID (WINAPI *pfnCertFreeCertificateChain)(PCCERT_CHAIN_CONTEXT);
+static pfnCertFreeCertificateChain fnCertFreeCertificateChain;
+
+typedef BOOL (WINAPI *pfnCertFreeCertificateContext)(PCCERT_CONTEXT);
+static pfnCertFreeCertificateContext fnCertFreeCertificateContext;
+
+typedef BOOL (WINAPI *pfnCertVerifyCertificateChainPolicy)(LPCSTR, PCCERT_CHAIN_CONTEXT, PCERT_CHAIN_POLICY_PARA, PCERT_CHAIN_POLICY_STATUS);
+static pfnCertVerifyCertificateChainPolicy fnCertVerifyCertificateChainPolicy;
+
+typedef enum
+{
+ sockOpen,
+ sockClosed,
+ sockError
+} SocketState;
+
+
+struct SslHandle
+{
+ SOCKET s;
+
+ CtxtHandle hContext;
+
+ BYTE *pbRecDataBuf;
+ int cbRecDataBuf;
+ int sbRecDataBuf;
+
+ BYTE *pbIoBuffer;
+ int cbIoBuffer;
+ int sbIoBuffer;
+
+ SocketState state;
+};
+
+static void ReportSslError(SECURITY_STATUS scRet, int line, bool showPopup = false)
+{
+ TCHAR szMsgBuf[256];
+ switch (scRet)
+ {
+ case 0:
+ case ERROR_NOT_READY:
+ return;
+
+ case SEC_E_INVALID_TOKEN:
+ _tcscpy(szMsgBuf, TranslateT("Client cannot decode host message. Possible causes: Host does not support SSL or requires not existing security package"));
+ break;
+
+ case CERT_E_CN_NO_MATCH:
+ case SEC_E_WRONG_PRINCIPAL:
+ _tcscpy(szMsgBuf, TranslateT("Host we are connecting to is not the one certificate was issued for"));
+ break;
+
+ default:
+ FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, scRet, LANG_USER_DEFAULT, szMsgBuf, SIZEOF(szMsgBuf), NULL);
+ }
+
+ TCHAR szMsgBuf2[512];
+ mir_sntprintf(szMsgBuf2, SIZEOF(szMsgBuf2), _T("SSL connection failure (%x %u): %s"), scRet, line, szMsgBuf);
+
+ char* szMsg = Utf8EncodeT(szMsgBuf2);
+ NetlibLogf(NULL, szMsg);
+ mir_free(szMsg);
+
+ SetLastError(scRet);
+ PUShowMessageT(szMsgBuf2, SM_WARNING);
+}
+
+static bool AcquireCredentials(void)
+{
+ SCHANNEL_CRED SchannelCred;
+ TimeStamp tsExpiry;
+ SECURITY_STATUS scRet;
+
+ ZeroMemory(&SchannelCred, sizeof(SchannelCred));
+
+ SchannelCred.dwVersion = SCHANNEL_CRED_VERSION;
+ SchannelCred.grbitEnabledProtocols = SP_PROT_SSL3TLS1_CLIENTS /*| 0xA00 TLS1.1 & 1.2*/;
+
+ SchannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION;
+
+ // Create an SSPI credential.
+ scRet = g_pSSPI->AcquireCredentialsHandleA(
+ NULL, // Name of principal
+ UNISP_NAME_A, // Name of package
+ SECPKG_CRED_OUTBOUND, // Flags indicating use
+ NULL, // Pointer to logon ID
+ &SchannelCred, // Package specific data
+ NULL, // Pointer to GetKey() func
+ NULL, // Value to pass to GetKey()
+ &hCreds, // (out) Cred Handle
+ &tsExpiry); // (out) Lifetime (optional)
+
+ ReportSslError(scRet, __LINE__);
+ return scRet == SEC_E_OK;
+}
+
+static bool SSL_library_init(void)
+{
+ if (bSslInitDone) return true;
+
+ WaitForSingleObject(g_hSslMutex, INFINITE);
+
+ if (!bSslInitDone)
+ {
+ g_hSchannel = LoadLibraryA("schannel.dll");
+ if (g_hSchannel)
+ {
+ INIT_SECURITY_INTERFACE_A pInitSecurityInterface;
+ pInitSecurityInterface = (INIT_SECURITY_INTERFACE_A)GetProcAddress(g_hSchannel, SECURITY_ENTRYPOINT_ANSIA);
+ if (pInitSecurityInterface != NULL)
+ g_pSSPI = pInitSecurityInterface();
+
+ if (g_pSSPI)
+ {
+ HINSTANCE hCrypt = LoadLibraryA("crypt32.dll");
+ if (hCrypt)
+ {
+ fnCertGetCertificateChain = (pfnCertGetCertificateChain)GetProcAddress(hCrypt, "CertGetCertificateChain");
+ fnCertFreeCertificateChain = (pfnCertFreeCertificateChain)GetProcAddress(hCrypt, "CertFreeCertificateChain");
+ fnCertFreeCertificateContext = (pfnCertFreeCertificateContext)GetProcAddress(hCrypt, "CertFreeCertificateContext");
+ fnCertVerifyCertificateChainPolicy = (pfnCertVerifyCertificateChainPolicy)GetProcAddress(hCrypt, "CertVerifyCertificateChainPolicy");
+ }
+
+ MySslEmptyCache = (SSL_EMPTY_CACHE_FN_M)GetProcAddress(g_hSchannel, "SslEmptyCache");
+ AcquireCredentials();
+ bSslInitDone = true;
+ }
+ else
+ {
+ FreeLibrary(g_hSchannel);
+ g_hSchannel = NULL;
+ }
+ }
+ }
+
+ ReleaseMutex(g_hSslMutex);
+ return bSslInitDone;
+}
+
+void NetlibSslFree(SslHandle *ssl)
+{
+ if (ssl == NULL) return;
+
+ g_pSSPI->DeleteSecurityContext(&ssl->hContext);
+
+ mir_free(ssl->pbRecDataBuf);
+ mir_free(ssl->pbIoBuffer);
+ memset(ssl, 0, sizeof(SslHandle));
+ mir_free(ssl);
+}
+
+BOOL NetlibSslPending(SslHandle *ssl)
+{
+ return ssl != NULL && ( ssl->cbRecDataBuf != 0 || ssl->cbIoBuffer != 0 );
+}
+
+static bool VerifyCertificate(SslHandle *ssl, PCSTR pszServerName, DWORD dwCertFlags)
+{
+ if (!fnCertGetCertificateChain)
+ return true;
+
+ static LPSTR rgszUsages[] =
+ {
+ szOID_PKIX_KP_SERVER_AUTH,
+ szOID_SERVER_GATED_CRYPTO,
+ szOID_SGC_NETSCAPE
+ };
+
+ CERT_CHAIN_PARA ChainPara = {0};
+ HTTPSPolicyCallbackData polHttps = {0};
+ CERT_CHAIN_POLICY_PARA PolicyPara = {0};
+ CERT_CHAIN_POLICY_STATUS PolicyStatus = {0};
+ PCCERT_CHAIN_CONTEXT pChainContext = NULL;
+ PCCERT_CONTEXT pServerCert = NULL;
+ DWORD scRet;
+
+ PWSTR pwszServerName = mir_a2u(pszServerName);
+
+ scRet = g_pSSPI->QueryContextAttributesA(&ssl->hContext,
+ SECPKG_ATTR_REMOTE_CERT_CONTEXT, &pServerCert);
+ if (scRet != SEC_E_OK)
+ goto cleanup;
+
+ if (pServerCert == NULL)
+ {
+ scRet = SEC_E_WRONG_PRINCIPAL;
+ goto cleanup;
+ }
+
+ ChainPara.cbSize = sizeof(ChainPara);
+ ChainPara.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR;
+ ChainPara.RequestedUsage.Usage.cUsageIdentifier = SIZEOF(rgszUsages);
+ ChainPara.RequestedUsage.Usage.rgpszUsageIdentifier = rgszUsages;
+
+ if (!fnCertGetCertificateChain(NULL, pServerCert, NULL, pServerCert->hCertStore,
+ &ChainPara, 0, NULL, &pChainContext))
+ {
+ scRet = GetLastError();
+ goto cleanup;
+ }
+
+ polHttps.cbStruct = sizeof(HTTPSPolicyCallbackData);
+ polHttps.dwAuthType = AUTHTYPE_SERVER;
+ polHttps.fdwChecks = dwCertFlags;
+ polHttps.pwszServerName = pwszServerName;
+
+ PolicyPara.cbSize = sizeof(PolicyPara);
+ PolicyPara.pvExtraPolicyPara = &polHttps;
+
+ PolicyStatus.cbSize = sizeof(PolicyStatus);
+
+ if (!fnCertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, pChainContext,
+ &PolicyPara, &PolicyStatus))
+ {
+ scRet = GetLastError();
+ goto cleanup;
+ }
+
+ if (PolicyStatus.dwError)
+ {
+ scRet = PolicyStatus.dwError;
+ goto cleanup;
+ }
+
+ scRet = SEC_E_OK;
+
+cleanup:
+ if (pChainContext)
+ fnCertFreeCertificateChain(pChainContext);
+ if (pServerCert)
+ fnCertFreeCertificateContext(pServerCert);
+ mir_free(pwszServerName);
+
+ ReportSslError(scRet, __LINE__, true);
+ return scRet == SEC_E_OK;
+}
+
+static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
+{
+ SecBufferDesc InBuffer;
+ SecBuffer InBuffers[2];
+ SecBufferDesc OutBuffer;
+ SecBuffer OutBuffers[1];
+ DWORD dwSSPIFlags;
+ DWORD dwSSPIOutFlags;
+ TimeStamp tsExpiry;
+ SECURITY_STATUS scRet;
+ DWORD cbData;
+
+ BOOL fDoRead;
+
+ dwSSPIFlags =
+ ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_ALLOCATE_MEMORY |
+ ISC_REQ_STREAM;
+
+ ssl->cbIoBuffer = 0;
+
+ fDoRead = fDoInitialRead;
+
+ scRet = SEC_I_CONTINUE_NEEDED;
+
+ // Loop until the handshake is finished or an error occurs.
+ while (scRet == SEC_I_CONTINUE_NEEDED || scRet == SEC_E_INCOMPLETE_MESSAGE || scRet == SEC_I_INCOMPLETE_CREDENTIALS)
+ {
+ // Read server data
+ if (0 == ssl->cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE)
+ {
+ if (fDoRead)
+ {
+ static const TIMEVAL tv = {6, 0};
+ fd_set fd;
+
+ // If buffer not large enough reallocate buffer
+ if (ssl->sbIoBuffer <= ssl->cbIoBuffer)
+ {
+ ssl->sbIoBuffer += 4096;
+ ssl->pbIoBuffer = (PUCHAR)mir_realloc(ssl->pbIoBuffer, ssl->sbIoBuffer);
+ }
+
+ FD_ZERO(&fd);
+ FD_SET(ssl->s, &fd);
+ if (select(1, &fd, NULL, NULL, &tv) != 1)
+ {
+ NetlibLogf(NULL, "SSL Negotiation failure recieving data (timeout) (bytes %u)", ssl->cbIoBuffer);
+ scRet = ERROR_NOT_READY;
+ break;
+ }
+
+ cbData = recv(ssl->s, (char*)ssl->pbIoBuffer + ssl->cbIoBuffer, ssl->sbIoBuffer - ssl->cbIoBuffer, 0);
+ if (cbData == SOCKET_ERROR)
+ {
+ NetlibLogf(NULL, "SSL Negotiation failure recieving data (%d)", WSAGetLastError());
+ scRet = ERROR_NOT_READY;
+ break;
+ }
+ if (cbData == 0)
+ {
+ NetlibLogf(NULL, "SSL Negotiation connection gracefully closed");
+ scRet = ERROR_NOT_READY;
+ break;
+ }
+
+ NetlibDumpData(NULL, ssl->pbIoBuffer + ssl->cbIoBuffer, cbData, 0, MSG_DUMPSSL);
+ ssl->cbIoBuffer += cbData;
+ }
+ else fDoRead = TRUE;
+ }
+
+ // Set up the input buffers. Buffer 0 is used to pass in data
+ // received from the server. Schannel will consume some or all
+ // of this. Leftover data (if any) will be placed in buffer 1 and
+ // given a buffer type of SECBUFFER_EXTRA.
+
+ InBuffers[0].pvBuffer = ssl->pbIoBuffer;
+ InBuffers[0].cbBuffer = ssl->cbIoBuffer;
+ InBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ InBuffers[1].pvBuffer = NULL;
+ InBuffers[1].cbBuffer = 0;
+ InBuffers[1].BufferType = SECBUFFER_EMPTY;
+
+ InBuffer.cBuffers = 2;
+ InBuffer.pBuffers = InBuffers;
+ InBuffer.ulVersion = SECBUFFER_VERSION;
+
+ // Set up the output buffers. These are initialized to NULL
+ // so as to make it less likely we'll attempt to free random
+ // garbage later.
+
+ OutBuffers[0].pvBuffer = NULL;
+ OutBuffers[0].BufferType= SECBUFFER_TOKEN;
+ OutBuffers[0].cbBuffer = 0;
+
+ OutBuffer.cBuffers = 1;
+ OutBuffer.pBuffers = OutBuffers;
+ OutBuffer.ulVersion = SECBUFFER_VERSION;
+
+ scRet = g_pSSPI->InitializeSecurityContextA(
+ &hCreds,
+ &ssl->hContext,
+ NULL,
+ dwSSPIFlags,
+ 0,
+ SECURITY_NATIVE_DREP,
+ &InBuffer,
+ 0,
+ NULL,
+ &OutBuffer,
+ &dwSSPIOutFlags,
+ &tsExpiry);
+
+ // If success (or if the error was one of the special extended ones),
+ // send the contents of the output buffer to the server.
+ if (scRet == SEC_E_OK ||
+ scRet == SEC_I_CONTINUE_NEEDED ||
+ (FAILED(scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR)))
+ {
+ if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
+ {
+ NetlibDumpData(NULL, (unsigned char*)(OutBuffers[0].pvBuffer), OutBuffers[0].cbBuffer, 1, MSG_DUMPSSL);
+ cbData = send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
+ if (cbData == SOCKET_ERROR || cbData == 0)
+ {
+ NetlibLogf(NULL, "SSL Negotiation failure sending data (%d)", WSAGetLastError());
+ g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
+ return SEC_E_INTERNAL_ERROR;
+ }
+
+ // Free output buffer.
+ g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
+ OutBuffers[0].pvBuffer = NULL;
+ }
+ }
+
+ // we need to read more data from the server and try again.
+ if (scRet == SEC_E_INCOMPLETE_MESSAGE) continue;
+
+ // handshake completed successfully.
+ if (scRet == SEC_E_OK)
+ {
+ // Store remaining data for further use
+ if (InBuffers[1].BufferType == SECBUFFER_EXTRA)
+ {
+ memmove(ssl->pbIoBuffer,
+ ssl->pbIoBuffer + (ssl->cbIoBuffer - InBuffers[1].cbBuffer),
+ InBuffers[1].cbBuffer);
+ ssl->cbIoBuffer = InBuffers[1].cbBuffer;
+ }
+ else
+ ssl->cbIoBuffer = 0;
+ break;
+ }
+
+ // Check for fatal error.
+ if (FAILED(scRet)) break;
+
+ // server just requested client authentication.
+ if (scRet == SEC_I_INCOMPLETE_CREDENTIALS)
+ {
+ // Server has requested client authentication and
+ // GetNewClientCredentials(ssl);
+
+ // Go around again.
+ fDoRead = FALSE;
+ scRet = SEC_I_CONTINUE_NEEDED;
+ continue;
+ }
+
+
+ // Copy any leftover data from the buffer, and go around again.
+ if (InBuffers[1].BufferType == SECBUFFER_EXTRA)
+ {
+ memmove(ssl->pbIoBuffer,
+ ssl->pbIoBuffer + (ssl->cbIoBuffer - InBuffers[1].cbBuffer),
+ InBuffers[1].cbBuffer);
+
+ ssl->cbIoBuffer = InBuffers[1].cbBuffer;
+ }
+ else ssl->cbIoBuffer = 0;
+ }
+
+ // Delete the security context in the case of a fatal error.
+ ReportSslError(scRet, __LINE__);
+
+ if (ssl->cbIoBuffer == 0)
+ {
+ mir_free(ssl->pbIoBuffer);
+ ssl->pbIoBuffer = NULL;
+ ssl->sbIoBuffer = 0;
+ }
+
+ return scRet;
+}
+
+static bool ClientConnect(SslHandle *ssl, const char *host)
+{
+ SecBufferDesc OutBuffer;
+ SecBuffer OutBuffers[1];
+ DWORD dwSSPIFlags;
+ DWORD dwSSPIOutFlags;
+ TimeStamp tsExpiry;
+ SECURITY_STATUS scRet;
+ DWORD cbData;
+
+ if (SecIsValidHandle(&ssl->hContext))
+ {
+ g_pSSPI->DeleteSecurityContext(&ssl->hContext);
+ SecInvalidateHandle(&ssl->hContext);
+ }
+
+ if (MySslEmptyCache) MySslEmptyCache();
+
+ dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_ALLOCATE_MEMORY |
+ ISC_REQ_STREAM;
+
+ // Initiate a ClientHello message and generate a token.
+
+ OutBuffers[0].pvBuffer = NULL;
+ OutBuffers[0].BufferType = SECBUFFER_TOKEN;
+ OutBuffers[0].cbBuffer = 0;
+
+ OutBuffer.cBuffers = 1;
+ OutBuffer.pBuffers = OutBuffers;
+ OutBuffer.ulVersion = SECBUFFER_VERSION;
+
+ scRet = g_pSSPI->InitializeSecurityContextA(
+ &hCreds,
+ NULL,
+ (SEC_CHAR*)host,
+ dwSSPIFlags,
+ 0,
+ SECURITY_NATIVE_DREP,
+ NULL,
+ 0,
+ &ssl->hContext,
+ &OutBuffer,
+ &dwSSPIOutFlags,
+ &tsExpiry);
+
+ if (scRet != SEC_I_CONTINUE_NEEDED)
+ {
+ ReportSslError(scRet, __LINE__);
+ return 0;
+ }
+
+ // Send response to server if there is one.
+ if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
+ {
+ NetlibDumpData(NULL, (unsigned char*)(OutBuffers[0].pvBuffer), OutBuffers[0].cbBuffer, 1, MSG_DUMPSSL);
+ cbData = send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
+ if (cbData == SOCKET_ERROR || cbData == 0)
+ {
+ NetlibLogf(NULL, "SSL failure sending connection data (%d %d)", ssl->s, WSAGetLastError());
+ g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
+ return 0;
+ }
+
+ // Free output buffer.
+ g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
+ OutBuffers[0].pvBuffer = NULL;
+ }
+
+ return ClientHandshakeLoop(ssl, TRUE) == SEC_E_OK;
+}
+
+
+SslHandle *NetlibSslConnect(SOCKET s, const char* host, int verify)
+{
+ SslHandle *ssl = (SslHandle*)mir_calloc(sizeof(SslHandle));
+ ssl->s = s;
+
+ SecInvalidateHandle(&ssl->hContext);
+
+ DWORD dwFlags = 0;
+
+ if (!host || inet_addr(host) != INADDR_NONE)
+ dwFlags |= 0x00001000;
+
+ bool res = SSL_library_init();
+
+ if (res) res = ClientConnect(ssl, host);
+ if (res && verify) res = VerifyCertificate(ssl, host, dwFlags);
+
+ if (!res)
+ {
+ NetlibSslFree(ssl);
+ ssl = NULL;
+ }
+ return ssl;
+}
+
+
+void NetlibSslShutdown(SslHandle *ssl)
+{
+ DWORD dwType;
+
+ SecBufferDesc OutBuffer;
+ SecBuffer OutBuffers[1];
+ DWORD dwSSPIFlags;
+ DWORD dwSSPIOutFlags;
+ TimeStamp tsExpiry;
+ DWORD scRet;
+
+ if (ssl == NULL || !SecIsValidHandle(&ssl->hContext))
+ return;
+
+ dwType = SCHANNEL_SHUTDOWN;
+
+ OutBuffers[0].pvBuffer = &dwType;
+ OutBuffers[0].BufferType = SECBUFFER_TOKEN;
+ OutBuffers[0].cbBuffer = sizeof(dwType);
+
+ OutBuffer.cBuffers = 1;
+ OutBuffer.pBuffers = OutBuffers;
+ OutBuffer.ulVersion = SECBUFFER_VERSION;
+
+ scRet = g_pSSPI->ApplyControlToken(&ssl->hContext, &OutBuffer);
+ if (FAILED(scRet)) return;
+
+ //
+ // Build an SSL close notify message.
+ //
+
+ dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_RET_EXTENDED_ERROR |
+ ISC_REQ_ALLOCATE_MEMORY |
+ ISC_REQ_STREAM;
+
+ OutBuffers[0].pvBuffer = NULL;
+ OutBuffers[0].BufferType = SECBUFFER_TOKEN;
+ OutBuffers[0].cbBuffer = 0;
+
+ OutBuffer.cBuffers = 1;
+ OutBuffer.pBuffers = OutBuffers;
+ OutBuffer.ulVersion = SECBUFFER_VERSION;
+
+ scRet = g_pSSPI->InitializeSecurityContextA(
+ &hCreds,
+ &ssl->hContext,
+ NULL,
+ dwSSPIFlags,
+ 0,
+ SECURITY_NATIVE_DREP,
+ NULL,
+ 0,
+ &ssl->hContext,
+ &OutBuffer,
+ &dwSSPIOutFlags,
+ &tsExpiry);
+
+ if (FAILED(scRet)) return;
+
+ // Send the close notify message to the server.
+ if (OutBuffers[0].pvBuffer != NULL && OutBuffers[0].cbBuffer != 0)
+ {
+ NetlibDumpData(NULL, (unsigned char*)(OutBuffers[0].pvBuffer), OutBuffers[0].cbBuffer, 1, MSG_DUMPSSL);
+ send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
+ g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
+ }
+}
+
+static int NetlibSslReadSetResult(SslHandle *ssl, char *buf, int num, int peek)
+{
+ if (ssl->cbRecDataBuf == 0)
+ {
+ return (ssl->state == sockClosed ? 0: SOCKET_ERROR);
+ }
+
+ int bytes = min(num, ssl->cbRecDataBuf);
+ int rbytes = ssl->cbRecDataBuf - bytes;
+
+ memcpy(buf, ssl->pbRecDataBuf, bytes);
+ if (!peek)
+ {
+ memmove(ssl->pbRecDataBuf, ssl->pbRecDataBuf + bytes, rbytes);
+ ssl->cbRecDataBuf = rbytes;
+ }
+
+ return bytes;
+}
+
+int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
+{
+ SECURITY_STATUS scRet;
+ DWORD cbData;
+ DWORD resNum = 0;
+ int i;
+
+ SecBufferDesc Message;
+ SecBuffer Buffers[4];
+ SecBuffer * pDataBuffer;
+ SecBuffer * pExtraBuffer;
+
+ if (ssl == NULL) return SOCKET_ERROR;
+
+ if (num <= 0) return 0;
+
+ if (ssl->state != sockOpen || (ssl->cbRecDataBuf != 0 && (!peek || ssl->cbRecDataBuf >= num)))
+ {
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ scRet = SEC_E_OK;
+
+ for (;;)
+ {
+ if (0 == ssl->cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE)
+ {
+ if (ssl->sbIoBuffer <= ssl->cbIoBuffer)
+ {
+ ssl->sbIoBuffer += 2048;
+ ssl->pbIoBuffer = (PUCHAR)mir_realloc(ssl->pbIoBuffer, ssl->sbIoBuffer);
+ }
+
+ if (peek)
+ {
+ static const TIMEVAL tv = {0};
+ fd_set fd;
+ FD_ZERO(&fd);
+ FD_SET(ssl->s, &fd);
+
+ cbData = select(1, &fd, NULL, NULL, &tv);
+ if (cbData == SOCKET_ERROR)
+ {
+ ssl->state = sockError;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ if (cbData == 0 && ssl->cbRecDataBuf)
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ cbData = recv(ssl->s, (char*)ssl->pbIoBuffer + ssl->cbIoBuffer, ssl->sbIoBuffer - ssl->cbIoBuffer, 0);
+ if (cbData == SOCKET_ERROR)
+ {
+ NetlibLogf(NULL, "SSL failure recieving data (%d)", WSAGetLastError());
+ ssl->state = sockError;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ if (cbData == 0)
+ {
+ NetlibLogf(NULL, "SSL connection gracefully closed");
+ if (peek && ssl->cbRecDataBuf)
+ {
+ ssl->state = sockClosed;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ // Server disconnected.
+ if (ssl->cbIoBuffer)
+ {
+ ssl->state = sockError;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ return 0;
+ }
+ else
+ {
+ NetlibDumpData(NULL, ssl->pbIoBuffer + ssl->cbIoBuffer, cbData, 0, MSG_DUMPSSL);
+ ssl->cbIoBuffer += cbData;
+ }
+ }
+
+ // Attempt to decrypt the received data.
+ Buffers[0].pvBuffer = ssl->pbIoBuffer;
+ Buffers[0].cbBuffer = ssl->cbIoBuffer;
+ Buffers[0].BufferType = SECBUFFER_DATA;
+
+ Buffers[1].BufferType = SECBUFFER_EMPTY;
+ Buffers[2].BufferType = SECBUFFER_EMPTY;
+ Buffers[3].BufferType = SECBUFFER_EMPTY;
+
+ Message.ulVersion = SECBUFFER_VERSION;
+ Message.cBuffers = 4;
+ Message.pBuffers = Buffers;
+
+ if (g_pSSPI->DecryptMessage != NULL && g_pSSPI->DecryptMessage != PVOID(0x80000000))
+ scRet = g_pSSPI->DecryptMessage(&ssl->hContext, &Message, 0, NULL);
+ else
+ scRet = ((DECRYPT_MESSAGE_FN)g_pSSPI->Reserved4)(&ssl->hContext, &Message, 0, NULL);
+
+ // The input buffer contains only a fragment of an
+ // encrypted record. Loop around and read some more
+ // data.
+ if (scRet == SEC_E_INCOMPLETE_MESSAGE)
+ continue;
+
+ if ( scRet != SEC_E_OK && scRet != SEC_I_RENEGOTIATE && scRet != SEC_I_CONTEXT_EXPIRED)
+ {
+ ReportSslError(scRet, __LINE__);
+ ssl->state = sockError;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ // Locate data and (optional) extra buffers.
+ pDataBuffer = NULL;
+ pExtraBuffer = NULL;
+ for(i = 1; i < 4; i++)
+ {
+ if (pDataBuffer == NULL && Buffers[i].BufferType == SECBUFFER_DATA)
+ pDataBuffer = &Buffers[i];
+
+ if (pExtraBuffer == NULL && Buffers[i].BufferType == SECBUFFER_EXTRA)
+ pExtraBuffer = &Buffers[i];
+ }
+
+ // Return decrypted data.
+ if (pDataBuffer)
+ {
+ DWORD bytes, rbytes;
+
+ bytes = peek ? 0 : min((DWORD)num, pDataBuffer->cbBuffer);
+ rbytes = pDataBuffer->cbBuffer - bytes;
+
+ NetlibDumpData(NULL, (PBYTE)pDataBuffer->pvBuffer, pDataBuffer->cbBuffer, 0, MSG_DUMPSSL);
+
+ if (rbytes > 0)
+ {
+ int nbytes = ssl->cbRecDataBuf + rbytes;
+ if (ssl->sbRecDataBuf < nbytes)
+ {
+ ssl->sbRecDataBuf = nbytes;
+ ssl->pbRecDataBuf = (PUCHAR)mir_realloc(ssl->pbRecDataBuf, nbytes);
+ }
+ memcpy(ssl->pbRecDataBuf + ssl->cbRecDataBuf, (char*)pDataBuffer->pvBuffer + bytes, rbytes);
+ ssl->cbRecDataBuf = nbytes;
+ }
+
+ if (peek)
+ {
+ resNum = bytes = min(num, ssl->cbRecDataBuf);
+ memcpy(buf, ssl->pbRecDataBuf, bytes);
+ }
+ else
+ {
+ resNum = bytes;
+ memcpy(buf, pDataBuffer->pvBuffer, bytes);
+ }
+ }
+
+ // Move any "extra" data to the input buffer.
+ if (pExtraBuffer)
+ {
+ memmove(ssl->pbIoBuffer, pExtraBuffer->pvBuffer, pExtraBuffer->cbBuffer);
+ ssl->cbIoBuffer = pExtraBuffer->cbBuffer;
+ }
+ else ssl->cbIoBuffer = 0;
+
+ if (pDataBuffer && resNum)
+ return resNum;
+
+ // Server signaled end of session
+ if (scRet == SEC_I_CONTEXT_EXPIRED)
+ {
+ NetlibLogf(NULL, "SSL Server signaled SSL Shutdown");
+ ssl->state = sockClosed;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+
+ if (scRet == SEC_I_RENEGOTIATE)
+ {
+ // The server wants to perform another handshake
+ // sequence.
+
+ scRet = ClientHandshakeLoop(ssl, FALSE);
+ if (scRet != SEC_E_OK)
+ {
+ ssl->state = sockError;
+ return NetlibSslReadSetResult(ssl, buf, num, peek);
+ }
+ }
+ }
+}
+
+int NetlibSslWrite(SslHandle *ssl, const char *buf, int num)
+{
+ SecPkgContext_StreamSizes Sizes;
+ SECURITY_STATUS scRet;
+ DWORD cbData;
+
+ SecBufferDesc Message;
+ SecBuffer Buffers[4] = {0};
+
+ PUCHAR pbDataBuffer;
+
+ PUCHAR pbMessage;
+ DWORD cbMessage;
+
+ DWORD sendOff = 0;
+
+ if (ssl == NULL) return SOCKET_ERROR;
+
+ scRet = g_pSSPI->QueryContextAttributesA(&ssl->hContext, SECPKG_ATTR_STREAM_SIZES, &Sizes);
+ if (scRet != SEC_E_OK) return scRet;
+
+ pbDataBuffer = (PUCHAR)mir_calloc(Sizes.cbMaximumMessage + Sizes.cbHeader + Sizes.cbTrailer);
+
+ pbMessage = pbDataBuffer + Sizes.cbHeader;
+
+ while (sendOff < (DWORD)num)
+ {
+ cbMessage = min(Sizes.cbMaximumMessage, (DWORD)num - sendOff);
+ CopyMemory(pbMessage, buf+sendOff, cbMessage);
+
+ Buffers[0].pvBuffer = pbDataBuffer;
+ Buffers[0].cbBuffer = Sizes.cbHeader;
+ Buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
+
+ Buffers[1].pvBuffer = pbMessage;
+ Buffers[1].cbBuffer = cbMessage;
+ Buffers[1].BufferType = SECBUFFER_DATA;
+
+ Buffers[2].pvBuffer = pbMessage + cbMessage;
+ Buffers[2].cbBuffer = Sizes.cbTrailer;
+ Buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
+
+ Buffers[3].BufferType = SECBUFFER_EMPTY;
+
+ Message.ulVersion = SECBUFFER_VERSION;
+ Message.cBuffers = 4;
+ Message.pBuffers = Buffers;
+
+ if (g_pSSPI->EncryptMessage != NULL)
+ scRet = g_pSSPI->EncryptMessage(&ssl->hContext, 0, &Message, 0);
+ else
+ scRet = ((ENCRYPT_MESSAGE_FN)g_pSSPI->Reserved3)(&ssl->hContext, 0, &Message, 0);
+
+ if (FAILED(scRet)) break;
+
+ // Calculate encrypted packet size
+ cbData = Buffers[0].cbBuffer + Buffers[1].cbBuffer + Buffers[2].cbBuffer;
+
+ // Send the encrypted data to the server.
+ NetlibDumpData(NULL, pbDataBuffer, cbData, 1, MSG_DUMPSSL);
+ cbData = send(ssl->s, (char*)pbDataBuffer, cbData, 0);
+ if (cbData == SOCKET_ERROR || cbData == 0)
+ {
+ NetlibLogf(NULL, "SSL failure sending data (%d)", WSAGetLastError());
+ scRet = SEC_E_INTERNAL_ERROR;
+ break;
+ }
+
+ sendOff += cbMessage;
+ }
+
+ mir_free(pbDataBuffer);
+ return scRet == SEC_E_OK ? num : SOCKET_ERROR;
+}
+
+static INT_PTR GetSslApi(WPARAM, LPARAM lParam)
+{
+ SSL_API* si = (SSL_API*)lParam;
+ if (si == NULL) return FALSE;
+
+ if (si->cbSize != sizeof(SSL_API))
+ return FALSE;
+
+ si->connect = (HSSL (__cdecl *)(SOCKET,const char *,int))NetlibSslConnect;
+ si->pending = (BOOL (__cdecl *)(HSSL))NetlibSslPending;
+ si->read = (int (__cdecl *)(HSSL,char *,int,int))NetlibSslRead;
+ si->write = (int (__cdecl *)(HSSL,const char *,int))NetlibSslWrite;
+ si->shutdown = (void (__cdecl *)(HSSL))NetlibSslShutdown;
+ si->sfree = (void (__cdecl *)(HSSL))NetlibSslFree;
+
+ return TRUE;
+}
+
+int LoadSslModule(void)
+{
+ CreateServiceFunction(MS_SYSTEM_GET_SI, GetSslApi);
+ g_hSslMutex = CreateMutex(NULL, FALSE, NULL);
+ SecInvalidateHandle(&hCreds);
+
+ return 0;
+}
+
+void UnloadSslModule(void)
+{
+ if (g_pSSPI && SecIsValidHandle(&hCreds))
+ g_pSSPI->FreeCredentialsHandle(&hCreds);
+ CloseHandle(g_hSslMutex);
+ if (g_hSchannel) FreeLibrary(g_hSchannel);
+}
|