OSDN Git Service

0fa4b0d874c33d20f98b70aa7fb4e62883a47d71
[ffftp/ffftp.git] / socketwrapper.c
1 // socketwrapper.cpp
2 // Copyright (C) 2011 Suguru Kawamoto
3 // \83\\83P\83b\83g\83\89\83b\83p\81[
4 // socket\8aÖ\98A\8aÖ\90\94\82ðOpenSSL\97p\82É\92u\8a·
5 // \83R\83\93\83p\83C\83\8b\82É\82ÍOpenSSL\82Ì\83w\83b\83_\81[\83t\83@\83C\83\8b\82ª\95K\97v
6 // \8eÀ\8ds\82É\82ÍOpenSSL\82ÌDLL\82ª\95K\97v
7
8 #include <windows.h>
9 #include <mmsystem.h>
10 #include <openssl/ssl.h>
11
12 #include "socketwrapper.h"
13
14 typedef void (__stdcall* _SSL_load_error_strings)();
15 typedef int (__stdcall* _SSL_library_init)();
16 typedef SSL_METHOD* (__stdcall* _SSLv23_method)();
17 typedef SSL_CTX* (__stdcall* _SSL_CTX_new)(SSL_METHOD*);
18 typedef void (__stdcall* _SSL_CTX_free)(SSL_CTX*);
19 typedef SSL* (__stdcall* _SSL_new)(SSL_CTX*);
20 typedef void (__stdcall* _SSL_free)(SSL*);
21 typedef int (__stdcall* _SSL_shutdown)(SSL*);
22 typedef int (__stdcall* _SSL_get_fd)(SSL*);
23 typedef int (__stdcall* _SSL_set_fd)(SSL*, int);
24 typedef int (__stdcall* _SSL_accept)(SSL*);
25 typedef int (__stdcall* _SSL_connect)(SSL*);
26 typedef int (__stdcall* _SSL_write)(SSL*, const void*, int);
27 typedef int (__stdcall* _SSL_peek)(SSL*, void*, int);
28 typedef int (__stdcall* _SSL_read)(SSL*, void*, int);
29 typedef int (__stdcall* _SSL_get_error)(SSL*, int);
30
31 _SSL_load_error_strings pSSL_load_error_strings;
32 _SSL_library_init pSSL_library_init;
33 _SSLv23_method pSSLv23_method;
34 _SSL_CTX_new pSSL_CTX_new;
35 _SSL_CTX_free pSSL_CTX_free;
36 _SSL_new pSSL_new;
37 _SSL_free pSSL_free;
38 _SSL_shutdown pSSL_shutdown;
39 _SSL_get_fd pSSL_get_fd;
40 _SSL_set_fd pSSL_set_fd;
41 _SSL_accept pSSL_accept;
42 _SSL_connect pSSL_connect;
43 _SSL_write pSSL_write;
44 _SSL_peek pSSL_peek;
45 _SSL_read pSSL_read;
46 _SSL_get_error pSSL_get_error;
47
48 #define MAX_SSL_SOCKET 16
49
50 BOOL g_bOpenSSLLoaded;
51 HMODULE g_hOpenSSL;
52 CRITICAL_SECTION g_OpenSSLLock;
53 DWORD g_OpenSSLTimeout;
54 LPSSLTIMEOUTCALLBACK g_pOpenSSLTimeoutCallback;
55 SSL_CTX* g_pOpenSSLCTX;
56 SSL* g_pOpenSSLHandle[MAX_SSL_SOCKET];
57
58 BOOL __stdcall DefaultSSLTimeoutCallback()
59 {
60         Sleep(100);
61         return FALSE;
62 }
63
64 BOOL LoadOpenSSL()
65 {
66         if(g_bOpenSSLLoaded)
67                 return FALSE;
68         g_hOpenSSL = LoadLibrary("ssleay32.dll");
69         if(!g_hOpenSSL)
70                 g_hOpenSSL = LoadLibrary("libssl32.dll");
71         if(!g_hOpenSSL
72                 || !(pSSL_load_error_strings = (_SSL_load_error_strings)GetProcAddress(g_hOpenSSL, "SSL_load_error_strings"))
73                 || !(pSSL_library_init = (_SSL_library_init)GetProcAddress(g_hOpenSSL, "SSL_library_init"))
74                 || !(pSSLv23_method = (_SSLv23_method)GetProcAddress(g_hOpenSSL, "SSLv23_method"))
75                 || !(pSSL_CTX_new = (_SSL_CTX_new)GetProcAddress(g_hOpenSSL, "SSL_CTX_new"))
76                 || !(pSSL_CTX_free = (_SSL_CTX_free)GetProcAddress(g_hOpenSSL, "SSL_CTX_free"))
77                 || !(pSSL_new = (_SSL_new)GetProcAddress(g_hOpenSSL, "SSL_new"))
78                 || !(pSSL_free = (_SSL_free)GetProcAddress(g_hOpenSSL, "SSL_free"))
79                 || !(pSSL_shutdown = (_SSL_shutdown)GetProcAddress(g_hOpenSSL, "SSL_shutdown"))
80                 || !(pSSL_get_fd = (_SSL_get_fd)GetProcAddress(g_hOpenSSL, "SSL_get_fd"))
81                 || !(pSSL_set_fd = (_SSL_set_fd)GetProcAddress(g_hOpenSSL, "SSL_set_fd"))
82                 || !(pSSL_accept = (_SSL_accept)GetProcAddress(g_hOpenSSL, "SSL_accept"))
83                 || !(pSSL_connect = (_SSL_connect)GetProcAddress(g_hOpenSSL, "SSL_connect"))
84                 || !(pSSL_write = (_SSL_write)GetProcAddress(g_hOpenSSL, "SSL_write"))
85                 || !(pSSL_peek = (_SSL_peek)GetProcAddress(g_hOpenSSL, "SSL_peek"))
86                 || !(pSSL_read = (_SSL_read)GetProcAddress(g_hOpenSSL, "SSL_read"))
87                 || !(pSSL_get_error = (_SSL_get_error)GetProcAddress(g_hOpenSSL, "SSL_get_error")))
88         {
89                 if(g_hOpenSSL)
90                         FreeLibrary(g_hOpenSSL);
91                 g_hOpenSSL = NULL;
92                 return FALSE;
93         }
94         InitializeCriticalSection(&g_OpenSSLLock);
95         pSSL_load_error_strings();
96         pSSL_library_init();
97         SetSSLTimeoutCallback(60000, DefaultSSLTimeoutCallback);
98         g_bOpenSSLLoaded = TRUE;
99         return TRUE;
100 }
101
102 void FreeOpenSSL()
103 {
104         int i;
105         if(!g_bOpenSSLLoaded)
106                 return;
107         EnterCriticalSection(&g_OpenSSLLock);
108         for(i = 0; i < MAX_SSL_SOCKET; i++)
109         {
110                 if(g_pOpenSSLHandle[i])
111                 {
112                         pSSL_shutdown(g_pOpenSSLHandle[i]);
113                         pSSL_free(g_pOpenSSLHandle[i]);
114                         g_pOpenSSLHandle[i] = NULL;
115                 }
116         }
117         if(g_pOpenSSLCTX)
118                 pSSL_CTX_free(g_pOpenSSLCTX);
119         g_pOpenSSLCTX = NULL;
120         FreeLibrary(g_hOpenSSL);
121         g_hOpenSSL = NULL;
122         LeaveCriticalSection(&g_OpenSSLLock);
123         DeleteCriticalSection(&g_OpenSSLLock);
124         g_bOpenSSLLoaded = FALSE;
125 }
126
127 BOOL IsOpenSSLLoaded()
128 {
129         return g_bOpenSSLLoaded;
130 }
131
132 SSL** GetUnusedSSLPointer()
133 {
134         int i;
135         for(i = 0; i < MAX_SSL_SOCKET; i++)
136         {
137                 if(!g_pOpenSSLHandle[i])
138                         return &g_pOpenSSLHandle[i];
139         }
140         return NULL;
141 }
142
143 SSL** FindSSLPointerFromSocket(SOCKET s)
144 {
145         int i;
146         for(i = 0; i < MAX_SSL_SOCKET; i++)
147         {
148                 if(g_pOpenSSLHandle[i])
149                 {
150                         if(pSSL_get_fd(g_pOpenSSLHandle[i]) == s)
151                                 return &g_pOpenSSLHandle[i];
152                 }
153         }
154         return NULL;
155 }
156
157 void SetSSLTimeoutCallback(DWORD Timeout, LPSSLTIMEOUTCALLBACK pCallback)
158 {
159         if(!g_bOpenSSLLoaded)
160                 return;
161         EnterCriticalSection(&g_OpenSSLLock);
162         g_OpenSSLTimeout = Timeout;
163         g_pOpenSSLTimeoutCallback = pCallback;
164         LeaveCriticalSection(&g_OpenSSLLock);
165 }
166
167 BOOL AttachSSL(SOCKET s)
168 {
169         BOOL r;
170         DWORD Time;
171         SSL** ppSSL;
172         if(!g_bOpenSSLLoaded)
173                 return FALSE;
174         r = FALSE;
175         Time = timeGetTime();
176         EnterCriticalSection(&g_OpenSSLLock);
177         if(!g_pOpenSSLCTX)
178                 g_pOpenSSLCTX = pSSL_CTX_new(pSSLv23_method());
179         if(g_pOpenSSLCTX)
180         {
181                 if(ppSSL = GetUnusedSSLPointer())
182                 {
183                         if(*ppSSL = pSSL_new(g_pOpenSSLCTX))
184                         {
185                                 if(pSSL_set_fd(*ppSSL, s) != 0)
186                                 {
187                                         r = TRUE;
188                                         // SSL\82Ì\83l\83S\83V\83G\81[\83V\83\87\83\93\82É\82Í\8e\9e\8aÔ\82ª\82©\82©\82é\8fê\8d\87\82ª\82 \82é
189                                         while(pSSL_connect(*ppSSL) != 1)
190                                         {
191                                                 LeaveCriticalSection(&g_OpenSSLLock);
192                                                 if(g_pOpenSSLTimeoutCallback() || timeGetTime() - Time >= g_OpenSSLTimeout)
193                                                 {
194                                                         DetachSSL(s);
195                                                         r = FALSE;
196                                                         EnterCriticalSection(&g_OpenSSLLock);
197                                                         break;
198                                                 }
199                                                 EnterCriticalSection(&g_OpenSSLLock);
200                                         }
201                                 }
202                                 else
203                                         DetachSSL(s);
204                         }
205                 }
206         }
207         LeaveCriticalSection(&g_OpenSSLLock);
208         return r;
209 }
210
211 BOOL DetachSSL(SOCKET s)
212 {
213         BOOL r;
214         SSL** ppSSL;
215         if(!g_bOpenSSLLoaded)
216                 return FALSE;
217         r = FALSE;
218         EnterCriticalSection(&g_OpenSSLLock);
219         if(ppSSL = FindSSLPointerFromSocket(s))
220         {
221                 pSSL_shutdown(*ppSSL);
222                 pSSL_free(*ppSSL);
223                 *ppSSL = NULL;
224                 r = TRUE;
225         }
226         LeaveCriticalSection(&g_OpenSSLLock);
227         return r;
228 }
229
230 BOOL IsSSLAttached(SOCKET s)
231 {
232         SSL** ppSSL;
233         if(!g_bOpenSSLLoaded)
234                 return FALSE;
235         EnterCriticalSection(&g_OpenSSLLock);
236         ppSSL = FindSSLPointerFromSocket(s);
237         LeaveCriticalSection(&g_OpenSSLLock);
238         if(!ppSSL)
239                 return TRUE;
240         return TRUE;
241 }
242
243 SOCKET socketS(int af, int type, int protocol)
244 {
245         return socket(af, type, protocol);
246 }
247
248 int bindS(SOCKET s, const struct sockaddr *addr, int namelen)
249 {
250         return bind(s, addr, namelen);
251 }
252
253 int listenS(SOCKET s, int backlog)
254 {
255         return listen(s, backlog);
256 }
257
258 SOCKET acceptS(SOCKET s, struct sockaddr *addr, int *addrlen)
259 {
260         SOCKET r;
261         r = accept(s, addr, addrlen);
262         if(!AttachSSL(r))
263         {
264                 closesocket(r);
265                 return INVALID_SOCKET;
266         }
267         return r;
268 }
269
270 int connectS(SOCKET s, const struct sockaddr *name, int namelen)
271 {
272         int r;
273         r = connect(s, name, namelen);
274         if(!AttachSSL(r))
275                 return SOCKET_ERROR;
276         return r;
277 }
278
279 int closesocketS(SOCKET s)
280 {
281         DetachSSL(s);
282         return closesocket(s);
283 }
284
285 int sendS(SOCKET s, const char * buf, int len, int flags)
286 {
287         SSL** ppSSL;
288         if(!g_bOpenSSLLoaded)
289                 return send(s, buf, len, flags);
290         EnterCriticalSection(&g_OpenSSLLock);
291         ppSSL = FindSSLPointerFromSocket(s);
292         LeaveCriticalSection(&g_OpenSSLLock);
293         if(!ppSSL)
294                 return send(s, buf, len, flags);
295         return pSSL_write(*ppSSL, buf, len);
296 }
297
298 int recvS(SOCKET s, char * buf, int len, int flags)
299 {
300         SSL** ppSSL;
301         if(!g_bOpenSSLLoaded)
302                 return recv(s, buf, len, flags);
303         EnterCriticalSection(&g_OpenSSLLock);
304         ppSSL = FindSSLPointerFromSocket(s);
305         LeaveCriticalSection(&g_OpenSSLLock);
306         if(!ppSSL)
307                 return recv(s, buf, len, flags);
308         if(flags & MSG_PEEK)
309                 return pSSL_peek(*ppSSL, buf, len);
310         return pSSL_read(*ppSSL, buf, len);
311 }
312