summaryrefslogtreecommitdiff
path: root/src/modules/netlib/netlibssl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/modules/netlib/netlibssl.cpp')
-rw-r--r--src/modules/netlib/netlibssl.cpp497
1 files changed, 164 insertions, 333 deletions
diff --git a/src/modules/netlib/netlibssl.cpp b/src/modules/netlib/netlibssl.cpp
index fa9a77028d..0c0e62d32b 100644
--- a/src/modules/netlib/netlibssl.cpp
+++ b/src/modules/netlib/netlibssl.cpp
@@ -30,29 +30,18 @@ Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#include <security.h>
#include <schannel.h>
-//#include <SCHNLSP.H>
+#pragma comment(lib, "secur32.lib")
+#pragma comment(lib, "crypt32.lib")
-typedef BOOL (* SSL_EMPTY_CACHE_FN_M)(VOID);
+typedef BOOL (*SSL_EMPTY_CACHE_FN_M)(VOID);
static HMODULE g_hSchannel;
-static PSecurityFunctionTableA g_pSSPI;
+static PSecurityFunctionTable g_pSSPI;
static HANDLE g_hSslMutex;
static SSL_EMPTY_CACHE_FN_M MySslEmptyCache;
static CredHandle hCreds;
static bool bSslInitDone;
-typedef BOOL (WINAPI *pfnCertGetCertificateChain)(HCERTCHAINENGINE, PCCERT_CONTEXT, LPFILETIME, HCERTSTORE, PCERT_CHAIN_PARA, DWORD, LPVOID, PCCERT_CHAIN_CONTEXT*);
-static pfnCertGetCertificateChain fnCertGetCertificateChain;
-
-typedef VOID (WINAPI *pfnCertFreeCertificateChain)(PCCERT_CHAIN_CONTEXT);
-static pfnCertFreeCertificateChain fnCertFreeCertificateChain;
-
-typedef BOOL (WINAPI *pfnCertFreeCertificateContext)(PCCERT_CONTEXT);
-static pfnCertFreeCertificateContext fnCertFreeCertificateContext;
-
-typedef BOOL (WINAPI *pfnCertVerifyCertificateChainPolicy)(LPCSTR, PCCERT_CHAIN_CONTEXT, PCERT_CHAIN_POLICY_PARA, PCERT_CHAIN_POLICY_STATUS);
-static pfnCertVerifyCertificateChainPolicy fnCertVerifyCertificateChainPolicy;
-
typedef enum
{
sockOpen,
@@ -81,8 +70,7 @@ struct SslHandle
static void ReportSslError(SECURITY_STATUS scRet, int line, bool showPopup = false)
{
TCHAR szMsgBuf[256];
- switch (scRet)
- {
+ switch (scRet) {
case 0:
case ERROR_NOT_READY:
return;
@@ -126,9 +114,9 @@ static bool AcquireCredentials(void)
SchannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_MANUAL_CRED_VALIDATION;
// Create an SSPI credential.
- scRet = g_pSSPI->AcquireCredentialsHandleA(
+ scRet = g_pSSPI->AcquireCredentialsHandle(
NULL, // Name of principal
- UNISP_NAME_A, // Name of package
+ UNISP_NAME, // Name of package
SECPKG_CRED_OUTBOUND, // Flags indicating use
NULL, // Pointer to logon ID
&SchannelCred, // Package specific data
@@ -143,41 +131,16 @@ static bool AcquireCredentials(void)
static bool SSL_library_init(void)
{
- if (bSslInitDone) return true;
+ if (bSslInitDone)
+ return true;
WaitForSingleObject(g_hSslMutex, INFINITE);
- if (!bSslInitDone)
- {
- g_hSchannel = LoadLibraryA("schannel.dll");
- if (g_hSchannel)
- {
- INIT_SECURITY_INTERFACE_A pInitSecurityInterface;
- pInitSecurityInterface = (INIT_SECURITY_INTERFACE_A)GetProcAddress(g_hSchannel, SECURITY_ENTRYPOINT_ANSIA);
- if (pInitSecurityInterface != NULL)
- g_pSSPI = pInitSecurityInterface();
-
- if (g_pSSPI)
- {
- HINSTANCE hCrypt = LoadLibraryA("crypt32.dll");
- if (hCrypt)
- {
- fnCertGetCertificateChain = (pfnCertGetCertificateChain)GetProcAddress(hCrypt, "CertGetCertificateChain");
- fnCertFreeCertificateChain = (pfnCertFreeCertificateChain)GetProcAddress(hCrypt, "CertFreeCertificateChain");
- fnCertFreeCertificateContext = (pfnCertFreeCertificateContext)GetProcAddress(hCrypt, "CertFreeCertificateContext");
- fnCertVerifyCertificateChainPolicy = (pfnCertVerifyCertificateChainPolicy)GetProcAddress(hCrypt, "CertVerifyCertificateChainPolicy");
- }
-
- MySslEmptyCache = (SSL_EMPTY_CACHE_FN_M)GetProcAddress(g_hSchannel, "SslEmptyCache");
- AcquireCredentials();
- bSslInitDone = true;
- }
- else
- {
- FreeLibrary(g_hSchannel);
- g_hSchannel = NULL;
- }
- }
+ g_pSSPI = InitSecurityInterface();
+ if (g_pSSPI) {
+ MySslEmptyCache = (SSL_EMPTY_CACHE_FN_M)GetProcAddress(g_hSchannel, "SslEmptyCache");
+ AcquireCredentials();
+ bSslInitDone = true;
}
ReleaseMutex(g_hSslMutex);
@@ -203,9 +166,6 @@ BOOL NetlibSslPending(SslHandle *ssl)
static bool VerifyCertificate(SslHandle *ssl, PCSTR pszServerName, DWORD dwCertFlags)
{
- if (!fnCertGetCertificateChain)
- return true;
-
static LPSTR rgszUsages[] =
{
szOID_PKIX_KP_SERVER_AUTH,
@@ -213,23 +173,21 @@ static bool VerifyCertificate(SslHandle *ssl, PCSTR pszServerName, DWORD dwCertF
szOID_SGC_NETSCAPE
};
- CERT_CHAIN_PARA ChainPara = {0};
- HTTPSPolicyCallbackData polHttps = {0};
- CERT_CHAIN_POLICY_PARA PolicyPara = {0};
- CERT_CHAIN_POLICY_STATUS PolicyStatus = {0};
+ CERT_CHAIN_PARA ChainPara = { 0 };
+ HTTPSPolicyCallbackData polHttps = { 0 };
+ CERT_CHAIN_POLICY_PARA PolicyPara = { 0 };
+ CERT_CHAIN_POLICY_STATUS PolicyStatus = { 0 };
PCCERT_CHAIN_CONTEXT pChainContext = NULL;
PCCERT_CONTEXT pServerCert = NULL;
DWORD scRet;
PWSTR pwszServerName = mir_a2u(pszServerName);
- scRet = g_pSSPI->QueryContextAttributesA(&ssl->hContext,
- SECPKG_ATTR_REMOTE_CERT_CONTEXT, &pServerCert);
+ scRet = g_pSSPI->QueryContextAttributes(&ssl->hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &pServerCert);
if (scRet != SEC_E_OK)
goto cleanup;
- if (pServerCert == NULL)
- {
+ if (pServerCert == NULL) {
scRet = SEC_E_WRONG_PRINCIPAL;
goto cleanup;
}
@@ -239,9 +197,7 @@ static bool VerifyCertificate(SslHandle *ssl, PCSTR pszServerName, DWORD dwCertF
ChainPara.RequestedUsage.Usage.cUsageIdentifier = SIZEOF(rgszUsages);
ChainPara.RequestedUsage.Usage.rgpszUsageIdentifier = rgszUsages;
- if (!fnCertGetCertificateChain(NULL, pServerCert, NULL, pServerCert->hCertStore,
- &ChainPara, 0, NULL, &pChainContext))
- {
+ if (!CertGetCertificateChain(NULL, pServerCert, NULL, pServerCert->hCertStore, &ChainPara, 0, NULL, &pChainContext)) {
scRet = GetLastError();
goto cleanup;
}
@@ -256,15 +212,12 @@ static bool VerifyCertificate(SslHandle *ssl, PCSTR pszServerName, DWORD dwCertF
PolicyStatus.cbSize = sizeof(PolicyStatus);
- if (!fnCertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, pChainContext,
- &PolicyPara, &PolicyStatus))
- {
+ if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, pChainContext, &PolicyPara, &PolicyStatus)) {
scRet = GetLastError();
goto cleanup;
}
- if (PolicyStatus.dwError)
- {
+ if (PolicyStatus.dwError) {
scRet = PolicyStatus.dwError;
goto cleanup;
}
@@ -273,9 +226,9 @@ static bool VerifyCertificate(SslHandle *ssl, PCSTR pszServerName, DWORD dwCertF
cleanup:
if (pChainContext)
- fnCertFreeCertificateChain(pChainContext);
+ CertFreeCertificateChain(pChainContext);
if (pServerCert)
- fnCertFreeCertificateContext(pServerCert);
+ CertFreeCertificateContext(pServerCert);
mir_free(pwszServerName);
ReportSslError(scRet, __LINE__, true);
@@ -284,68 +237,49 @@ cleanup:
static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
{
- SecBufferDesc InBuffer;
- SecBuffer InBuffers[2];
- SecBufferDesc OutBuffer;
- SecBuffer OutBuffers[1];
- DWORD dwSSPIFlags;
- DWORD dwSSPIOutFlags;
- TimeStamp tsExpiry;
- SECURITY_STATUS scRet;
- DWORD cbData;
-
- BOOL fDoRead;
-
- dwSSPIFlags =
- ISC_REQ_SEQUENCE_DETECT |
- ISC_REQ_REPLAY_DETECT |
- ISC_REQ_CONFIDENTIALITY |
- ISC_REQ_EXTENDED_ERROR |
- ISC_REQ_ALLOCATE_MEMORY |
+ DWORD dwSSPIFlags =
+ ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
ssl->cbIoBuffer = 0;
- fDoRead = fDoInitialRead;
+ BOOL fDoRead = fDoInitialRead;
- scRet = SEC_I_CONTINUE_NEEDED;
+ SECURITY_STATUS scRet = SEC_I_CONTINUE_NEEDED;
// Loop until the handshake is finished or an error occurs.
- while (scRet == SEC_I_CONTINUE_NEEDED || scRet == SEC_E_INCOMPLETE_MESSAGE || scRet == SEC_I_INCOMPLETE_CREDENTIALS)
- {
+ while (scRet == SEC_I_CONTINUE_NEEDED || scRet == SEC_E_INCOMPLETE_MESSAGE || scRet == SEC_I_INCOMPLETE_CREDENTIALS) {
// Read server data
- if (0 == ssl->cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE)
- {
- if (fDoRead)
- {
- static const TIMEVAL tv = {6, 0};
+ if (0 == ssl->cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE) {
+ if (fDoRead) {
+ static const TIMEVAL tv = { 6, 0 };
fd_set fd;
// If buffer not large enough reallocate buffer
- if (ssl->sbIoBuffer <= ssl->cbIoBuffer)
- {
+ if (ssl->sbIoBuffer <= ssl->cbIoBuffer) {
ssl->sbIoBuffer += 4096;
ssl->pbIoBuffer = (PUCHAR)mir_realloc(ssl->pbIoBuffer, ssl->sbIoBuffer);
}
FD_ZERO(&fd);
FD_SET(ssl->s, &fd);
- if (select(1, &fd, NULL, NULL, &tv) != 1)
- {
+ if (select(1, &fd, NULL, NULL, &tv) != 1) {
NetlibLogf(NULL, "SSL Negotiation failure recieving data (timeout) (bytes %u)", ssl->cbIoBuffer);
scRet = ERROR_NOT_READY;
break;
}
- cbData = recv(ssl->s, (char*)ssl->pbIoBuffer + ssl->cbIoBuffer, ssl->sbIoBuffer - ssl->cbIoBuffer, 0);
- if (cbData == SOCKET_ERROR)
- {
+ DWORD cbData = recv(ssl->s, (char*)ssl->pbIoBuffer + ssl->cbIoBuffer, ssl->sbIoBuffer - ssl->cbIoBuffer, 0);
+ if (cbData == SOCKET_ERROR) {
NetlibLogf(NULL, "SSL Negotiation failure recieving data (%d)", WSAGetLastError());
scRet = ERROR_NOT_READY;
break;
}
- if (cbData == 0)
- {
+ if (cbData == 0) {
NetlibLogf(NULL, "SSL Negotiation connection gracefully closed");
scRet = ERROR_NOT_READY;
break;
@@ -362,6 +296,7 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
// of this. Leftover data (if any) will be placed in buffer 1 and
// given a buffer type of SECBUFFER_EXTRA.
+ SecBuffer InBuffers[2];
InBuffers[0].pvBuffer = ssl->pbIoBuffer;
InBuffers[0].cbBuffer = ssl->cbIoBuffer;
InBuffers[0].BufferType = SECBUFFER_TOKEN;
@@ -370,7 +305,8 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
InBuffers[1].cbBuffer = 0;
InBuffers[1].BufferType = SECBUFFER_EMPTY;
- InBuffer.cBuffers = 2;
+ SecBufferDesc InBuffer;
+ InBuffer.cBuffers = SIZEOF(InBuffers);
InBuffer.pBuffers = InBuffers;
InBuffer.ulVersion = SECBUFFER_VERSION;
@@ -378,40 +314,29 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
// so as to make it less likely we'll attempt to free random
// garbage later.
+ SecBuffer OutBuffers[1];
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = 0;
- OutBuffer.cBuffers = 1;
+ SecBufferDesc OutBuffer;
+ OutBuffer.cBuffers = SIZEOF(OutBuffers);
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
- scRet = g_pSSPI->InitializeSecurityContextA(
- &hCreds,
- &ssl->hContext,
- NULL,
- dwSSPIFlags,
- 0,
- SECURITY_NATIVE_DREP,
- &InBuffer,
- 0,
- NULL,
- &OutBuffer,
- &dwSSPIOutFlags,
- &tsExpiry);
+ TimeStamp tsExpiry;
+ DWORD dwSSPIOutFlags;
+ scRet = g_pSSPI->InitializeSecurityContext(&hCreds, &ssl->hContext, NULL, dwSSPIFlags, 0,
+ SECURITY_NATIVE_DREP, &InBuffer, 0, NULL, &OutBuffer, &dwSSPIOutFlags, &tsExpiry);
// If success (or if the error was one of the special extended ones),
// send the contents of the output buffer to the server.
- if (scRet == SEC_E_OK ||
- scRet == SEC_I_CONTINUE_NEEDED ||
- (FAILED(scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR)))
- {
- if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
- {
+ if (scRet == SEC_E_OK || scRet == SEC_I_CONTINUE_NEEDED || (FAILED(scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR))) {
+ if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL) {
NetlibDumpData(NULL, (unsigned char*)(OutBuffers[0].pvBuffer), OutBuffers[0].cbBuffer, 1, MSG_DUMPSSL);
- cbData = send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
- if (cbData == SOCKET_ERROR || cbData == 0)
- {
+
+ DWORD cbData = send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
+ if (cbData == SOCKET_ERROR || cbData == 0) {
NetlibLogf(NULL, "SSL Negotiation failure sending data (%d)", WSAGetLastError());
g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
return SEC_E_INTERNAL_ERROR;
@@ -424,21 +349,19 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
}
// we need to read more data from the server and try again.
- if (scRet == SEC_E_INCOMPLETE_MESSAGE) continue;
+ if (scRet == SEC_E_INCOMPLETE_MESSAGE)
+ continue;
// handshake completed successfully.
- if (scRet == SEC_E_OK)
- {
+ if (scRet == SEC_E_OK) {
// Store remaining data for further use
- if (InBuffers[1].BufferType == SECBUFFER_EXTRA)
- {
+ if (InBuffers[1].BufferType == SECBUFFER_EXTRA) {
memmove(ssl->pbIoBuffer,
ssl->pbIoBuffer + (ssl->cbIoBuffer - InBuffers[1].cbBuffer),
InBuffers[1].cbBuffer);
ssl->cbIoBuffer = InBuffers[1].cbBuffer;
}
- else
- ssl->cbIoBuffer = 0;
+ else ssl->cbIoBuffer = 0;
break;
}
@@ -446,8 +369,7 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
if (FAILED(scRet)) break;
// server just requested client authentication.
- if (scRet == SEC_I_INCOMPLETE_CREDENTIALS)
- {
+ if (scRet == SEC_I_INCOMPLETE_CREDENTIALS) {
// Server has requested client authentication and
// GetNewClientCredentials(ssl);
@@ -458,12 +380,8 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
}
// Copy any leftover data from the buffer, and go around again.
- if (InBuffers[1].BufferType == SECBUFFER_EXTRA)
- {
- memmove(ssl->pbIoBuffer,
- ssl->pbIoBuffer + (ssl->cbIoBuffer - InBuffers[1].cbBuffer),
- InBuffers[1].cbBuffer);
-
+ if (InBuffers[1].BufferType == SECBUFFER_EXTRA) {
+ memmove(ssl->pbIoBuffer, ssl->pbIoBuffer + (ssl->cbIoBuffer - InBuffers[1].cbBuffer), InBuffers[1].cbBuffer);
ssl->cbIoBuffer = InBuffers[1].cbBuffer;
}
else ssl->cbIoBuffer = 0;
@@ -472,8 +390,7 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
// Delete the security context in the case of a fatal error.
ReportSslError(scRet, __LINE__);
- if (ssl->cbIoBuffer == 0)
- {
+ if (ssl->cbIoBuffer == 0) {
mir_free(ssl->pbIoBuffer);
ssl->pbIoBuffer = NULL;
ssl->sbIoBuffer = 0;
@@ -484,66 +401,46 @@ static SECURITY_STATUS ClientHandshakeLoop(SslHandle *ssl, BOOL fDoInitialRead)
static bool ClientConnect(SslHandle *ssl, const char *host)
{
- SecBufferDesc OutBuffer;
- SecBuffer OutBuffers[1];
- DWORD dwSSPIFlags;
- DWORD dwSSPIOutFlags;
- TimeStamp tsExpiry;
- SECURITY_STATUS scRet;
- DWORD cbData;
-
- if (SecIsValidHandle(&ssl->hContext))
- {
+ if (SecIsValidHandle(&ssl->hContext)) {
g_pSSPI->DeleteSecurityContext(&ssl->hContext);
SecInvalidateHandle(&ssl->hContext);
}
if (MySslEmptyCache) MySslEmptyCache();
- dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
- ISC_REQ_REPLAY_DETECT |
- ISC_REQ_CONFIDENTIALITY |
- ISC_REQ_EXTENDED_ERROR |
- ISC_REQ_ALLOCATE_MEMORY |
+ DWORD dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR |
+ ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
// Initiate a ClientHello message and generate a token.
-
+ SecBuffer OutBuffers[1];
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = 0;
- OutBuffer.cBuffers = 1;
+ SecBufferDesc OutBuffer;
+ OutBuffer.cBuffers = SIZEOF(OutBuffers);
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
- scRet = g_pSSPI->InitializeSecurityContextA(
- &hCreds,
- NULL,
- (SEC_CHAR*)host,
- dwSSPIFlags,
- 0,
- SECURITY_NATIVE_DREP,
- NULL,
- 0,
- &ssl->hContext,
- &OutBuffer,
- &dwSSPIOutFlags,
- &tsExpiry);
-
- if (scRet != SEC_I_CONTINUE_NEEDED)
- {
+ TimeStamp tsExpiry;
+ DWORD dwSSPIOutFlags;
+ SECURITY_STATUS scRet = g_pSSPI->InitializeSecurityContext(&hCreds, NULL, _A2T(host), dwSSPIFlags, 0,
+ SECURITY_NATIVE_DREP, NULL, 0, &ssl->hContext, &OutBuffer, &dwSSPIOutFlags, &tsExpiry);
+ if (scRet != SEC_I_CONTINUE_NEEDED) {
ReportSslError(scRet, __LINE__);
return 0;
}
// Send response to server if there is one.
- if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
- {
+ if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL) {
NetlibDumpData(NULL, (unsigned char*)(OutBuffers[0].pvBuffer), OutBuffers[0].cbBuffer, 1, MSG_DUMPSSL);
- cbData = send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
- if (cbData == SOCKET_ERROR || cbData == 0)
- {
+
+ DWORD cbData = send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
+ if (cbData == SOCKET_ERROR || cbData == 0) {
NetlibLogf(NULL, "SSL failure sending connection data (%d %d)", ssl->s, WSAGetLastError());
g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
return 0;
@@ -557,7 +454,7 @@ static bool ClientConnect(SslHandle *ssl, const char *host)
return ClientHandshakeLoop(ssl, TRUE) == SEC_E_OK;
}
-SslHandle *NetlibSslConnect(SOCKET s, const char* host, int verify)
+SslHandle* NetlibSslConnect(SOCKET s, const char* host, int verify)
{
SslHandle *ssl = (SslHandle*)mir_calloc(sizeof(SslHandle));
ssl->s = s;
@@ -574,8 +471,7 @@ SslHandle *NetlibSslConnect(SOCKET s, const char* host, int verify)
if (res) res = ClientConnect(ssl, host);
if (res && verify) res = VerifyCertificate(ssl, host, dwFlags);
- if (!res)
- {
+ if (!res) {
NetlibSslFree(ssl);
ssl = NULL;
}
@@ -584,40 +480,32 @@ SslHandle *NetlibSslConnect(SOCKET s, const char* host, int verify)
void NetlibSslShutdown(SslHandle *ssl)
{
- DWORD dwType;
-
- SecBufferDesc OutBuffer;
- SecBuffer OutBuffers[1];
- DWORD dwSSPIFlags;
- DWORD dwSSPIOutFlags;
- TimeStamp tsExpiry;
- DWORD scRet;
-
if (ssl == NULL || !SecIsValidHandle(&ssl->hContext))
return;
- dwType = SCHANNEL_SHUTDOWN;
+ DWORD dwType = SCHANNEL_SHUTDOWN;
+ SecBuffer OutBuffers[1];
OutBuffers[0].pvBuffer = &dwType;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = sizeof(dwType);
- OutBuffer.cBuffers = 1;
+ SecBufferDesc OutBuffer;
+ OutBuffer.cBuffers = SIZEOF(OutBuffers);
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
- scRet = g_pSSPI->ApplyControlToken(&ssl->hContext, &OutBuffer);
- if (FAILED(scRet)) return;
+ SECURITY_STATUS scRet = g_pSSPI->ApplyControlToken(&ssl->hContext, &OutBuffer);
+ if (FAILED(scRet))
+ return;
- //
// Build an SSL close notify message.
- //
- dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
- ISC_REQ_REPLAY_DETECT |
- ISC_REQ_CONFIDENTIALITY |
- ISC_RET_EXTENDED_ERROR |
- ISC_REQ_ALLOCATE_MEMORY |
+ DWORD dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
+ ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_CONFIDENTIALITY |
+ ISC_RET_EXTENDED_ERROR |
+ ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
OutBuffers[0].pvBuffer = NULL;
@@ -628,25 +516,15 @@ void NetlibSslShutdown(SslHandle *ssl)
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
- scRet = g_pSSPI->InitializeSecurityContextA(
- &hCreds,
- &ssl->hContext,
- NULL,
- dwSSPIFlags,
- 0,
- SECURITY_NATIVE_DREP,
- NULL,
- 0,
- &ssl->hContext,
- &OutBuffer,
- &dwSSPIOutFlags,
- &tsExpiry);
-
- if (FAILED(scRet)) return;
+ TimeStamp tsExpiry;
+ DWORD dwSSPIOutFlags;
+ scRet = g_pSSPI->InitializeSecurityContext(&hCreds, &ssl->hContext, NULL, dwSSPIFlags, 0, SECURITY_NATIVE_DREP, NULL, 0,
+ &ssl->hContext, &OutBuffer, &dwSSPIOutFlags, &tsExpiry);
+ if (FAILED(scRet))
+ return;
// Send the close notify message to the server.
- if (OutBuffers[0].pvBuffer != NULL && OutBuffers[0].cbBuffer != 0)
- {
+ if (OutBuffers[0].pvBuffer != NULL && OutBuffers[0].cbBuffer != 0) {
NetlibDumpData(NULL, (unsigned char*)(OutBuffers[0].pvBuffer), OutBuffers[0].cbBuffer, 1, MSG_DUMPSSL);
send(ssl->s, (char*)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0);
g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer);
@@ -656,16 +534,13 @@ void NetlibSslShutdown(SslHandle *ssl)
static int NetlibSslReadSetResult(SslHandle *ssl, char *buf, int num, int peek)
{
if (ssl->cbRecDataBuf == 0)
- {
- return (ssl->state == sockClosed ? 0: SOCKET_ERROR);
- }
+ return (ssl->state == sockClosed ? 0 : SOCKET_ERROR);
int bytes = min(num, ssl->cbRecDataBuf);
int rbytes = ssl->cbRecDataBuf - bytes;
memcpy(buf, ssl->pbRecDataBuf, bytes);
- if (!peek)
- {
+ if (!peek) {
memmove(ssl->pbRecDataBuf, ssl->pbRecDataBuf + bytes, rbytes);
ssl->cbRecDataBuf = rbytes;
}
@@ -675,47 +550,30 @@ static int NetlibSslReadSetResult(SslHandle *ssl, char *buf, int num, int peek)
int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
{
- SECURITY_STATUS scRet;
- DWORD cbData;
- DWORD resNum = 0;
- int i;
-
- SecBufferDesc Message;
- SecBuffer Buffers[4];
- SecBuffer * pDataBuffer;
- SecBuffer * pExtraBuffer;
-
if (ssl == NULL) return SOCKET_ERROR;
if (num <= 0) return 0;
if (ssl->state != sockOpen || (ssl->cbRecDataBuf != 0 && (!peek || ssl->cbRecDataBuf >= num)))
- {
return NetlibSslReadSetResult(ssl, buf, num, peek);
- }
- scRet = SEC_E_OK;
+ SECURITY_STATUS scRet = SEC_E_OK;
- while(true)
- {
- if (0 == ssl->cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE)
- {
- if (ssl->sbIoBuffer <= ssl->cbIoBuffer)
- {
+ while (true) {
+ if (0 == ssl->cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE) {
+ if (ssl->sbIoBuffer <= ssl->cbIoBuffer) {
ssl->sbIoBuffer += 2048;
ssl->pbIoBuffer = (PUCHAR)mir_realloc(ssl->pbIoBuffer, ssl->sbIoBuffer);
}
- if (peek)
- {
- static const TIMEVAL tv = {0};
+ if (peek) {
+ static const TIMEVAL tv = { 0 };
fd_set fd;
FD_ZERO(&fd);
FD_SET(ssl->s, &fd);
- cbData = select(1, &fd, NULL, NULL, &tv);
- if (cbData == SOCKET_ERROR)
- {
+ DWORD cbData = select(1, &fd, NULL, NULL, &tv);
+ if (cbData == SOCKET_ERROR) {
ssl->state = sockError;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
@@ -724,40 +582,36 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
- cbData = recv(ssl->s, (char*)ssl->pbIoBuffer + ssl->cbIoBuffer, ssl->sbIoBuffer - ssl->cbIoBuffer, 0);
- if (cbData == SOCKET_ERROR)
- {
+ DWORD cbData = recv(ssl->s, (char*)ssl->pbIoBuffer + ssl->cbIoBuffer, ssl->sbIoBuffer - ssl->cbIoBuffer, 0);
+ if (cbData == SOCKET_ERROR) {
NetlibLogf(NULL, "SSL failure recieving data (%d)", WSAGetLastError());
ssl->state = sockError;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
- if (cbData == 0)
- {
+ if (cbData == 0) {
NetlibLogf(NULL, "SSL connection gracefully closed");
- if (peek && ssl->cbRecDataBuf)
- {
+ if (peek && ssl->cbRecDataBuf) {
ssl->state = sockClosed;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
// Server disconnected.
- if (ssl->cbIoBuffer)
- {
+ if (ssl->cbIoBuffer) {
ssl->state = sockError;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
return 0;
}
- else
- {
+ else {
NetlibDumpData(NULL, ssl->pbIoBuffer + ssl->cbIoBuffer, cbData, 0, MSG_DUMPSSL);
ssl->cbIoBuffer += cbData;
}
}
// Attempt to decrypt the received data.
+ SecBuffer Buffers[4];
Buffers[0].pvBuffer = ssl->pbIoBuffer;
Buffers[0].cbBuffer = ssl->cbIoBuffer;
Buffers[0].BufferType = SECBUFFER_DATA;
@@ -766,8 +620,9 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
Buffers[2].BufferType = SECBUFFER_EMPTY;
Buffers[3].BufferType = SECBUFFER_EMPTY;
+ SecBufferDesc Message;
Message.ulVersion = SECBUFFER_VERSION;
- Message.cBuffers = 4;
+ Message.cBuffers = SIZEOF(Buffers);
Message.pBuffers = Buffers;
if (g_pSSPI->DecryptMessage != NULL && g_pSSPI->DecryptMessage != PVOID(0x80000000))
@@ -781,18 +636,16 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
if (scRet == SEC_E_INCOMPLETE_MESSAGE)
continue;
- if (scRet != SEC_E_OK && scRet != SEC_I_RENEGOTIATE && scRet != SEC_I_CONTEXT_EXPIRED)
- {
+ if (scRet != SEC_E_OK && scRet != SEC_I_RENEGOTIATE && scRet != SEC_I_CONTEXT_EXPIRED) {
ReportSslError(scRet, __LINE__);
ssl->state = sockError;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
// Locate data and (optional) extra buffers.
- pDataBuffer = NULL;
- pExtraBuffer = NULL;
- for (i = 1; i < 4; i++)
- {
+ SecBuffer *pDataBuffer = NULL;
+ SecBuffer *pExtraBuffer = NULL;
+ for (int i = 1; i < SIZEOF(Buffers); i++) {
if (pDataBuffer == NULL && Buffers[i].BufferType == SECBUFFER_DATA)
pDataBuffer = &Buffers[i];
@@ -801,20 +654,16 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
}
// Return decrypted data.
- if (pDataBuffer)
- {
- DWORD bytes, rbytes;
-
- bytes = peek ? 0 : min((DWORD)num, pDataBuffer->cbBuffer);
- rbytes = pDataBuffer->cbBuffer - bytes;
+ DWORD resNum = 0;
+ if (pDataBuffer) {
+ DWORD bytes = peek ? 0 : min((DWORD)num, pDataBuffer->cbBuffer);
+ DWORD rbytes = pDataBuffer->cbBuffer - bytes;
NetlibDumpData(NULL, (PBYTE)pDataBuffer->pvBuffer, pDataBuffer->cbBuffer, 0, MSG_DUMPSSL);
- if (rbytes > 0)
- {
+ if (rbytes > 0) {
int nbytes = ssl->cbRecDataBuf + rbytes;
- if (ssl->sbRecDataBuf < nbytes)
- {
+ if (ssl->sbRecDataBuf < nbytes) {
ssl->sbRecDataBuf = nbytes;
ssl->pbRecDataBuf = (PUCHAR)mir_realloc(ssl->pbRecDataBuf, nbytes);
}
@@ -822,21 +671,18 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
ssl->cbRecDataBuf = nbytes;
}
- if (peek)
- {
+ if (peek) {
resNum = bytes = min(num, ssl->cbRecDataBuf);
memcpy(buf, ssl->pbRecDataBuf, bytes);
}
- else
- {
+ else {
resNum = bytes;
memcpy(buf, pDataBuffer->pvBuffer, bytes);
}
}
// Move any "extra" data to the input buffer.
- if (pExtraBuffer)
- {
+ if (pExtraBuffer) {
memmove(ssl->pbIoBuffer, pExtraBuffer->pvBuffer, pExtraBuffer->cbBuffer);
ssl->cbIoBuffer = pExtraBuffer->cbBuffer;
}
@@ -846,21 +692,18 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
return resNum;
// Server signaled end of session
- if (scRet == SEC_I_CONTEXT_EXPIRED)
- {
+ if (scRet == SEC_I_CONTEXT_EXPIRED) {
NetlibLogf(NULL, "SSL Server signaled SSL Shutdown");
ssl->state = sockClosed;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
- if (scRet == SEC_I_RENEGOTIATE)
- {
+ if (scRet == SEC_I_RENEGOTIATE) {
// The server wants to perform another handshake
// sequence.
scRet = ClientHandshakeLoop(ssl, FALSE);
- if (scRet != SEC_E_OK)
- {
+ if (scRet != SEC_E_OK) {
ssl->state = sockError;
return NetlibSslReadSetResult(ssl, buf, num, peek);
}
@@ -870,34 +713,23 @@ int NetlibSslRead(SslHandle *ssl, char *buf, int num, int peek)
int NetlibSslWrite(SslHandle *ssl, const char *buf, int num)
{
- SecPkgContext_StreamSizes Sizes;
- SECURITY_STATUS scRet;
- DWORD cbData;
-
- SecBufferDesc Message;
- SecBuffer Buffers[4] = {0};
-
- PUCHAR pbDataBuffer;
-
- PUCHAR pbMessage;
- DWORD cbMessage;
-
- DWORD sendOff = 0;
-
if (ssl == NULL) return SOCKET_ERROR;
- scRet = g_pSSPI->QueryContextAttributesA(&ssl->hContext, SECPKG_ATTR_STREAM_SIZES, &Sizes);
- if (scRet != SEC_E_OK) return scRet;
+ SecPkgContext_StreamSizes Sizes;
+ SECURITY_STATUS scRet = g_pSSPI->QueryContextAttributes(&ssl->hContext, SECPKG_ATTR_STREAM_SIZES, &Sizes);
+ if (scRet != SEC_E_OK)
+ return scRet;
- pbDataBuffer = (PUCHAR)mir_calloc(Sizes.cbMaximumMessage + Sizes.cbHeader + Sizes.cbTrailer);
+ PUCHAR pbDataBuffer = (PUCHAR)mir_calloc(Sizes.cbMaximumMessage + Sizes.cbHeader + Sizes.cbTrailer);
- pbMessage = pbDataBuffer + Sizes.cbHeader;
+ PUCHAR pbMessage = pbDataBuffer + Sizes.cbHeader;
- while (sendOff < (DWORD)num)
- {
- cbMessage = min(Sizes.cbMaximumMessage, (DWORD)num - sendOff);
- memcpy(pbMessage, buf+sendOff, cbMessage);
+ DWORD sendOff = 0;
+ while (sendOff < (DWORD)num) {
+ DWORD cbMessage = min(Sizes.cbMaximumMessage, (DWORD)num - sendOff);
+ memcpy(pbMessage, buf + sendOff, cbMessage);
+ SecBuffer Buffers[4] = { 0 };
Buffers[0].pvBuffer = pbDataBuffer;
Buffers[0].cbBuffer = Sizes.cbHeader;
Buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
@@ -912,8 +744,9 @@ int NetlibSslWrite(SslHandle *ssl, const char *buf, int num)
Buffers[3].BufferType = SECBUFFER_EMPTY;
+ SecBufferDesc Message;
Message.ulVersion = SECBUFFER_VERSION;
- Message.cBuffers = 4;
+ Message.cBuffers = SIZEOF(Buffers);
Message.pBuffers = Buffers;
if (g_pSSPI->EncryptMessage != NULL)
@@ -924,13 +757,12 @@ int NetlibSslWrite(SslHandle *ssl, const char *buf, int num)
if (FAILED(scRet)) break;
// Calculate encrypted packet size
- cbData = Buffers[0].cbBuffer + Buffers[1].cbBuffer + Buffers[2].cbBuffer;
+ DWORD cbData = Buffers[0].cbBuffer + Buffers[1].cbBuffer + Buffers[2].cbBuffer;
// Send the encrypted data to the server.
NetlibDumpData(NULL, pbDataBuffer, cbData, 1, MSG_DUMPSSL);
cbData = send(ssl->s, (char*)pbDataBuffer, cbData, 0);
- if (cbData == SOCKET_ERROR || cbData == 0)
- {
+ if (cbData == SOCKET_ERROR || cbData == 0) {
NetlibLogf(NULL, "SSL failure sending data (%d)", WSAGetLastError());
scRet = SEC_E_INTERNAL_ERROR;
break;
@@ -945,19 +777,19 @@ int NetlibSslWrite(SslHandle *ssl, const char *buf, int num)
static INT_PTR GetSslApi(WPARAM, LPARAM lParam)
{
- SSL_API* si = (SSL_API*)lParam;
- if (si == NULL) return FALSE;
+ SSL_API *si = (SSL_API*)lParam;
+ if (si == NULL)
+ return FALSE;
if (si->cbSize != sizeof(SSL_API))
return FALSE;
- si->connect = (HSSL (__cdecl *)(SOCKET, const char *, int))NetlibSslConnect;
- si->pending = (BOOL (__cdecl *)(HSSL))NetlibSslPending;
- si->read = (int (__cdecl *)(HSSL, char *, int, int))NetlibSslRead;
- si->write = (int (__cdecl *)(HSSL, const char *, int))NetlibSslWrite;
- si->shutdown = (void (__cdecl *)(HSSL))NetlibSslShutdown;
- si->sfree = (void (__cdecl *)(HSSL))NetlibSslFree;
-
+ si->connect = (HSSL(__cdecl *)(SOCKET, const char *, int))NetlibSslConnect;
+ si->pending = (BOOL(__cdecl *)(HSSL))NetlibSslPending;
+ si->read = (int(__cdecl *)(HSSL, char *, int, int))NetlibSslRead;
+ si->write = (int(__cdecl *)(HSSL, const char *, int))NetlibSslWrite;
+ si->shutdown = (void(__cdecl *)(HSSL))NetlibSslShutdown;
+ si->sfree = (void(__cdecl *)(HSSL))NetlibSslFree;
return TRUE;
}
@@ -966,7 +798,6 @@ int LoadSslModule(void)
CreateServiceFunction(MS_SYSTEM_GET_SI, GetSslApi);
g_hSslMutex = CreateMutex(NULL, FALSE, NULL);
SecInvalidateHandle(&hCreds);
-
return 0;
}