OSDN Git Service

Add support for process protection (prevents from loading untrustworthy DLLs with...
[ffftp/ffftp.git] / socketwrapper.c
1 // socketwrapper.c
2 // Copyright (C) 2011 Suguru Kawamoto
3 // ソケットラッパー
4 // socket関連関数をOpenSSL用に置換
5 // コンパイルにはOpenSSLのヘッダーファイルが必要
6 // 実行にはOpenSSLのDLLが必要
7
8 #define _WIN32_WINNT 0x0600
9
10 #include <windows.h>
11 #include <mmsystem.h>
12 #include <openssl/ssl.h>
13
14 #include "socketwrapper.h"
15 #include "protectprocess.h"
16
17 typedef void (__cdecl* _SSL_load_error_strings)();
18 typedef int (__cdecl* _SSL_library_init)();
19 typedef SSL_METHOD* (__cdecl* _SSLv23_method)();
20 typedef SSL_CTX* (__cdecl* _SSL_CTX_new)(SSL_METHOD*);
21 typedef void (__cdecl* _SSL_CTX_free)(SSL_CTX*);
22 typedef SSL* (__cdecl* _SSL_new)(SSL_CTX*);
23 typedef void (__cdecl* _SSL_free)(SSL*);
24 typedef int (__cdecl* _SSL_shutdown)(SSL*);
25 typedef int (__cdecl* _SSL_get_fd)(SSL*);
26 typedef int (__cdecl* _SSL_set_fd)(SSL*, int);
27 typedef int (__cdecl* _SSL_accept)(SSL*);
28 typedef int (__cdecl* _SSL_connect)(SSL*);
29 typedef int (__cdecl* _SSL_write)(SSL*, const void*, int);
30 typedef int (__cdecl* _SSL_peek)(SSL*, void*, int);
31 typedef int (__cdecl* _SSL_read)(SSL*, void*, int);
32 typedef int (__cdecl* _SSL_get_error)(SSL*, int);
33
34 _SSL_load_error_strings pSSL_load_error_strings;
35 _SSL_library_init pSSL_library_init;
36 _SSLv23_method pSSLv23_method;
37 _SSL_CTX_new pSSL_CTX_new;
38 _SSL_CTX_free pSSL_CTX_free;
39 _SSL_new pSSL_new;
40 _SSL_free pSSL_free;
41 _SSL_shutdown pSSL_shutdown;
42 _SSL_get_fd pSSL_get_fd;
43 _SSL_set_fd pSSL_set_fd;
44 _SSL_accept pSSL_accept;
45 _SSL_connect pSSL_connect;
46 _SSL_write pSSL_write;
47 _SSL_peek pSSL_peek;
48 _SSL_read pSSL_read;
49 _SSL_get_error pSSL_get_error;
50
51 #define MAX_SSL_SOCKET 64
52
53 BOOL g_bOpenSSLLoaded;
54 HMODULE g_hOpenSSL;
55 CRITICAL_SECTION g_OpenSSLLock;
56 DWORD g_OpenSSLTimeout;
57 LPSSLTIMEOUTCALLBACK g_pOpenSSLTimeoutCallback;
58 SSL_CTX* g_pOpenSSLCTX;
59 SSL* g_pOpenSSLHandle[MAX_SSL_SOCKET];
60
61 BOOL __stdcall DefaultSSLTimeoutCallback()
62 {
63         Sleep(100);
64         return FALSE;
65 }
66
67 BOOL LoadOpenSSL()
68 {
69         if(g_bOpenSSLLoaded)
70                 return FALSE;
71 #ifdef ENABLE_PROCESS_PROTECTION
72         // ssleay32.dll 1.0.0e
73         // libssl32.dll 1.0.0e
74         RegisterModuleMD5Hash("\x8B\xA3\xB7\xB3\xCE\x2E\x4F\x07\x8C\xB8\x93\x7D\x77\xE1\x09\x3A");
75         // libeay32.dll 1.0.0e
76         RegisterModuleMD5Hash("\xA6\x4C\xAF\x9E\xF3\xDC\xFC\x68\xAE\xCA\xCC\x61\xD2\xF6\x70\x8B");
77 #endif
78         g_hOpenSSL = LoadLibrary("ssleay32.dll");
79         if(!g_hOpenSSL)
80                 g_hOpenSSL = LoadLibrary("libssl32.dll");
81         if(!g_hOpenSSL
82                 || !(pSSL_load_error_strings = (_SSL_load_error_strings)GetProcAddress(g_hOpenSSL, "SSL_load_error_strings"))
83                 || !(pSSL_library_init = (_SSL_library_init)GetProcAddress(g_hOpenSSL, "SSL_library_init"))
84                 || !(pSSLv23_method = (_SSLv23_method)GetProcAddress(g_hOpenSSL, "SSLv23_method"))
85                 || !(pSSL_CTX_new = (_SSL_CTX_new)GetProcAddress(g_hOpenSSL, "SSL_CTX_new"))
86                 || !(pSSL_CTX_free = (_SSL_CTX_free)GetProcAddress(g_hOpenSSL, "SSL_CTX_free"))
87                 || !(pSSL_new = (_SSL_new)GetProcAddress(g_hOpenSSL, "SSL_new"))
88                 || !(pSSL_free = (_SSL_free)GetProcAddress(g_hOpenSSL, "SSL_free"))
89                 || !(pSSL_shutdown = (_SSL_shutdown)GetProcAddress(g_hOpenSSL, "SSL_shutdown"))
90                 || !(pSSL_get_fd = (_SSL_get_fd)GetProcAddress(g_hOpenSSL, "SSL_get_fd"))
91                 || !(pSSL_set_fd = (_SSL_set_fd)GetProcAddress(g_hOpenSSL, "SSL_set_fd"))
92                 || !(pSSL_accept = (_SSL_accept)GetProcAddress(g_hOpenSSL, "SSL_accept"))
93                 || !(pSSL_connect = (_SSL_connect)GetProcAddress(g_hOpenSSL, "SSL_connect"))
94                 || !(pSSL_write = (_SSL_write)GetProcAddress(g_hOpenSSL, "SSL_write"))
95                 || !(pSSL_peek = (_SSL_peek)GetProcAddress(g_hOpenSSL, "SSL_peek"))
96                 || !(pSSL_read = (_SSL_read)GetProcAddress(g_hOpenSSL, "SSL_read"))
97                 || !(pSSL_get_error = (_SSL_get_error)GetProcAddress(g_hOpenSSL, "SSL_get_error")))
98         {
99                 if(g_hOpenSSL)
100                         FreeLibrary(g_hOpenSSL);
101                 g_hOpenSSL = NULL;
102                 return FALSE;
103         }
104         InitializeCriticalSection(&g_OpenSSLLock);
105         pSSL_load_error_strings();
106         pSSL_library_init();
107         SetSSLTimeoutCallback(60000, DefaultSSLTimeoutCallback);
108         g_bOpenSSLLoaded = TRUE;
109         return TRUE;
110 }
111
112 void FreeOpenSSL()
113 {
114         int i;
115         if(!g_bOpenSSLLoaded)
116                 return;
117         EnterCriticalSection(&g_OpenSSLLock);
118         for(i = 0; i < MAX_SSL_SOCKET; i++)
119         {
120                 if(g_pOpenSSLHandle[i])
121                 {
122                         pSSL_shutdown(g_pOpenSSLHandle[i]);
123                         pSSL_free(g_pOpenSSLHandle[i]);
124                         g_pOpenSSLHandle[i] = NULL;
125                 }
126         }
127         if(g_pOpenSSLCTX)
128                 pSSL_CTX_free(g_pOpenSSLCTX);
129         g_pOpenSSLCTX = NULL;
130         FreeLibrary(g_hOpenSSL);
131         g_hOpenSSL = NULL;
132         LeaveCriticalSection(&g_OpenSSLLock);
133         DeleteCriticalSection(&g_OpenSSLLock);
134         g_bOpenSSLLoaded = FALSE;
135 }
136
137 BOOL IsOpenSSLLoaded()
138 {
139         return g_bOpenSSLLoaded;
140 }
141
142 SSL** GetUnusedSSLPointer()
143 {
144         int i;
145         for(i = 0; i < MAX_SSL_SOCKET; i++)
146         {
147                 if(!g_pOpenSSLHandle[i])
148                         return &g_pOpenSSLHandle[i];
149         }
150         return NULL;
151 }
152
153 SSL** FindSSLPointerFromSocket(SOCKET s)
154 {
155         int i;
156         for(i = 0; i < MAX_SSL_SOCKET; i++)
157         {
158                 if(g_pOpenSSLHandle[i])
159                 {
160                         if(pSSL_get_fd(g_pOpenSSLHandle[i]) == s)
161                                 return &g_pOpenSSLHandle[i];
162                 }
163         }
164         return NULL;
165 }
166
167 void SetSSLTimeoutCallback(DWORD Timeout, LPSSLTIMEOUTCALLBACK pCallback)
168 {
169         if(!g_bOpenSSLLoaded)
170                 return;
171         EnterCriticalSection(&g_OpenSSLLock);
172         g_OpenSSLTimeout = Timeout;
173         g_pOpenSSLTimeoutCallback = pCallback;
174         LeaveCriticalSection(&g_OpenSSLLock);
175 }
176
177 BOOL AttachSSL(SOCKET s)
178 {
179         BOOL r;
180         DWORD Time;
181         SSL** ppSSL;
182         if(!g_bOpenSSLLoaded)
183                 return FALSE;
184         r = FALSE;
185         Time = timeGetTime();
186         EnterCriticalSection(&g_OpenSSLLock);
187         if(!g_pOpenSSLCTX)
188                 g_pOpenSSLCTX = pSSL_CTX_new(pSSLv23_method());
189         if(g_pOpenSSLCTX)
190         {
191                 if(ppSSL = GetUnusedSSLPointer())
192                 {
193                         if(*ppSSL = pSSL_new(g_pOpenSSLCTX))
194                         {
195                                 if(pSSL_set_fd(*ppSSL, s) != 0)
196                                 {
197                                         r = TRUE;
198                                         // SSLのネゴシエーションには時間がかかる場合がある
199                                         while(pSSL_connect(*ppSSL) != 1)
200                                         {
201                                                 LeaveCriticalSection(&g_OpenSSLLock);
202                                                 if(g_pOpenSSLTimeoutCallback() || timeGetTime() - Time >= g_OpenSSLTimeout)
203                                                 {
204                                                         DetachSSL(s);
205                                                         r = FALSE;
206                                                         EnterCriticalSection(&g_OpenSSLLock);
207                                                         break;
208                                                 }
209                                                 EnterCriticalSection(&g_OpenSSLLock);
210                                         }
211                                 }
212                                 else
213                                 {
214                                         LeaveCriticalSection(&g_OpenSSLLock);
215                                         DetachSSL(s);
216                                         EnterCriticalSection(&g_OpenSSLLock);
217                                 }
218                         }
219                 }
220         }
221         LeaveCriticalSection(&g_OpenSSLLock);
222         return r;
223 }
224
225 BOOL DetachSSL(SOCKET s)
226 {
227         BOOL r;
228         SSL** ppSSL;
229         if(!g_bOpenSSLLoaded)
230                 return FALSE;
231         r = FALSE;
232         EnterCriticalSection(&g_OpenSSLLock);
233         if(ppSSL = FindSSLPointerFromSocket(s))
234         {
235                 pSSL_shutdown(*ppSSL);
236                 pSSL_free(*ppSSL);
237                 *ppSSL = NULL;
238                 r = TRUE;
239         }
240         LeaveCriticalSection(&g_OpenSSLLock);
241         return r;
242 }
243
244 BOOL IsSSLAttached(SOCKET s)
245 {
246         SSL** ppSSL;
247         if(!g_bOpenSSLLoaded)
248                 return FALSE;
249         EnterCriticalSection(&g_OpenSSLLock);
250         ppSSL = FindSSLPointerFromSocket(s);
251         LeaveCriticalSection(&g_OpenSSLLock);
252         if(!ppSSL)
253                 return TRUE;
254         return TRUE;
255 }
256
257 SOCKET socketS(int af, int type, int protocol)
258 {
259         return socket(af, type, protocol);
260 }
261
262 int bindS(SOCKET s, const struct sockaddr *addr, int namelen)
263 {
264         return bind(s, addr, namelen);
265 }
266
267 int listenS(SOCKET s, int backlog)
268 {
269         return listen(s, backlog);
270 }
271
272 SOCKET acceptS(SOCKET s, struct sockaddr *addr, int *addrlen)
273 {
274         SOCKET r;
275         r = accept(s, addr, addrlen);
276         if(!AttachSSL(r))
277         {
278                 closesocket(r);
279                 return INVALID_SOCKET;
280         }
281         return r;
282 }
283
284 int connectS(SOCKET s, const struct sockaddr *name, int namelen)
285 {
286         int r;
287         r = connect(s, name, namelen);
288         if(!AttachSSL(r))
289                 return SOCKET_ERROR;
290         return r;
291 }
292
293 int closesocketS(SOCKET s)
294 {
295         DetachSSL(s);
296         return closesocket(s);
297 }
298
299 int sendS(SOCKET s, const char * buf, int len, int flags)
300 {
301         SSL** ppSSL;
302         if(!g_bOpenSSLLoaded)
303                 return send(s, buf, len, flags);
304         EnterCriticalSection(&g_OpenSSLLock);
305         ppSSL = FindSSLPointerFromSocket(s);
306         LeaveCriticalSection(&g_OpenSSLLock);
307         if(!ppSSL)
308                 return send(s, buf, len, flags);
309         return pSSL_write(*ppSSL, buf, len);
310 }
311
312 int recvS(SOCKET s, char * buf, int len, int flags)
313 {
314         SSL** ppSSL;
315         if(!g_bOpenSSLLoaded)
316                 return recv(s, buf, len, flags);
317         EnterCriticalSection(&g_OpenSSLLock);
318         ppSSL = FindSSLPointerFromSocket(s);
319         LeaveCriticalSection(&g_OpenSSLLock);
320         if(!ppSSL)
321                 return recv(s, buf, len, flags);
322         if(flags & MSG_PEEK)
323                 return pSSL_peek(*ppSSL, buf, len);
324         return pSSL_read(*ppSSL, buf, len);
325 }
326