OSDN Git Service

Fix line feed codes.
[ffftp/ffftp.git] / socketwrapper.c
1 // socketwrapper.c\r
2 // Copyright (C) 2011 Suguru Kawamoto\r
3 // ソケットラッパー\r
4 // socket関連関数をOpenSSL用に置換\r
5 // コンパイルにはOpenSSLのヘッダーファイルが必要\r
6 // 実行にはOpenSSLのDLLが必要\r
7 \r
8 #include <windows.h>\r
9 #include <mmsystem.h>\r
10 #include <openssl/ssl.h>\r
11 \r
12 #include "socketwrapper.h"\r
13 #include "protectprocess.h"\r
14 \r
15 typedef void (__cdecl* _SSL_load_error_strings)();\r
16 typedef int (__cdecl* _SSL_library_init)();\r
17 typedef SSL_METHOD* (__cdecl* _SSLv23_method)();\r
18 typedef SSL_CTX* (__cdecl* _SSL_CTX_new)(SSL_METHOD*);\r
19 typedef void (__cdecl* _SSL_CTX_free)(SSL_CTX*);\r
20 typedef SSL* (__cdecl* _SSL_new)(SSL_CTX*);\r
21 typedef void (__cdecl* _SSL_free)(SSL*);\r
22 typedef int (__cdecl* _SSL_shutdown)(SSL*);\r
23 typedef int (__cdecl* _SSL_get_fd)(SSL*);\r
24 typedef int (__cdecl* _SSL_set_fd)(SSL*, int);\r
25 typedef int (__cdecl* _SSL_accept)(SSL*);\r
26 typedef int (__cdecl* _SSL_connect)(SSL*);\r
27 typedef int (__cdecl* _SSL_write)(SSL*, const void*, int);\r
28 typedef int (__cdecl* _SSL_peek)(SSL*, void*, int);\r
29 typedef int (__cdecl* _SSL_read)(SSL*, void*, int);\r
30 typedef int (__cdecl* _SSL_get_error)(SSL*, int);\r
31 \r
32 _SSL_load_error_strings pSSL_load_error_strings;\r
33 _SSL_library_init pSSL_library_init;\r
34 _SSLv23_method pSSLv23_method;\r
35 _SSL_CTX_new pSSL_CTX_new;\r
36 _SSL_CTX_free pSSL_CTX_free;\r
37 _SSL_new pSSL_new;\r
38 _SSL_free pSSL_free;\r
39 _SSL_shutdown pSSL_shutdown;\r
40 _SSL_get_fd pSSL_get_fd;\r
41 _SSL_set_fd pSSL_set_fd;\r
42 _SSL_accept pSSL_accept;\r
43 _SSL_connect pSSL_connect;\r
44 _SSL_write pSSL_write;\r
45 _SSL_peek pSSL_peek;\r
46 _SSL_read pSSL_read;\r
47 _SSL_get_error pSSL_get_error;\r
48 \r
49 #define MAX_SSL_SOCKET 64\r
50 \r
51 BOOL g_bOpenSSLLoaded;\r
52 HMODULE g_hOpenSSL;\r
53 CRITICAL_SECTION g_OpenSSLLock;\r
54 DWORD g_OpenSSLTimeout;\r
55 LPSSLTIMEOUTCALLBACK g_pOpenSSLTimeoutCallback;\r
56 SSL_CTX* g_pOpenSSLCTX;\r
57 SSL* g_pOpenSSLHandle[MAX_SSL_SOCKET];\r
58 \r
59 BOOL __stdcall DefaultSSLTimeoutCallback()\r
60 {\r
61         Sleep(100);\r
62         return FALSE;\r
63 }\r
64 \r
65 BOOL LoadOpenSSL()\r
66 {\r
67         if(g_bOpenSSLLoaded)\r
68                 return FALSE;\r
69 #ifdef ENABLE_PROCESS_PROTECTION\r
70         // 同梱するOpenSSLのバージョンに合わせてSHA1ハッシュ値を変更すること\r
71         // ssleay32.dll 1.0.0e\r
72         // libssl32.dll 1.0.0e\r
73         RegisterTrustedModuleSHA1Hash("\x4E\xB7\xA0\x22\x14\x4B\x58\x6D\xBC\xF5\x21\x0D\x96\x78\x0D\x79\x7D\x66\xB2\xB0");\r
74         // libeay32.dll 1.0.0e\r
75         RegisterTrustedModuleSHA1Hash("\x01\x32\x7A\xAE\x69\x26\xE6\x58\xC7\x63\x22\x1E\x53\x5A\x78\xBC\x61\xC7\xB5\xC1");\r
76 #endif\r
77         g_hOpenSSL = LoadLibrary("ssleay32.dll");\r
78         if(!g_hOpenSSL)\r
79                 g_hOpenSSL = LoadLibrary("libssl32.dll");\r
80         if(!g_hOpenSSL\r
81                 || !(pSSL_load_error_strings = (_SSL_load_error_strings)GetProcAddress(g_hOpenSSL, "SSL_load_error_strings"))\r
82                 || !(pSSL_library_init = (_SSL_library_init)GetProcAddress(g_hOpenSSL, "SSL_library_init"))\r
83                 || !(pSSLv23_method = (_SSLv23_method)GetProcAddress(g_hOpenSSL, "SSLv23_method"))\r
84                 || !(pSSL_CTX_new = (_SSL_CTX_new)GetProcAddress(g_hOpenSSL, "SSL_CTX_new"))\r
85                 || !(pSSL_CTX_free = (_SSL_CTX_free)GetProcAddress(g_hOpenSSL, "SSL_CTX_free"))\r
86                 || !(pSSL_new = (_SSL_new)GetProcAddress(g_hOpenSSL, "SSL_new"))\r
87                 || !(pSSL_free = (_SSL_free)GetProcAddress(g_hOpenSSL, "SSL_free"))\r
88                 || !(pSSL_shutdown = (_SSL_shutdown)GetProcAddress(g_hOpenSSL, "SSL_shutdown"))\r
89                 || !(pSSL_get_fd = (_SSL_get_fd)GetProcAddress(g_hOpenSSL, "SSL_get_fd"))\r
90                 || !(pSSL_set_fd = (_SSL_set_fd)GetProcAddress(g_hOpenSSL, "SSL_set_fd"))\r
91                 || !(pSSL_accept = (_SSL_accept)GetProcAddress(g_hOpenSSL, "SSL_accept"))\r
92                 || !(pSSL_connect = (_SSL_connect)GetProcAddress(g_hOpenSSL, "SSL_connect"))\r
93                 || !(pSSL_write = (_SSL_write)GetProcAddress(g_hOpenSSL, "SSL_write"))\r
94                 || !(pSSL_peek = (_SSL_peek)GetProcAddress(g_hOpenSSL, "SSL_peek"))\r
95                 || !(pSSL_read = (_SSL_read)GetProcAddress(g_hOpenSSL, "SSL_read"))\r
96                 || !(pSSL_get_error = (_SSL_get_error)GetProcAddress(g_hOpenSSL, "SSL_get_error")))\r
97         {\r
98                 if(g_hOpenSSL)\r
99                         FreeLibrary(g_hOpenSSL);\r
100                 g_hOpenSSL = NULL;\r
101                 return FALSE;\r
102         }\r
103         InitializeCriticalSection(&g_OpenSSLLock);\r
104         pSSL_load_error_strings();\r
105         pSSL_library_init();\r
106         SetSSLTimeoutCallback(60000, DefaultSSLTimeoutCallback);\r
107         g_bOpenSSLLoaded = TRUE;\r
108         return TRUE;\r
109 }\r
110 \r
111 void FreeOpenSSL()\r
112 {\r
113         int i;\r
114         if(!g_bOpenSSLLoaded)\r
115                 return;\r
116         EnterCriticalSection(&g_OpenSSLLock);\r
117         for(i = 0; i < MAX_SSL_SOCKET; i++)\r
118         {\r
119                 if(g_pOpenSSLHandle[i])\r
120                 {\r
121                         pSSL_shutdown(g_pOpenSSLHandle[i]);\r
122                         pSSL_free(g_pOpenSSLHandle[i]);\r
123                         g_pOpenSSLHandle[i] = NULL;\r
124                 }\r
125         }\r
126         if(g_pOpenSSLCTX)\r
127                 pSSL_CTX_free(g_pOpenSSLCTX);\r
128         g_pOpenSSLCTX = NULL;\r
129         FreeLibrary(g_hOpenSSL);\r
130         g_hOpenSSL = NULL;\r
131         LeaveCriticalSection(&g_OpenSSLLock);\r
132         DeleteCriticalSection(&g_OpenSSLLock);\r
133         g_bOpenSSLLoaded = FALSE;\r
134 }\r
135 \r
136 BOOL IsOpenSSLLoaded()\r
137 {\r
138         return g_bOpenSSLLoaded;\r
139 }\r
140 \r
141 SSL** GetUnusedSSLPointer()\r
142 {\r
143         int i;\r
144         for(i = 0; i < MAX_SSL_SOCKET; i++)\r
145         {\r
146                 if(!g_pOpenSSLHandle[i])\r
147                         return &g_pOpenSSLHandle[i];\r
148         }\r
149         return NULL;\r
150 }\r
151 \r
152 SSL** FindSSLPointerFromSocket(SOCKET s)\r
153 {\r
154         int i;\r
155         for(i = 0; i < MAX_SSL_SOCKET; i++)\r
156         {\r
157                 if(g_pOpenSSLHandle[i])\r
158                 {\r
159                         if(pSSL_get_fd(g_pOpenSSLHandle[i]) == s)\r
160                                 return &g_pOpenSSLHandle[i];\r
161                 }\r
162         }\r
163         return NULL;\r
164 }\r
165 \r
166 void SetSSLTimeoutCallback(DWORD Timeout, LPSSLTIMEOUTCALLBACK pCallback)\r
167 {\r
168         if(!g_bOpenSSLLoaded)\r
169                 return;\r
170         EnterCriticalSection(&g_OpenSSLLock);\r
171         g_OpenSSLTimeout = Timeout;\r
172         g_pOpenSSLTimeoutCallback = pCallback;\r
173         LeaveCriticalSection(&g_OpenSSLLock);\r
174 }\r
175 \r
176 BOOL AttachSSL(SOCKET s)\r
177 {\r
178         BOOL r;\r
179         DWORD Time;\r
180         SSL** ppSSL;\r
181         if(!g_bOpenSSLLoaded)\r
182                 return FALSE;\r
183         r = FALSE;\r
184         Time = timeGetTime();\r
185         EnterCriticalSection(&g_OpenSSLLock);\r
186         if(!g_pOpenSSLCTX)\r
187                 g_pOpenSSLCTX = pSSL_CTX_new(pSSLv23_method());\r
188         if(g_pOpenSSLCTX)\r
189         {\r
190                 if(ppSSL = GetUnusedSSLPointer())\r
191                 {\r
192                         if(*ppSSL = pSSL_new(g_pOpenSSLCTX))\r
193                         {\r
194                                 if(pSSL_set_fd(*ppSSL, s) != 0)\r
195                                 {\r
196                                         r = TRUE;\r
197                                         // SSLのネゴシエーションには時間がかかる場合がある\r
198                                         while(pSSL_connect(*ppSSL) != 1)\r
199                                         {\r
200                                                 LeaveCriticalSection(&g_OpenSSLLock);\r
201                                                 if(g_pOpenSSLTimeoutCallback() || timeGetTime() - Time >= g_OpenSSLTimeout)\r
202                                                 {\r
203                                                         DetachSSL(s);\r
204                                                         r = FALSE;\r
205                                                         EnterCriticalSection(&g_OpenSSLLock);\r
206                                                         break;\r
207                                                 }\r
208                                                 EnterCriticalSection(&g_OpenSSLLock);\r
209                                         }\r
210                                 }\r
211                                 else\r
212                                 {\r
213                                         LeaveCriticalSection(&g_OpenSSLLock);\r
214                                         DetachSSL(s);\r
215                                         EnterCriticalSection(&g_OpenSSLLock);\r
216                                 }\r
217                         }\r
218                 }\r
219         }\r
220         LeaveCriticalSection(&g_OpenSSLLock);\r
221         return r;\r
222 }\r
223 \r
224 BOOL DetachSSL(SOCKET s)\r
225 {\r
226         BOOL r;\r
227         SSL** ppSSL;\r
228         if(!g_bOpenSSLLoaded)\r
229                 return FALSE;\r
230         r = FALSE;\r
231         EnterCriticalSection(&g_OpenSSLLock);\r
232         if(ppSSL = FindSSLPointerFromSocket(s))\r
233         {\r
234                 pSSL_shutdown(*ppSSL);\r
235                 pSSL_free(*ppSSL);\r
236                 *ppSSL = NULL;\r
237                 r = TRUE;\r
238         }\r
239         LeaveCriticalSection(&g_OpenSSLLock);\r
240         return r;\r
241 }\r
242 \r
243 BOOL IsSSLAttached(SOCKET s)\r
244 {\r
245         SSL** ppSSL;\r
246         if(!g_bOpenSSLLoaded)\r
247                 return FALSE;\r
248         EnterCriticalSection(&g_OpenSSLLock);\r
249         ppSSL = FindSSLPointerFromSocket(s);\r
250         LeaveCriticalSection(&g_OpenSSLLock);\r
251         if(!ppSSL)\r
252                 return TRUE;\r
253         return TRUE;\r
254 }\r
255 \r
256 SOCKET socketS(int af, int type, int protocol)\r
257 {\r
258         return socket(af, type, protocol);\r
259 }\r
260 \r
261 int bindS(SOCKET s, const struct sockaddr *addr, int namelen)\r
262 {\r
263         return bind(s, addr, namelen);\r
264 }\r
265 \r
266 int listenS(SOCKET s, int backlog)\r
267 {\r
268         return listen(s, backlog);\r
269 }\r
270 \r
271 SOCKET acceptS(SOCKET s, struct sockaddr *addr, int *addrlen)\r
272 {\r
273         SOCKET r;\r
274         r = accept(s, addr, addrlen);\r
275         if(!AttachSSL(r))\r
276         {\r
277                 closesocket(r);\r
278                 return INVALID_SOCKET;\r
279         }\r
280         return r;\r
281 }\r
282 \r
283 int connectS(SOCKET s, const struct sockaddr *name, int namelen)\r
284 {\r
285         int r;\r
286         r = connect(s, name, namelen);\r
287         if(!AttachSSL(r))\r
288                 return SOCKET_ERROR;\r
289         return r;\r
290 }\r
291 \r
292 int closesocketS(SOCKET s)\r
293 {\r
294         DetachSSL(s);\r
295         return closesocket(s);\r
296 }\r
297 \r
298 int sendS(SOCKET s, const char * buf, int len, int flags)\r
299 {\r
300         SSL** ppSSL;\r
301         if(!g_bOpenSSLLoaded)\r
302                 return send(s, buf, len, flags);\r
303         EnterCriticalSection(&g_OpenSSLLock);\r
304         ppSSL = FindSSLPointerFromSocket(s);\r
305         LeaveCriticalSection(&g_OpenSSLLock);\r
306         if(!ppSSL)\r
307                 return send(s, buf, len, flags);\r
308         return pSSL_write(*ppSSL, buf, len);\r
309 }\r
310 \r
311 int recvS(SOCKET s, char * buf, int len, int flags)\r
312 {\r
313         SSL** ppSSL;\r
314         if(!g_bOpenSSLLoaded)\r
315                 return recv(s, buf, len, flags);\r
316         EnterCriticalSection(&g_OpenSSLLock);\r
317         ppSSL = FindSSLPointerFromSocket(s);\r
318         LeaveCriticalSection(&g_OpenSSLLock);\r
319         if(!ppSSL)\r
320                 return recv(s, buf, len, flags);\r
321         if(flags & MSG_PEEK)\r
322                 return pSSL_peek(*ppSSL, buf, len);\r
323         return pSSL_read(*ppSSL, buf, len);\r
324 }\r
325 \r