]> git.stg.codes - stg.git/blob - stglibs/srvconf.lib/netunit.cpp
Added stream encryption/decryption tests.
[stg.git] / stglibs / srvconf.lib / netunit.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Boris Mikhailenko <stg34@stargazer.dp.ua>
19  */
20
21 #include "netunit.h"
22
23 #include "stg/servconf_types.h"
24 #include "stg/common.h"
25 #include "stg/blowfish.h"
26
27 #include <algorithm> // std::min
28
29 #include <cstdio>
30 #include <cerrno>
31 #include <cstring>
32
33 #include <netdb.h>
34 #include <arpa/inet.h>
35 #include <unistd.h>
36
37 #include <sys/types.h>
38 #include <sys/socket.h>
39 #include <netinet/in.h>
40
41 using namespace STG;
42
43 namespace
44 {
45
46 const std::string::size_type MAX_XML_CHUNK_LENGTH = 2048;
47
48 }
49
50 //---------------------------------------------------------------------------
51
52 #define SEND_DATA_ERROR             "Send data error!"
53 #define RECV_DATA_ANSWER_ERROR      "Recv data answer error!"
54 #define UNKNOWN_ERROR               "Unknown error!"
55 #define CONNECT_FAILED              "Connect failed!"
56 #define BIND_FAILED                 "Bind failed!"
57 #define INCORRECT_LOGIN             "Incorrect login!"
58 #define INCORRECT_HEADER            "Incorrect header!"
59 #define SEND_LOGIN_ERROR            "Send login error!"
60 #define RECV_LOGIN_ANSWER_ERROR     "Recv login answer error!"
61 #define CREATE_SOCKET_ERROR         "Create socket failed!"
62 #define WSASTARTUP_FAILED           "WSAStartup failed!"
63 #define SEND_HEADER_ERROR           "Send header error!"
64 #define RECV_HEADER_ANSWER_ERROR    "Recv header answer error!"
65
66 //---------------------------------------------------------------------------
67 NETTRANSACT::NETTRANSACT(const std::string & s, uint16_t p,
68                          const std::string & l, const std::string & pwd)
69     : server(s),
70       port(p),
71       localPort(0),
72       login(l),
73       password(pwd),
74       outerSocket(-1)
75 {
76 }
77 //---------------------------------------------------------------------------
78 NETTRANSACT::NETTRANSACT(const std::string & s, uint16_t p,
79                          const std::string & la, uint16_t lp,
80                          const std::string & l, const std::string & pwd)
81     : server(s),
82       port(p),
83       localAddress(la),
84       localPort(lp),
85       login(l),
86       password(pwd),
87       outerSocket(-1)
88 {
89 }
90 //---------------------------------------------------------------------------
91 int NETTRANSACT::Connect()
92 {
93 outerSocket = socket(PF_INET, SOCK_STREAM, 0);
94 if (outerSocket < 0)
95     {
96     errorMsg = CREATE_SOCKET_ERROR;
97     return st_conn_fail;
98     }
99
100 if (!localAddress.empty())
101     {
102     if (localPort == 0)
103         localPort = port;
104
105     unsigned long ip = inet_addr(localAddress.c_str());
106
107     if (ip == INADDR_NONE)
108         {
109         struct hostent * phe = gethostbyname(localAddress.c_str());
110         if (phe == NULL)
111             {
112             errorMsg = "DNS error.\nCan not reslove " + localAddress;
113             return st_dns_err;
114             }
115
116         struct hostent he;
117         memcpy(&he, phe, sizeof(he));
118         ip = *((long *)he.h_addr_list[0]);
119         }
120
121     struct sockaddr_in localAddr;
122     memset(&localAddr, 0, sizeof(localAddr));
123     localAddr.sin_family = AF_INET;
124     localAddr.sin_port = htons(localPort);
125     localAddr.sin_addr.s_addr = ip;
126
127     if (bind(outerSocket, (struct sockaddr *)&localAddr, sizeof(localAddr)) < 0)
128         {
129         errorMsg = BIND_FAILED;
130         close(outerSocket);
131         return st_conn_fail;
132         }
133     }
134
135 struct sockaddr_in outerAddr;
136 memset(&outerAddr, 0, sizeof(outerAddr));
137
138 unsigned long ip = inet_addr(server.c_str());
139
140 if (ip == INADDR_NONE)
141     {
142     struct hostent * phe = gethostbyname(server.c_str());
143     if (phe == NULL)
144         {
145         errorMsg = "DNS error.\nCan not reslove " + server;
146         return st_dns_err;
147         }
148
149     struct hostent he;
150     memcpy(&he, phe, sizeof(he));
151     ip = *((long *)he.h_addr_list[0]);
152     }
153
154 outerAddr.sin_family = AF_INET;
155 outerAddr.sin_port = htons(port);
156 outerAddr.sin_addr.s_addr = ip;
157
158 if (connect(outerSocket, (struct sockaddr *)&outerAddr, sizeof(outerAddr)) < 0)
159     {
160     errorMsg = CONNECT_FAILED;
161     close(outerSocket);
162     return st_conn_fail;
163     }
164
165 return st_ok;
166 }
167 //---------------------------------------------------------------------------
168 void NETTRANSACT::Disconnect()
169 {
170 close(outerSocket);
171 }
172 //---------------------------------------------------------------------------
173 int NETTRANSACT::Transact(const std::string & request, CALLBACK callback, void * data)
174 {
175 int ret;
176 if ((ret = TxHeader()) != st_ok)
177     {
178     Disconnect();
179     return ret;
180     }
181
182 if ((ret = RxHeaderAnswer()) != st_ok)
183     {
184     Disconnect();
185     return ret;
186     }
187
188 if ((ret = TxLogin()) != st_ok)
189     {
190     Disconnect();
191     return ret;
192     }
193
194 if ((ret = RxLoginAnswer()) != st_ok)
195     {
196     Disconnect();
197     return ret;
198     }
199
200 if ((ret = TxLoginS()) != st_ok)
201     {
202     Disconnect();
203     return ret;
204     }
205
206 if ((ret = RxLoginSAnswer()) != st_ok)
207     {
208     Disconnect();
209     return ret;
210     }
211
212 if ((ret = TxData(request)) != st_ok)
213     {
214     Disconnect();
215     return ret;
216     }
217
218 if ((ret = RxDataAnswer(callback, data)) != st_ok)
219     {
220     Disconnect();
221     return ret;
222     }
223
224 return st_ok;
225 }
226 //---------------------------------------------------------------------------
227 int NETTRANSACT::TxHeader()
228 {
229 if (send(outerSocket, STG_HEADER, strlen(STG_HEADER), 0) <= 0)
230     {
231     errorMsg = SEND_HEADER_ERROR;
232     return st_send_fail;
233     }
234
235 return st_ok;
236 }
237 //---------------------------------------------------------------------------
238 int NETTRANSACT::RxHeaderAnswer()
239 {
240 char buffer[sizeof(STG_HEADER) + 1];
241
242 if (recv(outerSocket, buffer, strlen(OK_HEADER), 0) <= 0)
243     {
244     printf("Receive header answer error: '%s'\n", strerror(errno));
245     errorMsg = RECV_HEADER_ANSWER_ERROR;
246     return st_recv_fail;
247     }
248
249 if (strncmp(OK_HEADER, buffer, strlen(OK_HEADER)) == 0)
250     {
251     return st_ok;
252     }
253 else
254     {
255     if (strncmp(ERR_HEADER, buffer, strlen(ERR_HEADER)) == 0)
256         {
257         errorMsg = INCORRECT_HEADER;
258         return st_header_err;
259         }
260     else
261         {
262         errorMsg = UNKNOWN_ERROR;
263         return st_unknown_err;
264         }
265     }
266 }
267 //---------------------------------------------------------------------------
268 int NETTRANSACT::TxLogin()
269 {
270 char loginZ[ADM_LOGIN_LEN];
271 memset(loginZ, 0, ADM_LOGIN_LEN);
272 strncpy(loginZ, login.c_str(), ADM_LOGIN_LEN);
273
274 if (send(outerSocket, loginZ, ADM_LOGIN_LEN, 0) <= 0)
275     {
276     errorMsg = SEND_LOGIN_ERROR;
277     return st_send_fail;
278     }
279
280 return st_ok;
281 }
282 //---------------------------------------------------------------------------
283 int NETTRANSACT::RxLoginAnswer()
284 {
285 char buffer[sizeof(OK_LOGIN) + 1];
286
287 if (recv(outerSocket, buffer, strlen(OK_LOGIN), 0) <= 0)
288     {
289     printf("Receive login answer error: '%s'\n", strerror(errno));
290     errorMsg = RECV_LOGIN_ANSWER_ERROR;
291     return st_recv_fail;
292     }
293
294 if (strncmp(OK_LOGIN, buffer, strlen(OK_LOGIN)) == 0)
295     {
296     return st_ok;
297     }
298 else
299     {
300     if (strncmp(ERR_LOGIN, buffer, strlen(ERR_LOGIN)) == 0)
301         {
302         errorMsg = INCORRECT_LOGIN;
303         return st_login_err;
304         }
305     else
306         {
307         errorMsg = UNKNOWN_ERROR;
308         return st_unknown_err;
309         }
310     }
311 }
312 //---------------------------------------------------------------------------
313 int NETTRANSACT::TxLoginS()
314 {
315 char loginZ[ADM_LOGIN_LEN];
316 memset(loginZ, 0, ADM_LOGIN_LEN);
317 BLOWFISH_CTX ctx;
318 InitContext(password.c_str(), PASSWD_LEN, &ctx);
319 EncryptString(loginZ, login.c_str(), std::min(login.length(), ADM_LOGIN_LEN), &ctx);
320 if (send(outerSocket, loginZ, ADM_LOGIN_LEN, 0) <= 0)
321     {
322     errorMsg = SEND_LOGIN_ERROR;
323     return st_send_fail;
324     }
325 return st_ok;
326 }
327 //---------------------------------------------------------------------------
328 int NETTRANSACT::RxLoginSAnswer()
329 {
330 char buffer[sizeof(OK_LOGINS) + 1];
331
332 if (recv(outerSocket, buffer, strlen(OK_LOGINS), 0) <= 0)
333     {
334     printf("Receive secret login answer error: '%s'\n", strerror(errno));
335     errorMsg = RECV_LOGIN_ANSWER_ERROR;
336     return st_recv_fail;
337     }
338
339 if (strncmp(OK_LOGINS, buffer, strlen(OK_LOGINS)) == 0)
340     {
341     return st_ok;
342     }
343 else
344     {
345     if (strncmp(ERR_LOGINS, buffer, strlen(ERR_LOGINS)) == 0)
346         {
347         errorMsg = INCORRECT_LOGIN;
348         return st_logins_err;
349         }
350     else
351         {
352         errorMsg = UNKNOWN_ERROR;
353         return st_unknown_err;
354         }
355     }
356 }
357 //---------------------------------------------------------------------------
358 int NETTRANSACT::TxData(const std::string & text)
359 {
360 BLOWFISH_CTX ctx;
361 InitContext(password.c_str(), PASSWD_LEN, &ctx);
362 char buffer[text.length()];
363 EncryptString(buffer, text.c_str(), text.length(), &ctx);
364 if (send(outerSocket, buffer, text.length(), 0) <= 0)
365     {
366     errorMsg = SEND_DATA_ERROR;
367     return st_send_fail;
368     }
369 return st_ok;
370 }
371 //---------------------------------------------------------------------------
372 int NETTRANSACT::RxDataAnswer(CALLBACK callback, void * data)
373 {
374 BLOWFISH_CTX ctx;
375 InitContext(password.c_str(), PASSWD_LEN, &ctx);
376
377 std::string chunk;
378 while (true)
379     {
380     char bufferS[ENC_MSG_LEN];
381     size_t toRead = ENC_MSG_LEN;
382     while (toRead > 0)
383         {
384         int ret = recv(outerSocket, &bufferS[ENC_MSG_LEN - toRead], toRead, 0);
385         if (ret <= 0)
386             {
387             printf("Receive data error: '%s'\n", strerror(errno));
388             close(outerSocket);
389             errorMsg = RECV_DATA_ANSWER_ERROR;
390             return st_recv_fail;
391             }
392         toRead -= ret;
393         }
394
395     char buffer[ENC_MSG_LEN];
396     DecryptBlock(buffer, bufferS, &ctx);
397
398     bool final = false;
399     size_t pos = 0;
400     for (; pos < ENC_MSG_LEN && buffer[pos] != 0; pos++) ;
401     if (pos < ENC_MSG_LEN && buffer[pos] == 0)
402         final = true;
403
404     if (pos > 0)
405         chunk.append(&buffer[0], &buffer[pos]);
406
407     if (chunk.length() > MAX_XML_CHUNK_LENGTH || final)
408         {
409         if (callback)
410             if (!callback(chunk, final, data))
411                 return st_xml_parse_error;
412         chunk.clear();
413         }
414
415     if (final)
416         return st_ok;
417     }
418 }