]> git.stg.codes - stg.git/blob - stglibs/srvconf.lib/netunit.cpp
Better error reporting in netunit.cpp.
[stg.git] / stglibs / srvconf.lib / netunit.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Boris Mikhailenko <stg34@stargazer.dp.ua>
19  */
20
21 #include "netunit.h"
22
23 #include "stg/servconf_types.h"
24 #include "stg/common.h"
25 #include "stg/blowfish.h"
26 #include "stg/bfstream.h"
27
28 #include <algorithm> // std::min
29
30 #include <cstdio>
31 #include <cerrno>
32 #include <cstring>
33 #include <cassert>
34
35 #include <netdb.h>
36 #include <arpa/inet.h>
37 #include <unistd.h>
38
39 #include <sys/types.h>
40 #include <sys/socket.h>
41 #include <netinet/in.h>
42
43 using namespace STG;
44
45 namespace
46 {
47
48 struct ReadState
49 {
50     bool final;
51     NETTRANSACT::CALLBACK callback;
52     void * callbackData;
53     NETTRANSACT * nt;
54 };
55
56 }
57
58 //---------------------------------------------------------------------------
59
60 #define SEND_DATA_ERROR             "Send data error!"
61 #define RECV_DATA_ANSWER_ERROR      "Recv data answer error!"
62 #define UNKNOWN_ERROR               "Unknown error!"
63 #define CONNECT_FAILED              "Connect failed!"
64 #define BIND_FAILED                 "Bind failed!"
65 #define INCORRECT_LOGIN             "Incorrect login!"
66 #define INCORRECT_HEADER            "Incorrect header!"
67 #define SEND_LOGIN_ERROR            "Send login error!"
68 #define RECV_LOGIN_ANSWER_ERROR     "Recv login answer error!"
69 #define CREATE_SOCKET_ERROR         "Create socket failed!"
70 #define WSASTARTUP_FAILED           "WSAStartup failed!"
71 #define SEND_HEADER_ERROR           "Send header error!"
72 #define RECV_HEADER_ANSWER_ERROR    "Recv header answer error!"
73
74 //---------------------------------------------------------------------------
75 NETTRANSACT::NETTRANSACT(const std::string & s, uint16_t p,
76                          const std::string & l, const std::string & pwd)
77     : server(s),
78       port(p),
79       localPort(0),
80       login(l),
81       password(pwd),
82       sock(-1)
83 {
84 }
85 //---------------------------------------------------------------------------
86 NETTRANSACT::NETTRANSACT(const std::string & s, uint16_t p,
87                          const std::string & la, uint16_t lp,
88                          const std::string & l, const std::string & pwd)
89     : server(s),
90       port(p),
91       localAddress(la),
92       localPort(lp),
93       login(l),
94       password(pwd),
95       sock(-1)
96 {
97 }
98 //---------------------------------------------------------------------------
99 NETTRANSACT::~NETTRANSACT()
100 {
101 Disconnect();
102 }
103 //---------------------------------------------------------------------------
104 int NETTRANSACT::Connect()
105 {
106 sock = socket(PF_INET, SOCK_STREAM, 0);
107 if (sock < 0)
108     {
109     errorMsg = CREATE_SOCKET_ERROR;
110     return st_conn_fail;
111     }
112
113 if (!localAddress.empty())
114     {
115     if (localPort == 0)
116         localPort = port;
117
118     unsigned long ip = inet_addr(localAddress.c_str());
119
120     if (ip == INADDR_NONE)
121         {
122         struct hostent * phe = gethostbyname(localAddress.c_str());
123         if (phe == NULL)
124             {
125             errorMsg = "Can not reslove '" + localAddress + "'";
126             return st_dns_err;
127             }
128
129         struct hostent he;
130         memcpy(&he, phe, sizeof(he));
131         ip = *((long *)he.h_addr_list[0]);
132         }
133
134     struct sockaddr_in localAddr;
135     memset(&localAddr, 0, sizeof(localAddr));
136     localAddr.sin_family = AF_INET;
137     localAddr.sin_port = htons(localPort);
138     localAddr.sin_addr.s_addr = ip;
139
140     if (bind(sock, (struct sockaddr *)&localAddr, sizeof(localAddr)) < 0)
141         {
142         errorMsg = BIND_FAILED;
143         return st_conn_fail;
144         }
145     }
146
147 struct sockaddr_in outerAddr;
148 memset(&outerAddr, 0, sizeof(outerAddr));
149
150 unsigned long ip = inet_addr(server.c_str());
151
152 if (ip == INADDR_NONE)
153     {
154     struct hostent * phe = gethostbyname(server.c_str());
155     if (phe == NULL)
156         {
157         errorMsg = "Can not reslove '" + server + "'";
158         return st_dns_err;
159         }
160
161     struct hostent he;
162     memcpy(&he, phe, sizeof(he));
163     ip = *((long *)he.h_addr_list[0]);
164     }
165
166 outerAddr.sin_family = AF_INET;
167 outerAddr.sin_port = htons(port);
168 outerAddr.sin_addr.s_addr = ip;
169
170 if (connect(sock, (struct sockaddr *)&outerAddr, sizeof(outerAddr)) < 0)
171     {
172     errorMsg = CONNECT_FAILED;
173     return st_conn_fail;
174     }
175
176 return st_ok;
177 }
178 //---------------------------------------------------------------------------
179 void NETTRANSACT::Disconnect()
180 {
181 if (sock != -1)
182     {
183     shutdown(sock, SHUT_RDWR);
184     close(sock);
185     sock = -1;
186     }
187 }
188 //---------------------------------------------------------------------------
189 int NETTRANSACT::Transact(const std::string & request, CALLBACK callback, void * data)
190 {
191 int ret;
192 if ((ret = TxHeader()) != st_ok)
193     return ret;
194
195 if ((ret = RxHeaderAnswer()) != st_ok)
196     return ret;
197
198 if ((ret = TxLogin()) != st_ok)
199     return ret;
200
201 if ((ret = RxLoginAnswer()) != st_ok)
202     return ret;
203
204 if ((ret = TxLoginS()) != st_ok)
205     return ret;
206
207 if ((ret = RxLoginSAnswer()) != st_ok)
208     return ret;
209
210 if ((ret = TxData(request)) != st_ok)
211     return ret;
212
213 if ((ret = RxDataAnswer(callback, data)) != st_ok)
214     return ret;
215
216 return st_ok;
217 }
218 //---------------------------------------------------------------------------
219 int NETTRANSACT::TxHeader()
220 {
221 if (!WriteAll(sock, STG_HEADER, strlen(STG_HEADER)))
222     {
223     errorMsg = SEND_HEADER_ERROR;
224     return st_send_fail;
225     }
226
227 return st_ok;
228 }
229 //---------------------------------------------------------------------------
230 int NETTRANSACT::RxHeaderAnswer()
231 {
232 char buffer[sizeof(STG_HEADER) + 1];
233
234 if (!ReadAll(sock, buffer, strlen(OK_HEADER)))
235     {
236     printf("Receive header answer error: '%s'\n", strerror(errno));
237     errorMsg = RECV_HEADER_ANSWER_ERROR;
238     return st_recv_fail;
239     }
240
241 if (strncmp(OK_HEADER, buffer, strlen(OK_HEADER)) == 0)
242     return st_ok;
243
244 if (strncmp(ERR_HEADER, buffer, strlen(ERR_HEADER)) == 0)
245     {
246     errorMsg = INCORRECT_HEADER;
247     return st_header_err;
248     }
249 else
250     {
251     errorMsg = UNKNOWN_ERROR;
252     return st_unknown_err;
253     }
254 }
255 //---------------------------------------------------------------------------
256 int NETTRANSACT::TxLogin()
257 {
258 char loginZ[ADM_LOGIN_LEN + 1];
259 memset(loginZ, 0, ADM_LOGIN_LEN + 1);
260 strncpy(loginZ, login.c_str(), ADM_LOGIN_LEN);
261
262 if (!WriteAll(sock, loginZ, ADM_LOGIN_LEN))
263     {
264     errorMsg = SEND_LOGIN_ERROR;
265     return st_send_fail;
266     }
267
268 return st_ok;
269 }
270 //---------------------------------------------------------------------------
271 int NETTRANSACT::RxLoginAnswer()
272 {
273 char buffer[sizeof(OK_LOGIN) + 1];
274
275 if (!ReadAll(sock, buffer, strlen(OK_LOGIN)))
276     {
277     printf("Receive login answer error: '%s'\n", strerror(errno));
278     errorMsg = RECV_LOGIN_ANSWER_ERROR;
279     return st_recv_fail;
280     }
281
282 if (strncmp(OK_LOGIN, buffer, strlen(OK_LOGIN)) == 0)
283     return st_ok;
284
285 if (strncmp(ERR_LOGIN, buffer, strlen(ERR_LOGIN)) == 0)
286     {
287     errorMsg = INCORRECT_LOGIN;
288     return st_login_err;
289     }
290 else
291     {
292     errorMsg = UNKNOWN_ERROR;
293     return st_unknown_err;
294     }
295 }
296 //---------------------------------------------------------------------------
297 int NETTRANSACT::TxLoginS()
298 {
299 char loginZ[ADM_LOGIN_LEN + 1];
300 memset(loginZ, 0, ADM_LOGIN_LEN + 1);
301
302 BLOWFISH_CTX ctx;
303 InitContext(password.c_str(), PASSWD_LEN, &ctx);
304 EncryptString(loginZ, login.c_str(), std::min<size_t>(login.length() + 1, ADM_LOGIN_LEN), &ctx);
305 if (!WriteAll(sock, loginZ, ADM_LOGIN_LEN))
306     {
307     errorMsg = SEND_LOGIN_ERROR;
308     return st_send_fail;
309     }
310
311 return st_ok;
312 }
313 //---------------------------------------------------------------------------
314 int NETTRANSACT::RxLoginSAnswer()
315 {
316 char buffer[sizeof(OK_LOGINS) + 1];
317
318 if (!ReadAll(sock, buffer, strlen(OK_LOGINS)))
319     {
320     printf("Receive secret login answer error: '%s'\n", strerror(errno));
321     errorMsg = RECV_LOGIN_ANSWER_ERROR;
322     return st_recv_fail;
323     }
324
325 if (strncmp(OK_LOGINS, buffer, strlen(OK_LOGINS)) == 0)
326     return st_ok;
327
328 if (strncmp(ERR_LOGINS, buffer, strlen(ERR_LOGINS)) == 0)
329     {
330     errorMsg = INCORRECT_LOGIN;
331     return st_logins_err;
332     }
333 else
334     {
335     errorMsg = UNKNOWN_ERROR;
336     return st_unknown_err;
337     }
338 }
339 //---------------------------------------------------------------------------
340 int NETTRANSACT::TxData(const std::string & text)
341 {
342 STG::ENCRYPT_STREAM stream(password, TxCrypto, this);
343 stream.Put(text.c_str(), text.length() + 1, true);
344 if (!stream.IsOk())
345     {
346     errorMsg = SEND_DATA_ERROR;
347     return st_send_fail;
348     }
349
350 return st_ok;
351 }
352 //---------------------------------------------------------------------------
353 int NETTRANSACT::RxDataAnswer(CALLBACK callback, void * data)
354 {
355 ReadState state = {false, callback, data, this};
356 STG::DECRYPT_STREAM stream(password, RxCrypto, &state);
357 while (!state.final)
358     {
359     char buffer[1024];
360     ssize_t res = read(sock, buffer, sizeof(buffer));
361     if (res < 0)
362         {
363         printf("Receive data error: '%s'\n", strerror(errno));
364         errorMsg = RECV_DATA_ANSWER_ERROR;
365         return st_recv_fail;
366         }
367     stream.Put(buffer, res, res == 0);
368     if (!stream.IsOk())
369         return st_xml_parse_error;
370     }
371
372 return st_ok;
373 }
374 //---------------------------------------------------------------------------
375 bool NETTRANSACT::TxCrypto(const void * block, size_t size, void * data)
376 {
377 assert(data != NULL);
378 NETTRANSACT & nt = *static_cast<NETTRANSACT *>(data);
379 if (!WriteAll(nt.sock, block, size))
380     return false;
381 return true;
382 }
383 //---------------------------------------------------------------------------
384 bool NETTRANSACT::RxCrypto(const void * block, size_t size, void * data)
385 {
386 assert(data != NULL);
387 ReadState & state = *static_cast<ReadState *>(data);
388
389 const char * buffer = static_cast<const char *>(block);
390 for (size_t pos = 0; pos < size; ++pos)
391     if (buffer[pos] == 0)
392         {
393         state.final = true;
394         size = pos; // Adjust string size
395         }
396
397 if (state.callback)
398     if (!state.callback(std::string(buffer, size), state.final, state.callbackData))
399         return false;
400
401 return true;
402 }