]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
Fix tests for inclusive fee charge rules
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <netdb.h>
2 #include <arpa/inet.h>
3
4 #include <csignal>
5 #include <cerrno>
6 #include <cstring>
7 #include <cassert>
8 #include <stdexcept>
9 #include <algorithm>
10
11 #include "stg/common.h"
12 #include "stg/ia_packets.h"
13 #include "stg/locker.h"
14
15 #include "proto.h"
16
17 class HasIP : public std::unary_function<std::pair<uint32_t, USER>, bool> {
18     public:
19         explicit HasIP(uint32_t i) : ip(i) {}
20         bool operator()(const std::pair<uint32_t, USER> & value) { return value.first == ip; }
21     private:
22         uint32_t ip;
23 };
24
25 PROTO::PROTO(const std::string & server,
26              uint16_t port,
27              uint16_t localPort,
28              int to)
29     : timeout(to),
30       running(false),
31       stopped(true)
32 {
33 uint32_t ip = inet_addr(server.c_str());
34 if (ip == INADDR_NONE)
35     {
36     struct hostent * hePtr = gethostbyname(server.c_str());
37     if (hePtr)
38         {
39         ip = *((uint32_t *)hePtr->h_addr_list[0]);
40         }
41     else
42         {
43         errorStr = "Unknown host: '";
44         errorStr += server;
45         errorStr += "'";
46         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
47         throw std::runtime_error(errorStr);
48         }
49     }
50
51 localAddr.sin_family = AF_INET;
52 localAddr.sin_port = htons(localPort);
53 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
54
55 serverAddr.sin_family = AF_INET;
56 serverAddr.sin_port = htons(port);
57 serverAddr.sin_addr.s_addr = ip;
58
59 unsigned char key[IA_PASSWD_LEN];
60 memset(key, 0, IA_PASSWD_LEN);
61 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
62 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
63
64 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
65 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
66 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
67 processors["FIN"] = &PROTO::FIN_Proc;
68 processors["INFO"] = &PROTO::INFO_Proc;
69 // ERR_Proc will be handled explicitly
70
71 pthread_mutex_init(&mutex, NULL);
72 }
73
74 PROTO::~PROTO()
75 {
76 pthread_mutex_destroy(&mutex);
77 }
78
79 void * PROTO::Runner(void * data)
80 {
81 sigset_t signalSet;
82 sigfillset(&signalSet);
83 pthread_sigmask(SIG_BLOCK, &signalSet, NULL);
84
85 PROTO * protoPtr = static_cast<PROTO *>(data);
86 protoPtr->Run();
87 return NULL;
88 }
89
90 bool PROTO::Start()
91 {
92 stopped = false;
93 running = true;
94 if (pthread_create(&tid, NULL, &Runner, this))
95     {
96     errorStr = "Failed to create listening thread: '";
97     errorStr += strerror(errno);
98     errorStr += "'";
99     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
100     return false;
101     }
102 return true;
103 }
104
105 bool PROTO::Stop()
106 {
107 running = false;
108 int time = 0;
109 while (!stopped && time < timeout)
110     {
111     struct timespec ts = {1, 0};
112     nanosleep(&ts, NULL);
113     ++time;
114     }
115 if (!stopped)
116     {
117     errorStr = "Failed to stop listening thread - timed out";
118     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
119     return false;
120     }
121 if (pthread_join(tid, NULL))
122     {
123     errorStr = "Failed to join listening thread after stop: '";
124     errorStr += strerror(errno);
125     errorStr += "'";
126     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
127     return false;
128     }
129 return true;
130 }
131
132 void PROTO::AddUser(const USER & user, bool connect)
133 {
134 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
135 users.push_back(std::make_pair(user.GetIP(), user));
136 users.back().second.InitNetwork();
137
138 struct pollfd pfd;
139 pfd.fd = users.back().second.GetSocket();
140 pfd.events = POLLIN;
141 pfd.revents = 0;
142 pollFds.push_back(pfd);
143
144 if (connect)
145     {
146     RealConnect(&users.back().second);
147     }
148 }
149
150 bool PROTO::Connect(uint32_t ip)
151 {
152 std::list<std::pair<uint32_t, USER> >::iterator it;
153 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
154 it = std::find_if(users.begin(), users.end(), HasIP(ip));
155 if (it == users.end())
156     return false;
157
158 // Do something
159
160 return RealConnect(&it->second);
161 }
162
163 bool PROTO::Disconnect(uint32_t ip)
164 {
165 std::list<std::pair<uint32_t, USER> >::iterator it;
166 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
167 it = std::find_if(users.begin(), users.end(), HasIP(ip));
168 if (it == users.end())
169     return false;
170
171 // Do something
172
173 return RealDisconnect(&it->second);
174 }
175
176 void PROTO::Run()
177 {
178 while (running)
179     {
180     int res;
181         {
182         STG_LOCKER lock(&mutex, __FILE__, __LINE__);
183         res = poll(&pollFds.front(), pollFds.size(), timeout);
184         }
185     if (res < 0)
186         break;
187     if (!running)
188         break;
189     if (res)
190         {
191         printfd(__FILE__, "PROTO::Run() - events: %d\n", res);
192         RecvPacket();
193         }
194     else
195         {
196         CheckTimeouts();
197         }
198     }
199
200 stopped = true;
201 }
202
203 void PROTO::CheckTimeouts()
204 {
205 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
206 std::list<std::pair<uint32_t, USER> >::iterator it;
207 for (it = users.begin(); it != users.end(); ++it)
208     {
209     int delta = difftime(time(NULL), it->second.GetPhaseChangeTime());
210     if ((it->second.GetPhase() == 3) &&
211         (delta > it->second.GetUserTimeout()))
212         {
213         printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout());
214         it->second.SetPhase(1);
215         }
216     if ((it->second.GetPhase() == 2) &&
217         (delta > it->second.GetAliveTimeout()))
218         {
219         printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout());
220         it->second.SetPhase(1);
221         }
222     }
223 }
224
225 bool PROTO::RecvPacket()
226 {
227 bool result = true;
228 std::vector<struct pollfd>::iterator it;
229 std::list<std::pair<uint32_t, USER> >::iterator userIt;
230 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
231 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
232     {
233     if (it->revents)
234         {
235         it->revents = 0;
236         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
237         struct sockaddr_in addr;
238         socklen_t fromLen = sizeof(addr);
239         char buffer[2048];
240         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
241
242         if (res == -1)
243             {
244             result = false;
245             continue;
246             }
247
248         result = result && HandlePacket(buffer, res, &(userIt->second));
249         }
250     }
251
252 return result;
253 }
254
255 bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user)
256 {
257 if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3))
258     {
259     return ERR_Proc(buffer, user);
260     }
261
262 for (size_t i = 0; i < length / 8; i++)
263     Blowfish_Decrypt(user->GetCtx(),
264                      (uint32_t *)(buffer + i * 8),
265                      (uint32_t *)(buffer + i * 8 + 4));
266
267 std::string packetName(buffer + 12);
268
269 std::map<std::string, PacketProcessor>::const_iterator it;
270 it = processors.find(packetName);
271 if (it != processors.end())
272     return (this->*it->second)(buffer, user);
273
274 printfd(__FILE__, "PROTO::HandlePacket() - invalid packet signature: '%s'\n", packetName.c_str());
275
276 return false;
277 }
278
279 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
280 {
281 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
282
283 uint32_t rnd = packet->rnd;
284 uint32_t userTimeout = packet->userTimeOut;
285 uint32_t aliveTimeout = packet->aliveDelay;
286
287 #ifdef ARCH_BE
288 SwapBytes(rnd);
289 SwapBytes(userTimeout);
290 SwapBytes(aliveDelay);
291 #endif
292
293 if (user->GetPhase() != 2)
294     {
295     errorStr = "Unexpected CONN_SYN_ACK";
296     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
297     return false;
298     }
299
300 user->SetPhase(3);
301 user->SetAliveTimeout(aliveTimeout);
302 user->SetUserTimeout(userTimeout);
303 user->SetRnd(rnd);
304
305 Send_CONN_ACK(user);
306
307 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
308
309 return true;
310 }
311
312 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
313 {
314 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
315
316 uint32_t rnd = packet->rnd;
317
318 #ifdef ARCH_BE
319 SwapBytes(rnd);
320 #endif
321
322 if (user->GetPhase() != 3)
323     {
324     errorStr = "Unexpected ALIVE_SYN";
325     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
326     return false;
327     }
328
329 user->SetPhase(3);
330 user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK
331
332 Send_ALIVE_ACK(user);
333
334 return true;
335 }
336
337 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
338 {
339 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
340
341 uint32_t rnd = packet->rnd;
342
343 #ifdef ARCH_BE
344 SwapBytes(rnd);
345 #endif
346
347 if (user->GetPhase() != 4)
348     {
349     errorStr = "Unexpected DISCONN_SYN_ACK";
350     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
351     return false;
352     }
353
354 if (user->GetRnd() + 1 != rnd)
355     {
356     errorStr = "Wrong control value at DISCONN_SYN_ACK";
357     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
358     }
359
360 user->SetPhase(5);
361 user->SetRnd(rnd);
362
363 Send_DISCONN_ACK(user);
364
365 return true;
366 }
367
368 bool PROTO::FIN_Proc(const void * buffer, USER * user)
369 {
370 if (user->GetPhase() != 5)
371     {
372     errorStr = "Unexpected FIN";
373     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
374     return false;
375     }
376
377 user->SetPhase(1);
378
379 return true;
380 }
381
382 bool PROTO::INFO_Proc(const void * buffer, USER * user)
383 {
384 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
385
386 return true;
387 }
388
389 bool PROTO::ERR_Proc(const void * buffer, USER * user)
390 {
391 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
392 const char * ptr = static_cast<const char *>(buffer);
393
394 //uint32_t len = packet->len;
395
396 #ifdef ARCH_BE
397 //SwapBytes(len);
398 #endif
399
400 user->SetPhase(1); //TODO: Check
401 /*KOIToWin((const char*)err.text, &messageText);
402 if (pErrorCb != NULL)
403     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
404 phaseTime = GetTickCount();
405 codeError = IA_SERVER_ERROR;*/
406
407 return true;
408 }
409
410 bool PROTO::Send_CONN_SYN(USER * user)
411 {
412 CONN_SYN_8 packet;
413
414 packet.len = sizeof(packet);
415
416 #ifdef ARCH_BE
417 SwapBytes(packet.len);
418 #endif
419
420 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
421 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
422 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
423 packet.dirs = 0xFFffFFff;
424
425 return SendPacket(&packet, sizeof(packet), user);
426 }
427
428 bool PROTO::Send_CONN_ACK(USER * user)
429 {
430 CONN_ACK_8 packet;
431
432 packet.len = sizeof(packet);
433 packet.rnd = user->IncRnd();
434
435 #ifdef ARCH_BE
436 SwapBytes(packet.len);
437 SwapBytes(packet.rnd);
438 #endif
439
440 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
441 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
442
443 return SendPacket(&packet, sizeof(packet), user);
444 }
445
446 bool PROTO::Send_ALIVE_ACK(USER * user)
447 {
448 ALIVE_ACK_8 packet;
449
450 packet.len = sizeof(packet);
451 packet.rnd = user->IncRnd();
452
453 #ifdef ARCH_BE
454 SwapBytes(packet.len);
455 SwapBytes(packet.rnd);
456 #endif
457
458 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
459 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
460
461 return SendPacket(&packet, sizeof(packet), user);
462 }
463
464 bool PROTO::Send_DISCONN_SYN(USER * user)
465 {
466 DISCONN_SYN_8 packet;
467
468 packet.len = sizeof(packet);
469
470 #ifdef ARCH_BE
471 SwapBytes(packet.len);
472 #endif
473
474 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
475 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
476 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
477
478 return SendPacket(&packet, sizeof(packet), user);
479 }
480
481 bool PROTO::Send_DISCONN_ACK(USER * user)
482 {
483 DISCONN_ACK_8 packet;
484
485 packet.len = sizeof(packet);
486 packet.rnd = user->IncRnd();
487
488 #ifdef ARCH_BE
489 SwapBytes(packet.len);
490 SwapBytes(packet.rnd);
491 #endif
492
493 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
494 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
495
496 return SendPacket(&packet, sizeof(packet), user);
497 }
498
499 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
500 {
501 HDR_8 hdr;
502
503 assert(length < 2048 && "Packet length must not exceed 2048 bytes");
504
505 strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic));
506 hdr.protoVer[0] = 0;
507 hdr.protoVer[1] = 8; // IA_PROTO_VER
508
509 unsigned char buffer[2048];
510 memset(buffer, 0, sizeof(buffer));
511 memcpy(buffer, packet, length);
512 memcpy(buffer, &hdr, sizeof(hdr));
513
514 size_t offset = sizeof(HDR_8);
515 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
516     {
517     Blowfish_Encrypt(&ctx,
518                      (uint32_t *)(buffer + offset + i * 8),
519                      (uint32_t *)(buffer + offset + i * 8 + 4));
520     }
521
522 offset += IA_LOGIN_LEN;
523 size_t encLen = (length - IA_LOGIN_LEN) / 8;
524 for (size_t i = 0; i < encLen; i++)
525     {
526     Blowfish_Encrypt(user->GetCtx(),
527                      (uint32_t*)(buffer + offset + i * 8),
528                      (uint32_t*)(buffer + offset + i * 8 + 4));
529     }
530
531 int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
532
533 if (res < 0)
534     {
535     errorStr = "Failed to send packet: '";
536     errorStr += strerror(errno);
537     errorStr += "'";
538     printfd(__FILE__, "PROTO::SendPacket() - %s, fd: %d\n", errorStr.c_str(), user->GetSocket());
539     return false;
540     }
541
542 if (res < length)
543     {
544     errorStr = "Packet sent partially";
545     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
546     return false;
547     }
548
549 return true;
550 }
551
552 bool PROTO::RealConnect(USER * user)
553 {
554 if (user->GetPhase() != 1 &&
555     user->GetPhase() != 5)
556     {
557     errorStr = "Unexpected connect";
558     printfd(__FILE__, "PROTO::RealConnect() - wrong phase: %d\n", user->GetPhase());
559     }
560 user->SetPhase(2);
561
562 return Send_CONN_SYN(user);
563 }
564
565 bool PROTO::RealDisconnect(USER * user)
566 {
567 if (user->GetPhase() != 3)
568     {
569     errorStr = "Unexpected disconnect";
570     printfd(__FILE__, "PROTO::RealDisconnect() - wrong phase: %d\n", user->GetPhase());
571     }
572 user->SetPhase(4);
573
574 return Send_DISCONN_SYN(user);
575 }