]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
Use common filtering algorithm with predicate specialization
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <netdb.h>
2 #include <arpa/inet.h>
3
4 #include <cerrno>
5 #include <cstring>
6 #include <cassert>
7 #include <stdexcept>
8 #include <algorithm>
9
10 #include "stg/common.h"
11 #include "stg/ia_packets.h"
12 #include "stg/locker.h"
13
14 #include "proto.h"
15
16 class HasIP : public std::unary_function<std::pair<uint32_t, USER>, bool> {
17     public:
18         explicit HasIP(uint32_t i) : ip(i) {}
19         bool operator()(const std::pair<uint32_t, USER> & value) { return value.first == ip; }
20     private:
21         uint32_t ip;
22 };
23
24 PROTO::PROTO(const std::string & server,
25              uint16_t port,
26              uint16_t localPort,
27              int to)
28     : timeout(to),
29       running(false),
30       stopped(true)
31 {
32 uint32_t ip = inet_addr(server.c_str());
33 if (ip == INADDR_NONE)
34     {
35     struct hostent * hePtr = gethostbyname(server.c_str());
36     if (hePtr)
37         {
38         ip = *((uint32_t *)hePtr->h_addr_list[0]);
39         }
40     else
41         {
42         errorStr = "Unknown host: '";
43         errorStr += server;
44         errorStr += "'";
45         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
46         throw std::runtime_error(errorStr);
47         }
48     }
49
50 localAddr.sin_family = AF_INET;
51 localAddr.sin_port = htons(localPort);
52 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
53
54 serverAddr.sin_family = AF_INET;
55 serverAddr.sin_port = htons(port);
56 serverAddr.sin_addr.s_addr = ip;
57
58 unsigned char key[IA_PASSWD_LEN];
59 memset(key, 0, IA_PASSWD_LEN);
60 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
61 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
62
63 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
64 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
65 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
66 processors["FIN"] = &PROTO::FIN_Proc;
67 processors["INFO"] = &PROTO::INFO_Proc;
68 // ERR_Proc will be handled explicitly
69
70 pthread_mutex_init(&mutex, NULL);
71 }
72
73 PROTO::~PROTO()
74 {
75 pthread_mutex_destroy(&mutex);
76 }
77
78 void * PROTO::Runner(void * data)
79 {
80 PROTO * protoPtr = static_cast<PROTO *>(data);
81 protoPtr->Run();
82 return NULL;
83 }
84
85 bool PROTO::Start()
86 {
87 stopped = false;
88 running = true;
89 if (pthread_create(&tid, NULL, &Runner, this))
90     {
91     errorStr = "Failed to create listening thread: '";
92     errorStr += strerror(errno);
93     errorStr += "'";
94     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
95     return false;
96     }
97 return true;
98 }
99
100 bool PROTO::Stop()
101 {
102 running = false;
103 int time = 0;
104 while (!stopped && time < timeout)
105     {
106     struct timespec ts = {1, 0};
107     nanosleep(&ts, NULL);
108     ++time;
109     }
110 if (!stopped)
111     {
112     errorStr = "Failed to stop listening thread - timed out";
113     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
114     return false;
115     }
116 if (pthread_join(tid, NULL))
117     {
118     errorStr = "Failed to join listening thread after stop: '";
119     errorStr += strerror(errno);
120     errorStr += "'";
121     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
122     return false;
123     }
124 return true;
125 }
126
127 void PROTO::AddUser(const USER & user, bool connect)
128 {
129 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
130 users.push_back(std::make_pair(user.GetIP(), user));
131 users.back().second.InitNetwork();
132
133 struct pollfd pfd;
134 pfd.fd = users.back().second.GetSocket();
135 pfd.events = POLLIN;
136 pfd.revents = 0;
137 pollFds.push_back(pfd);
138
139 if (connect)
140     {
141     RealConnect(&users.back().second);
142     }
143 }
144
145 bool PROTO::Connect(uint32_t ip)
146 {
147 std::list<std::pair<uint32_t, USER> >::iterator it;
148 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
149 it = std::find_if(users.begin(), users.end(), HasIP(ip));
150 if (it == users.end())
151     return false;
152
153 // Do something
154
155 return RealConnect(&it->second);
156 }
157
158 bool PROTO::Disconnect(uint32_t ip)
159 {
160 std::list<std::pair<uint32_t, USER> >::iterator it;
161 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
162 it = std::find_if(users.begin(), users.end(), HasIP(ip));
163 if (it == users.end())
164     return false;
165
166 // Do something
167
168 return RealDisconnect(&it->second);
169 }
170
171 void PROTO::Run()
172 {
173 while (running)
174     {
175     int res;
176         {
177         STG_LOCKER lock(&mutex, __FILE__, __LINE__);
178         res = poll(&pollFds.front(), pollFds.size(), timeout);
179         }
180     if (res < 0)
181         break;
182     if (!running)
183         break;
184     if (res)
185         {
186         printfd(__FILE__, "PROTO::Run() - events: %d\n", res);
187         RecvPacket();
188         }
189     else
190         {
191         CheckTimeouts();
192         }
193     }
194
195 stopped = true;
196 }
197
198 void PROTO::CheckTimeouts()
199 {
200 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
201 std::list<std::pair<uint32_t, USER> >::iterator it;
202 for (it = users.begin(); it != users.end(); ++it)
203     {
204     int delta = difftime(time(NULL), it->second.GetPhaseChangeTime());
205     if ((it->second.GetPhase() == 3) &&
206         (delta > it->second.GetUserTimeout()))
207         {
208         printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout());
209         it->second.SetPhase(1);
210         }
211     if ((it->second.GetPhase() == 2) &&
212         (delta > it->second.GetAliveTimeout()))
213         {
214         printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout());
215         it->second.SetPhase(1);
216         }
217     }
218 }
219
220 bool PROTO::RecvPacket()
221 {
222 bool result = true;
223 std::vector<struct pollfd>::iterator it;
224 std::list<std::pair<uint32_t, USER> >::iterator userIt;
225 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
226 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
227     {
228     if (it->revents)
229         {
230         it->revents = 0;
231         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
232         struct sockaddr_in addr;
233         socklen_t fromLen = sizeof(addr);
234         char buffer[2048];
235         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
236
237         if (res == -1)
238             {
239             result = false;
240             continue;
241             }
242
243         result = result && HandlePacket(buffer, res, &(userIt->second));
244         }
245     }
246
247 return result;
248 }
249
250 bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user)
251 {
252 if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3))
253     {
254     return ERR_Proc(buffer, user);
255     }
256
257 for (size_t i = 0; i < length / 8; i++)
258     Blowfish_Decrypt(user->GetCtx(),
259                      (uint32_t *)(buffer + i * 8),
260                      (uint32_t *)(buffer + i * 8 + 4));
261
262 std::string packetName(buffer + 12);
263
264 std::map<std::string, PacketProcessor>::const_iterator it;
265 it = processors.find(packetName);
266 if (it != processors.end())
267     return (this->*it->second)(buffer, user);
268
269 printfd(__FILE__, "PROTO::HandlePacket() - invalid packet signature: '%s'\n", packetName.c_str());
270
271 return false;
272 }
273
274 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
275 {
276 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
277
278 uint32_t rnd = packet->rnd;
279 uint32_t userTimeout = packet->userTimeOut;
280 uint32_t aliveTimeout = packet->aliveDelay;
281
282 #ifdef ARCH_BE
283 SwapBytes(rnd);
284 SwapBytes(userTimeout);
285 SwapBytes(aliveDelay);
286 #endif
287
288 if (user->GetPhase() != 2)
289     {
290     errorStr = "Unexpected CONN_SYN_ACK";
291     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
292     return false;
293     }
294
295 user->SetPhase(3);
296 user->SetAliveTimeout(aliveTimeout);
297 user->SetUserTimeout(userTimeout);
298 user->SetRnd(rnd);
299
300 Send_CONN_ACK(user);
301
302 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
303
304 return true;
305 }
306
307 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
308 {
309 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
310
311 uint32_t rnd = packet->rnd;
312
313 #ifdef ARCH_BE
314 SwapBytes(rnd);
315 #endif
316
317 if (user->GetPhase() != 3)
318     {
319     errorStr = "Unexpected ALIVE_SYN";
320     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
321     return false;
322     }
323
324 user->SetPhase(3);
325 user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK
326
327 Send_ALIVE_ACK(user);
328
329 return true;
330 }
331
332 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
333 {
334 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
335
336 uint32_t rnd = packet->rnd;
337
338 #ifdef ARCH_BE
339 SwapBytes(rnd);
340 #endif
341
342 if (user->GetPhase() != 4)
343     {
344     errorStr = "Unexpected DISCONN_SYN_ACK";
345     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
346     return false;
347     }
348
349 if (user->GetRnd() + 1 != rnd)
350     {
351     errorStr = "Wrong control value at DISCONN_SYN_ACK";
352     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
353     }
354
355 user->SetPhase(5);
356 user->SetRnd(rnd);
357
358 Send_DISCONN_ACK(user);
359
360 return true;
361 }
362
363 bool PROTO::FIN_Proc(const void * buffer, USER * user)
364 {
365 if (user->GetPhase() != 5)
366     {
367     errorStr = "Unexpected FIN";
368     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
369     return false;
370     }
371
372 user->SetPhase(1);
373
374 return true;
375 }
376
377 bool PROTO::INFO_Proc(const void * buffer, USER * user)
378 {
379 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
380
381 return true;
382 }
383
384 bool PROTO::ERR_Proc(const void * buffer, USER * user)
385 {
386 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
387 const char * ptr = static_cast<const char *>(buffer);
388
389 //uint32_t len = packet->len;
390
391 #ifdef ARCH_BE
392 //SwapBytes(len);
393 #endif
394
395 user->SetPhase(1); //TODO: Check
396 /*KOIToWin((const char*)err.text, &messageText);
397 if (pErrorCb != NULL)
398     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
399 phaseTime = GetTickCount();
400 codeError = IA_SERVER_ERROR;*/
401
402 return true;
403 }
404
405 bool PROTO::Send_CONN_SYN(USER * user)
406 {
407 CONN_SYN_8 packet;
408
409 packet.len = sizeof(packet);
410
411 #ifdef ARCH_BE
412 SwapBytes(packet.len);
413 #endif
414
415 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
416 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
417 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
418 packet.dirs = 0xFFffFFff;
419
420 return SendPacket(&packet, sizeof(packet), user);
421 }
422
423 bool PROTO::Send_CONN_ACK(USER * user)
424 {
425 CONN_ACK_8 packet;
426
427 packet.len = sizeof(packet);
428 packet.rnd = user->IncRnd();
429
430 #ifdef ARCH_BE
431 SwapBytes(packet.len);
432 SwapBytes(packet.rnd);
433 #endif
434
435 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
436 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
437
438 return SendPacket(&packet, sizeof(packet), user);
439 }
440
441 bool PROTO::Send_ALIVE_ACK(USER * user)
442 {
443 ALIVE_ACK_8 packet;
444
445 packet.len = sizeof(packet);
446 packet.rnd = user->IncRnd();
447
448 #ifdef ARCH_BE
449 SwapBytes(packet.len);
450 SwapBytes(packet.rnd);
451 #endif
452
453 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
454 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
455
456 return SendPacket(&packet, sizeof(packet), user);
457 }
458
459 bool PROTO::Send_DISCONN_SYN(USER * user)
460 {
461 DISCONN_SYN_8 packet;
462
463 packet.len = sizeof(packet);
464
465 #ifdef ARCH_BE
466 SwapBytes(packet.len);
467 #endif
468
469 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
470 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
471 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
472
473 return SendPacket(&packet, sizeof(packet), user);
474 }
475
476 bool PROTO::Send_DISCONN_ACK(USER * user)
477 {
478 DISCONN_ACK_8 packet;
479
480 packet.len = sizeof(packet);
481 packet.rnd = user->IncRnd();
482
483 #ifdef ARCH_BE
484 SwapBytes(packet.len);
485 SwapBytes(packet.rnd);
486 #endif
487
488 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
489 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
490
491 return SendPacket(&packet, sizeof(packet), user);
492 }
493
494 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
495 {
496 HDR_8 hdr;
497
498 assert(length < 2048 && "Packet length must not exceed 2048 bytes");
499
500 strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic));
501 hdr.protoVer[0] = 0;
502 hdr.protoVer[1] = 8; // IA_PROTO_VER
503
504 unsigned char buffer[2048];
505 memset(buffer, 0, sizeof(buffer));
506 memcpy(buffer, packet, length);
507 memcpy(buffer, &hdr, sizeof(hdr));
508
509 size_t offset = sizeof(HDR_8);
510 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
511     {
512     Blowfish_Encrypt(&ctx,
513                      (uint32_t *)(buffer + offset + i * 8),
514                      (uint32_t *)(buffer + offset + i * 8 + 4));
515     }
516
517 offset += IA_LOGIN_LEN;
518 size_t encLen = (length - IA_LOGIN_LEN) / 8;
519 for (size_t i = 0; i < encLen; i++)
520     {
521     Blowfish_Encrypt(user->GetCtx(),
522                      (uint32_t*)(buffer + offset + i * 8),
523                      (uint32_t*)(buffer + offset + i * 8 + 4));
524     }
525
526 int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
527
528 if (res < 0)
529     {
530     errorStr = "Failed to send packet: '";
531     errorStr += strerror(errno);
532     errorStr += "'";
533     printfd(__FILE__, "PROTO::SendPacket() - %s, fd: %d\n", errorStr.c_str(), user->GetSocket());
534     return false;
535     }
536
537 if (res < length)
538     {
539     errorStr = "Packet sent partially";
540     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
541     return false;
542     }
543
544 return true;
545 }
546
547 bool PROTO::RealConnect(USER * user)
548 {
549 if (user->GetPhase() != 1 &&
550     user->GetPhase() != 5)
551     {
552     errorStr = "Unexpected connect";
553     printfd(__FILE__, "PROTO::RealConnect() - wrong phase: %d\n", user->GetPhase());
554     }
555 user->SetPhase(2);
556
557 return Send_CONN_SYN(user);
558 }
559
560 bool PROTO::RealDisconnect(USER * user)
561 {
562 if (user->GetPhase() != 3)
563     {
564     errorStr = "Unexpected disconnect";
565     printfd(__FILE__, "PROTO::RealDisconnect() - wrong phase: %d\n", user->GetPhase());
566     }
567 user->SetPhase(4);
568
569 return Send_DISCONN_SYN(user);
570 }