OSDN Git Service

e3d4623054a7678ce90a52e404ef783f3df5429d
[ffftp/ffftp.git] / socketwrapper.c
1 // socketwrapper.cpp
2 // Copyright (C) 2011 Suguru Kawamoto
3 // ソケットラッパー
4 // socket関連関数をOpenSSL用に置換
5 // コンパイルにはOpenSSLのヘッダーファイルが必要
6 // 実行にはOpenSSLのDLLが必要
7
8 #include <windows.h>
9 #include <mmsystem.h>
10 #include <openssl/ssl.h>
11
12 #include "socketwrapper.h"
13
14 typedef void (__cdecl* _SSL_load_error_strings)();
15 typedef int (__cdecl* _SSL_library_init)();
16 typedef SSL_METHOD* (__cdecl* _SSLv23_method)();
17 typedef SSL_CTX* (__cdecl* _SSL_CTX_new)(SSL_METHOD*);
18 typedef void (__cdecl* _SSL_CTX_free)(SSL_CTX*);
19 typedef SSL* (__cdecl* _SSL_new)(SSL_CTX*);
20 typedef void (__cdecl* _SSL_free)(SSL*);
21 typedef int (__cdecl* _SSL_shutdown)(SSL*);
22 typedef int (__cdecl* _SSL_get_fd)(SSL*);
23 typedef int (__cdecl* _SSL_set_fd)(SSL*, int);
24 typedef int (__cdecl* _SSL_accept)(SSL*);
25 typedef int (__cdecl* _SSL_connect)(SSL*);
26 typedef int (__cdecl* _SSL_write)(SSL*, const void*, int);
27 typedef int (__cdecl* _SSL_peek)(SSL*, void*, int);
28 typedef int (__cdecl* _SSL_read)(SSL*, void*, int);
29 typedef int (__cdecl* _SSL_get_error)(SSL*, int);
30
31 _SSL_load_error_strings pSSL_load_error_strings;
32 _SSL_library_init pSSL_library_init;
33 _SSLv23_method pSSLv23_method;
34 _SSL_CTX_new pSSL_CTX_new;
35 _SSL_CTX_free pSSL_CTX_free;
36 _SSL_new pSSL_new;
37 _SSL_free pSSL_free;
38 _SSL_shutdown pSSL_shutdown;
39 _SSL_get_fd pSSL_get_fd;
40 _SSL_set_fd pSSL_set_fd;
41 _SSL_accept pSSL_accept;
42 _SSL_connect pSSL_connect;
43 _SSL_write pSSL_write;
44 _SSL_peek pSSL_peek;
45 _SSL_read pSSL_read;
46 _SSL_get_error pSSL_get_error;
47
48 #define MAX_SSL_SOCKET 16
49
50 BOOL g_bOpenSSLLoaded;
51 HMODULE g_hOpenSSL;
52 CRITICAL_SECTION g_OpenSSLLock;
53 DWORD g_OpenSSLTimeout;
54 LPSSLTIMEOUTCALLBACK g_pOpenSSLTimeoutCallback;
55 SSL_CTX* g_pOpenSSLCTX;
56 SSL* g_pOpenSSLHandle[MAX_SSL_SOCKET];
57
58 BOOL __stdcall DefaultSSLTimeoutCallback()
59 {
60         Sleep(100);
61         return FALSE;
62 }
63
64 BOOL LoadOpenSSL()
65 {
66         if(g_bOpenSSLLoaded)
67                 return FALSE;
68         g_hOpenSSL = LoadLibrary("ssleay32.dll");
69         if(!g_hOpenSSL)
70                 g_hOpenSSL = LoadLibrary("libssl32.dll");
71         if(!g_hOpenSSL
72                 || !(pSSL_load_error_strings = (_SSL_load_error_strings)GetProcAddress(g_hOpenSSL, "SSL_load_error_strings"))
73                 || !(pSSL_library_init = (_SSL_library_init)GetProcAddress(g_hOpenSSL, "SSL_library_init"))
74                 || !(pSSLv23_method = (_SSLv23_method)GetProcAddress(g_hOpenSSL, "SSLv23_method"))
75                 || !(pSSL_CTX_new = (_SSL_CTX_new)GetProcAddress(g_hOpenSSL, "SSL_CTX_new"))
76                 || !(pSSL_CTX_free = (_SSL_CTX_free)GetProcAddress(g_hOpenSSL, "SSL_CTX_free"))
77                 || !(pSSL_new = (_SSL_new)GetProcAddress(g_hOpenSSL, "SSL_new"))
78                 || !(pSSL_free = (_SSL_free)GetProcAddress(g_hOpenSSL, "SSL_free"))
79                 || !(pSSL_shutdown = (_SSL_shutdown)GetProcAddress(g_hOpenSSL, "SSL_shutdown"))
80                 || !(pSSL_get_fd = (_SSL_get_fd)GetProcAddress(g_hOpenSSL, "SSL_get_fd"))
81                 || !(pSSL_set_fd = (_SSL_set_fd)GetProcAddress(g_hOpenSSL, "SSL_set_fd"))
82                 || !(pSSL_accept = (_SSL_accept)GetProcAddress(g_hOpenSSL, "SSL_accept"))
83                 || !(pSSL_connect = (_SSL_connect)GetProcAddress(g_hOpenSSL, "SSL_connect"))
84                 || !(pSSL_write = (_SSL_write)GetProcAddress(g_hOpenSSL, "SSL_write"))
85                 || !(pSSL_peek = (_SSL_peek)GetProcAddress(g_hOpenSSL, "SSL_peek"))
86                 || !(pSSL_read = (_SSL_read)GetProcAddress(g_hOpenSSL, "SSL_read"))
87                 || !(pSSL_get_error = (_SSL_get_error)GetProcAddress(g_hOpenSSL, "SSL_get_error")))
88         {
89                 if(g_hOpenSSL)
90                         FreeLibrary(g_hOpenSSL);
91                 g_hOpenSSL = NULL;
92                 return FALSE;
93         }
94         InitializeCriticalSection(&g_OpenSSLLock);
95         pSSL_load_error_strings();
96         pSSL_library_init();
97         SetSSLTimeoutCallback(60000, DefaultSSLTimeoutCallback);
98         g_bOpenSSLLoaded = TRUE;
99         return TRUE;
100 }
101
102 void FreeOpenSSL()
103 {
104         int i;
105         if(!g_bOpenSSLLoaded)
106                 return;
107         EnterCriticalSection(&g_OpenSSLLock);
108         for(i = 0; i < MAX_SSL_SOCKET; i++)
109         {
110                 if(g_pOpenSSLHandle[i])
111                 {
112                         pSSL_shutdown(g_pOpenSSLHandle[i]);
113                         pSSL_free(g_pOpenSSLHandle[i]);
114                         g_pOpenSSLHandle[i] = NULL;
115                 }
116         }
117         if(g_pOpenSSLCTX)
118                 pSSL_CTX_free(g_pOpenSSLCTX);
119         g_pOpenSSLCTX = NULL;
120         FreeLibrary(g_hOpenSSL);
121         g_hOpenSSL = NULL;
122         LeaveCriticalSection(&g_OpenSSLLock);
123         DeleteCriticalSection(&g_OpenSSLLock);
124         g_bOpenSSLLoaded = FALSE;
125 }
126
127 BOOL IsOpenSSLLoaded()
128 {
129         return g_bOpenSSLLoaded;
130 }
131
132 SSL** GetUnusedSSLPointer()
133 {
134         int i;
135         for(i = 0; i < MAX_SSL_SOCKET; i++)
136         {
137                 if(!g_pOpenSSLHandle[i])
138                         return &g_pOpenSSLHandle[i];
139         }
140         return NULL;
141 }
142
143 SSL** FindSSLPointerFromSocket(SOCKET s)
144 {
145         int i;
146         for(i = 0; i < MAX_SSL_SOCKET; i++)
147         {
148                 if(g_pOpenSSLHandle[i])
149                 {
150                         if(pSSL_get_fd(g_pOpenSSLHandle[i]) == s)
151                                 return &g_pOpenSSLHandle[i];
152                 }
153         }
154         return NULL;
155 }
156
157 void SetSSLTimeoutCallback(DWORD Timeout, LPSSLTIMEOUTCALLBACK pCallback)
158 {
159         if(!g_bOpenSSLLoaded)
160                 return;
161         EnterCriticalSection(&g_OpenSSLLock);
162         g_OpenSSLTimeout = Timeout;
163         g_pOpenSSLTimeoutCallback = pCallback;
164         LeaveCriticalSection(&g_OpenSSLLock);
165 }
166
167 BOOL AttachSSL(SOCKET s)
168 {
169         BOOL r;
170         DWORD Time;
171         SSL** ppSSL;
172         if(!g_bOpenSSLLoaded)
173                 return FALSE;
174         r = FALSE;
175         Time = timeGetTime();
176         EnterCriticalSection(&g_OpenSSLLock);
177         if(!g_pOpenSSLCTX)
178                 g_pOpenSSLCTX = pSSL_CTX_new(pSSLv23_method());
179         if(g_pOpenSSLCTX)
180         {
181                 if(ppSSL = GetUnusedSSLPointer())
182                 {
183                         if(*ppSSL = pSSL_new(g_pOpenSSLCTX))
184                         {
185                                 if(pSSL_set_fd(*ppSSL, s) != 0)
186                                 {
187                                         r = TRUE;
188                                         // SSLのネゴシエーションには時間がかかる場合がある
189                                         while(pSSL_connect(*ppSSL) != 1)
190                                         {
191                                                 LeaveCriticalSection(&g_OpenSSLLock);
192                                                 if(g_pOpenSSLTimeoutCallback() || timeGetTime() - Time >= g_OpenSSLTimeout)
193                                                 {
194                                                         DetachSSL(s);
195                                                         r = FALSE;
196                                                         EnterCriticalSection(&g_OpenSSLLock);
197                                                         break;
198                                                 }
199                                                 EnterCriticalSection(&g_OpenSSLLock);
200                                         }
201                                 }
202                                 else
203                                 {
204                                         LeaveCriticalSection(&g_OpenSSLLock);
205                                         DetachSSL(s);
206                                         EnterCriticalSection(&g_OpenSSLLock);
207                                 }
208                         }
209                 }
210         }
211         LeaveCriticalSection(&g_OpenSSLLock);
212         return r;
213 }
214
215 BOOL DetachSSL(SOCKET s)
216 {
217         BOOL r;
218         SSL** ppSSL;
219         if(!g_bOpenSSLLoaded)
220                 return FALSE;
221         r = FALSE;
222         EnterCriticalSection(&g_OpenSSLLock);
223         if(ppSSL = FindSSLPointerFromSocket(s))
224         {
225                 pSSL_shutdown(*ppSSL);
226                 pSSL_free(*ppSSL);
227                 *ppSSL = NULL;
228                 r = TRUE;
229         }
230         LeaveCriticalSection(&g_OpenSSLLock);
231         return r;
232 }
233
234 BOOL IsSSLAttached(SOCKET s)
235 {
236         SSL** ppSSL;
237         if(!g_bOpenSSLLoaded)
238                 return FALSE;
239         EnterCriticalSection(&g_OpenSSLLock);
240         ppSSL = FindSSLPointerFromSocket(s);
241         LeaveCriticalSection(&g_OpenSSLLock);
242         if(!ppSSL)
243                 return TRUE;
244         return TRUE;
245 }
246
247 SOCKET socketS(int af, int type, int protocol)
248 {
249         return socket(af, type, protocol);
250 }
251
252 int bindS(SOCKET s, const struct sockaddr *addr, int namelen)
253 {
254         return bind(s, addr, namelen);
255 }
256
257 int listenS(SOCKET s, int backlog)
258 {
259         return listen(s, backlog);
260 }
261
262 SOCKET acceptS(SOCKET s, struct sockaddr *addr, int *addrlen)
263 {
264         SOCKET r;
265         r = accept(s, addr, addrlen);
266         if(!AttachSSL(r))
267         {
268                 closesocket(r);
269                 return INVALID_SOCKET;
270         }
271         return r;
272 }
273
274 int connectS(SOCKET s, const struct sockaddr *name, int namelen)
275 {
276         int r;
277         r = connect(s, name, namelen);
278         if(!AttachSSL(r))
279                 return SOCKET_ERROR;
280         return r;
281 }
282
283 int closesocketS(SOCKET s)
284 {
285         DetachSSL(s);
286         return closesocket(s);
287 }
288
289 int sendS(SOCKET s, const char * buf, int len, int flags)
290 {
291         SSL** ppSSL;
292         if(!g_bOpenSSLLoaded)
293                 return send(s, buf, len, flags);
294         EnterCriticalSection(&g_OpenSSLLock);
295         ppSSL = FindSSLPointerFromSocket(s);
296         LeaveCriticalSection(&g_OpenSSLLock);
297         if(!ppSSL)
298                 return send(s, buf, len, flags);
299         return pSSL_write(*ppSSL, buf, len);
300 }
301
302 int recvS(SOCKET s, char * buf, int len, int flags)
303 {
304         SSL** ppSSL;
305         if(!g_bOpenSSLLoaded)
306                 return recv(s, buf, len, flags);
307         EnterCriticalSection(&g_OpenSSLLock);
308         ppSSL = FindSSLPointerFromSocket(s);
309         LeaveCriticalSection(&g_OpenSSLLock);
310         if(!ppSSL)
311                 return recv(s, buf, len, flags);
312         if(flags & MSG_PEEK)
313                 return pSSL_peek(*ppSSL, buf, len);
314         return pSSL_read(*ppSSL, buf, len);
315 }
316