+// socketwrapper.cpp
+// Copyright (C) 2011 Suguru Kawamoto
+// \83\\83P\83b\83g\83\89\83b\83p\81[
+// socket\8aÖ\98A\8aÖ\90\94\82ðOpenSSL\97p\82É\92u\8a·
+// \83R\83\93\83p\83C\83\8b\82É\82ÍOpenSSL\82Ì\83w\83b\83_\81[\83t\83@\83C\83\8b\82ª\95K\97v
+// \8eÀ\8ds\82É\82ÍOpenSSL\82ÌDLL\82ª\95K\97v
+
+#include <windows.h>
+#include <mmsystem.h>
+#include <openssl/ssl.h>
+
+#include "socketwrapper.h"
+
+typedef void (__stdcall* _SSL_load_error_strings)();
+typedef int (__stdcall* _SSL_library_init)();
+typedef SSL_METHOD* (__stdcall* _SSLv23_method)();
+typedef SSL_CTX* (__stdcall* _SSL_CTX_new)(SSL_METHOD*);
+typedef void (__stdcall* _SSL_CTX_free)(SSL_CTX*);
+typedef SSL* (__stdcall* _SSL_new)(SSL_CTX*);
+typedef void (__stdcall* _SSL_free)(SSL*);
+typedef int (__stdcall* _SSL_shutdown)(SSL*);
+typedef int (__stdcall* _SSL_get_fd)(SSL*);
+typedef int (__stdcall* _SSL_set_fd)(SSL*, int);
+typedef int (__stdcall* _SSL_accept)(SSL*);
+typedef int (__stdcall* _SSL_connect)(SSL*);
+typedef int (__stdcall* _SSL_write)(SSL*, const void*, int);
+typedef int (__stdcall* _SSL_peek)(SSL*, void*, int);
+typedef int (__stdcall* _SSL_read)(SSL*, void*, int);
+typedef int (__stdcall* _SSL_get_error)(SSL*, int);
+
+_SSL_load_error_strings pSSL_load_error_strings;
+_SSL_library_init pSSL_library_init;
+_SSLv23_method pSSLv23_method;
+_SSL_CTX_new pSSL_CTX_new;
+_SSL_CTX_free pSSL_CTX_free;
+_SSL_new pSSL_new;
+_SSL_free pSSL_free;
+_SSL_shutdown pSSL_shutdown;
+_SSL_get_fd pSSL_get_fd;
+_SSL_set_fd pSSL_set_fd;
+_SSL_accept pSSL_accept;
+_SSL_connect pSSL_connect;
+_SSL_write pSSL_write;
+_SSL_peek pSSL_peek;
+_SSL_read pSSL_read;
+_SSL_get_error pSSL_get_error;
+
+#define MAX_SSL_SOCKET 16
+
+BOOL g_bOpenSSLLoaded;
+HMODULE g_hOpenSSL;
+CRITICAL_SECTION g_OpenSSLLock;
+DWORD g_OpenSSLTimeout;
+LPSSLTIMEOUTCALLBACK g_pOpenSSLTimeoutCallback;
+SSL_CTX* g_pOpenSSLCTX;
+SSL* g_pOpenSSLHandle[MAX_SSL_SOCKET];
+
+BOOL __stdcall DefaultSSLTimeoutCallback()
+{
+ Sleep(100);
+ return FALSE;
+}
+
+BOOL LoadOpenSSL()
+{
+ if(g_bOpenSSLLoaded)
+ return FALSE;
+ g_hOpenSSL = LoadLibrary("ssleay32.dll");
+ if(!g_hOpenSSL)
+ g_hOpenSSL = LoadLibrary("libssl32.dll");
+ if(!g_hOpenSSL
+ || !(pSSL_load_error_strings = (_SSL_load_error_strings)GetProcAddress(g_hOpenSSL, "SSL_load_error_strings"))
+ || !(pSSL_library_init = (_SSL_library_init)GetProcAddress(g_hOpenSSL, "SSL_library_init"))
+ || !(pSSLv23_method = (_SSLv23_method)GetProcAddress(g_hOpenSSL, "SSLv23_method"))
+ || !(pSSL_CTX_new = (_SSL_CTX_new)GetProcAddress(g_hOpenSSL, "SSL_CTX_new"))
+ || !(pSSL_CTX_free = (_SSL_CTX_free)GetProcAddress(g_hOpenSSL, "SSL_CTX_free"))
+ || !(pSSL_new = (_SSL_new)GetProcAddress(g_hOpenSSL, "SSL_new"))
+ || !(pSSL_free = (_SSL_free)GetProcAddress(g_hOpenSSL, "SSL_free"))
+ || !(pSSL_shutdown = (_SSL_shutdown)GetProcAddress(g_hOpenSSL, "SSL_shutdown"))
+ || !(pSSL_get_fd = (_SSL_get_fd)GetProcAddress(g_hOpenSSL, "SSL_get_fd"))
+ || !(pSSL_set_fd = (_SSL_set_fd)GetProcAddress(g_hOpenSSL, "SSL_set_fd"))
+ || !(pSSL_accept = (_SSL_accept)GetProcAddress(g_hOpenSSL, "SSL_accept"))
+ || !(pSSL_connect = (_SSL_connect)GetProcAddress(g_hOpenSSL, "SSL_connect"))
+ || !(pSSL_write = (_SSL_write)GetProcAddress(g_hOpenSSL, "SSL_write"))
+ || !(pSSL_peek = (_SSL_peek)GetProcAddress(g_hOpenSSL, "SSL_peek"))
+ || !(pSSL_read = (_SSL_read)GetProcAddress(g_hOpenSSL, "SSL_read"))
+ || !(pSSL_get_error = (_SSL_get_error)GetProcAddress(g_hOpenSSL, "SSL_get_error")))
+ {
+ FreeLibrary(g_hOpenSSL);
+ g_hOpenSSL = NULL;
+ return FALSE;
+ }
+ InitializeCriticalSection(&g_OpenSSLLock);
+ pSSL_load_error_strings();
+ pSSL_library_init();
+ SetSSLTimeoutCallback(60000, DefaultSSLTimeoutCallback);
+ g_bOpenSSLLoaded = TRUE;
+ return TRUE;
+}
+
+void FreeOpenSSL()
+{
+ int i;
+ if(!g_bOpenSSLLoaded)
+ return;
+ EnterCriticalSection(&g_OpenSSLLock);
+ for(i = 0; i < MAX_SSL_SOCKET; i++)
+ {
+ if(g_pOpenSSLHandle[i])
+ {
+ pSSL_shutdown(g_pOpenSSLHandle[i]);
+ pSSL_free(g_pOpenSSLHandle[i]);
+ g_pOpenSSLHandle[i] = NULL;
+ }
+ }
+ if(g_pOpenSSLCTX)
+ pSSL_CTX_free(g_pOpenSSLCTX);
+ g_pOpenSSLCTX = NULL;
+ FreeLibrary(g_hOpenSSL);
+ g_hOpenSSL = NULL;
+ LeaveCriticalSection(&g_OpenSSLLock);
+ DeleteCriticalSection(&g_OpenSSLLock);
+ g_bOpenSSLLoaded = FALSE;
+}
+
+BOOL IsOpenSSLLoaded()
+{
+ return g_bOpenSSLLoaded;
+}
+
+SSL** GetUnusedSSLPointer()
+{
+ int i;
+ for(i = 0; i < MAX_SSL_SOCKET; i++)
+ {
+ if(!g_pOpenSSLHandle[i])
+ return &g_pOpenSSLHandle[i];
+ }
+ return NULL;
+}
+
+SSL** FindSSLPointerFromSocket(SOCKET s)
+{
+ int i;
+ for(i = 0; i < MAX_SSL_SOCKET; i++)
+ {
+ if(g_pOpenSSLHandle[i])
+ {
+ if(pSSL_get_fd(g_pOpenSSLHandle[i]) == s)
+ return &g_pOpenSSLHandle[i];
+ }
+ }
+ return NULL;
+}
+
+void SetSSLTimeoutCallback(DWORD Timeout, LPSSLTIMEOUTCALLBACK pCallback)
+{
+ EnterCriticalSection(&g_OpenSSLLock);
+ g_OpenSSLTimeout = Timeout;
+ g_pOpenSSLTimeoutCallback = pCallback;
+ LeaveCriticalSection(&g_OpenSSLLock);
+}
+
+BOOL AttachSSL(SOCKET s)
+{
+ BOOL r;
+ DWORD Time;
+ SSL** ppSSL;
+ r = FALSE;
+ Time = timeGetTime();
+ EnterCriticalSection(&g_OpenSSLLock);
+ if(!g_pOpenSSLCTX)
+ g_pOpenSSLCTX = pSSL_CTX_new(pSSLv23_method());
+ if(g_pOpenSSLCTX)
+ {
+ if(ppSSL = GetUnusedSSLPointer())
+ {
+ if(*ppSSL = pSSL_new(g_pOpenSSLCTX))
+ {
+ if(pSSL_set_fd(*ppSSL, s) != 0)
+ {
+ r = TRUE;
+ // SSL\82Ì\83l\83S\83V\83G\81[\83V\83\87\83\93\82É\82Í\8e\9e\8aÔ\82ª\82©\82©\82é\8fê\8d\87\82ª\82 \82é
+ while(pSSL_connect(*ppSSL) != 1)
+ {
+ LeaveCriticalSection(&g_OpenSSLLock);
+ if(g_pOpenSSLTimeoutCallback() || timeGetTime() - Time >= g_OpenSSLTimeout)
+ {
+ DetachSSL(s);
+ r = FALSE;
+ EnterCriticalSection(&g_OpenSSLLock);
+ break;
+ }
+ EnterCriticalSection(&g_OpenSSLLock);
+ }
+ }
+ else
+ DetachSSL(s);
+ }
+ }
+ }
+ LeaveCriticalSection(&g_OpenSSLLock);
+ return r;
+}
+
+BOOL DetachSSL(SOCKET s)
+{
+ BOOL r;
+ SSL** ppSSL;
+ r = FALSE;
+ EnterCriticalSection(&g_OpenSSLLock);
+ if(ppSSL = FindSSLPointerFromSocket(s))
+ {
+ pSSL_shutdown(*ppSSL);
+ pSSL_free(*ppSSL);
+ *ppSSL = NULL;
+ r = TRUE;
+ }
+ LeaveCriticalSection(&g_OpenSSLLock);
+ return r;
+}
+
+BOOL IsSSLAttached(SOCKET s)
+{
+ SSL** ppSSL;
+ EnterCriticalSection(&g_OpenSSLLock);
+ ppSSL = FindSSLPointerFromSocket(s);
+ LeaveCriticalSection(&g_OpenSSLLock);
+ if(!ppSSL)
+ return TRUE;
+ return TRUE;
+}
+
+SOCKET socketS(int af, int type, int protocol)
+{
+ return socket(af, type, protocol);
+}
+
+int bindS(SOCKET s, const struct sockaddr *addr, int namelen)
+{
+ return bind(s, addr, namelen);
+}
+
+int listenS(SOCKET s, int backlog)
+{
+ return listen(s, backlog);
+}
+
+SOCKET acceptS(SOCKET s, struct sockaddr *addr, int *addrlen)
+{
+ SOCKET r;
+ r = accept(s, addr, addrlen);
+ if(!AttachSSL(r))
+ {
+ closesocket(r);
+ return INVALID_SOCKET;
+ }
+ return r;
+}
+
+int connectS(SOCKET s, const struct sockaddr *name, int namelen)
+{
+ int r;
+ r = connect(s, name, namelen);
+ if(!AttachSSL(r))
+ return SOCKET_ERROR;
+ return r;
+}
+
+int closesocketS(SOCKET s)
+{
+ DetachSSL(s);
+ return closesocket(s);
+}
+
+int sendS(SOCKET s, const char * buf, int len, int flags)
+{
+ SSL** ppSSL;
+ EnterCriticalSection(&g_OpenSSLLock);
+ ppSSL = FindSSLPointerFromSocket(s);
+ LeaveCriticalSection(&g_OpenSSLLock);
+ if(!ppSSL)
+ return send(s, buf, len, flags);
+ return pSSL_write(*ppSSL, buf, len);
+}
+
+int recvS(SOCKET s, char * buf, int len, int flags)
+{
+ SSL** ppSSL;
+ EnterCriticalSection(&g_OpenSSLLock);
+ ppSSL = FindSSLPointerFromSocket(s);
+ LeaveCriticalSection(&g_OpenSSLLock);
+ if(!ppSSL)
+ return recv(s, buf, len, flags);
+ if(flags & MSG_PEEK)
+ return pSSL_peek(*ppSSL, buf, len);
+ return pSSL_read(*ppSSL, buf, len);
+}
+