]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
[NY Flight] Improved socket reading for sgconf.
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <sys/types.h>
2 #include <sys/socket.h>
3 #include <netinet/in.h>
4 #include <netdb.h>
5 #include <arpa/inet.h>
6
7 #include <csignal>
8 #include <cerrno>
9 #include <cstring>
10 #include <cassert>
11 #include <stdexcept>
12 #include <algorithm>
13
14 #include "stg/common.h"
15 #include "stg/ia_packets.h"
16 #include "stg/locker.h"
17
18 #include "proto.h"
19
20 class HasIP : public std::unary_function<std::pair<uint32_t, USER>, bool> {
21     public:
22         explicit HasIP(uint32_t i) : ip(i) {}
23         bool operator()(const std::pair<uint32_t, USER> & value) { return value.first == ip; }
24     private:
25         uint32_t ip;
26 };
27
28 PROTO::PROTO(const std::string & server,
29              uint16_t port,
30              uint16_t localPort,
31              int to)
32     : timeout(to),
33       running(false),
34       stopped(true)
35 {
36 uint32_t ip = inet_addr(server.c_str());
37 if (ip == INADDR_NONE)
38     {
39     struct hostent * hePtr = gethostbyname(server.c_str());
40     if (hePtr)
41         {
42         ip = *((uint32_t *)hePtr->h_addr_list[0]);
43         }
44     else
45         {
46         errorStr = "Unknown host: '";
47         errorStr += server;
48         errorStr += "'";
49         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
50         throw std::runtime_error(errorStr);
51         }
52     }
53
54 localAddr.sin_family = AF_INET;
55 localAddr.sin_port = htons(localPort);
56 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
57
58 serverAddr.sin_family = AF_INET;
59 serverAddr.sin_port = htons(port);
60 serverAddr.sin_addr.s_addr = ip;
61
62 unsigned char key[IA_PASSWD_LEN];
63 memset(key, 0, IA_PASSWD_LEN);
64 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
65 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
66
67 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
68 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
69 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
70 processors["FIN"] = &PROTO::FIN_Proc;
71 processors["INFO"] = &PROTO::INFO_Proc;
72 // ERR_Proc will be handled explicitly
73
74 pthread_mutex_init(&mutex, NULL);
75 }
76
77 PROTO::~PROTO()
78 {
79 pthread_mutex_destroy(&mutex);
80 }
81
82 void * PROTO::Runner(void * data)
83 {
84 sigset_t signalSet;
85 sigfillset(&signalSet);
86 pthread_sigmask(SIG_BLOCK, &signalSet, NULL);
87
88 PROTO * protoPtr = static_cast<PROTO *>(data);
89 protoPtr->Run();
90 return NULL;
91 }
92
93 bool PROTO::Start()
94 {
95 stopped = false;
96 running = true;
97 if (pthread_create(&tid, NULL, &Runner, this))
98     {
99     errorStr = "Failed to create listening thread: '";
100     errorStr += strerror(errno);
101     errorStr += "'";
102     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
103     return false;
104     }
105 return true;
106 }
107
108 bool PROTO::Stop()
109 {
110 running = false;
111 int time = 0;
112 while (!stopped && time < timeout)
113     {
114     struct timespec ts = {1, 0};
115     nanosleep(&ts, NULL);
116     ++time;
117     }
118 if (!stopped)
119     {
120     errorStr = "Failed to stop listening thread - timed out";
121     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
122     return false;
123     }
124 if (pthread_join(tid, NULL))
125     {
126     errorStr = "Failed to join listening thread after stop: '";
127     errorStr += strerror(errno);
128     errorStr += "'";
129     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
130     return false;
131     }
132 return true;
133 }
134
135 void PROTO::AddUser(const USER & user, bool connect)
136 {
137 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
138 users.push_back(std::make_pair(user.GetIP(), user));
139 users.back().second.InitNetwork();
140
141 struct pollfd pfd;
142 pfd.fd = users.back().second.GetSocket();
143 pfd.events = POLLIN;
144 pfd.revents = 0;
145 pollFds.push_back(pfd);
146
147 if (connect)
148     {
149     RealConnect(&users.back().second);
150     }
151 }
152
153 bool PROTO::Connect(uint32_t ip)
154 {
155 std::list<std::pair<uint32_t, USER> >::iterator it;
156 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
157 it = std::find_if(users.begin(), users.end(), HasIP(ip));
158 if (it == users.end())
159     return false;
160
161 // Do something
162
163 return RealConnect(&it->second);
164 }
165
166 bool PROTO::Disconnect(uint32_t ip)
167 {
168 std::list<std::pair<uint32_t, USER> >::iterator it;
169 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
170 it = std::find_if(users.begin(), users.end(), HasIP(ip));
171 if (it == users.end())
172     return false;
173
174 // Do something
175
176 return RealDisconnect(&it->second);
177 }
178
179 void PROTO::Run()
180 {
181 while (running)
182     {
183     int res;
184         {
185         STG_LOCKER lock(&mutex, __FILE__, __LINE__);
186         res = poll(&pollFds.front(), pollFds.size(), timeout);
187         }
188     if (res < 0)
189         break;
190     if (!running)
191         break;
192     if (res)
193         {
194         printfd(__FILE__, "PROTO::Run() - events: %d\n", res);
195         RecvPacket();
196         }
197     else
198         {
199         CheckTimeouts();
200         }
201     }
202
203 stopped = true;
204 }
205
206 void PROTO::CheckTimeouts()
207 {
208 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
209 std::list<std::pair<uint32_t, USER> >::iterator it;
210 for (it = users.begin(); it != users.end(); ++it)
211     {
212     int delta = difftime(time(NULL), it->second.GetPhaseChangeTime());
213     if ((it->second.GetPhase() == 3) &&
214         (delta > it->second.GetUserTimeout()))
215         {
216         printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout());
217         it->second.SetPhase(1);
218         }
219     if ((it->second.GetPhase() == 2) &&
220         (delta > it->second.GetAliveTimeout()))
221         {
222         printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout());
223         it->second.SetPhase(1);
224         }
225     }
226 }
227
228 bool PROTO::RecvPacket()
229 {
230 bool result = true;
231 std::vector<struct pollfd>::iterator it;
232 std::list<std::pair<uint32_t, USER> >::iterator userIt;
233 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
234 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
235     {
236     if (it->revents)
237         {
238         it->revents = 0;
239         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
240         struct sockaddr_in addr;
241         socklen_t fromLen = sizeof(addr);
242         char buffer[2048];
243         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
244
245         if (res == -1)
246             {
247             result = false;
248             continue;
249             }
250
251         result = result && HandlePacket(buffer, res, &(userIt->second));
252         }
253     }
254
255 return result;
256 }
257
258 bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user)
259 {
260 if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3))
261     {
262     return ERR_Proc(buffer, user);
263     }
264
265 for (size_t i = 0; i < length / 8; i++)
266     Blowfish_Decrypt(user->GetCtx(),
267                      (uint32_t *)(buffer + i * 8),
268                      (uint32_t *)(buffer + i * 8 + 4));
269
270 std::string packetName(buffer + 12);
271
272 std::map<std::string, PacketProcessor>::const_iterator it;
273 it = processors.find(packetName);
274 if (it != processors.end())
275     return (this->*it->second)(buffer, user);
276
277 printfd(__FILE__, "PROTO::HandlePacket() - invalid packet signature: '%s'\n", packetName.c_str());
278
279 return false;
280 }
281
282 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
283 {
284 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
285
286 uint32_t rnd = packet->rnd;
287 uint32_t userTimeout = packet->userTimeOut;
288 uint32_t aliveTimeout = packet->aliveDelay;
289
290 #ifdef ARCH_BE
291 SwapBytes(rnd);
292 SwapBytes(userTimeout);
293 SwapBytes(aliveDelay);
294 #endif
295
296 if (user->GetPhase() != 2)
297     {
298     errorStr = "Unexpected CONN_SYN_ACK";
299     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
300     return false;
301     }
302
303 user->SetPhase(3);
304 user->SetAliveTimeout(aliveTimeout);
305 user->SetUserTimeout(userTimeout);
306 user->SetRnd(rnd);
307
308 Send_CONN_ACK(user);
309
310 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
311
312 return true;
313 }
314
315 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
316 {
317 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
318
319 uint32_t rnd = packet->rnd;
320
321 #ifdef ARCH_BE
322 SwapBytes(rnd);
323 #endif
324
325 if (user->GetPhase() != 3)
326     {
327     errorStr = "Unexpected ALIVE_SYN";
328     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
329     return false;
330     }
331
332 user->SetPhase(3);
333 user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK
334
335 Send_ALIVE_ACK(user);
336
337 return true;
338 }
339
340 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
341 {
342 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
343
344 uint32_t rnd = packet->rnd;
345
346 #ifdef ARCH_BE
347 SwapBytes(rnd);
348 #endif
349
350 if (user->GetPhase() != 4)
351     {
352     errorStr = "Unexpected DISCONN_SYN_ACK";
353     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
354     return false;
355     }
356
357 if (user->GetRnd() + 1 != rnd)
358     {
359     errorStr = "Wrong control value at DISCONN_SYN_ACK";
360     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
361     }
362
363 user->SetPhase(5);
364 user->SetRnd(rnd);
365
366 Send_DISCONN_ACK(user);
367
368 return true;
369 }
370
371 bool PROTO::FIN_Proc(const void * buffer, USER * user)
372 {
373 if (user->GetPhase() != 5)
374     {
375     errorStr = "Unexpected FIN";
376     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
377     return false;
378     }
379
380 user->SetPhase(1);
381
382 return true;
383 }
384
385 bool PROTO::INFO_Proc(const void * buffer, USER * user)
386 {
387 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
388
389 return true;
390 }
391
392 bool PROTO::ERR_Proc(const void * buffer, USER * user)
393 {
394 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
395 const char * ptr = static_cast<const char *>(buffer);
396
397 //uint32_t len = packet->len;
398
399 #ifdef ARCH_BE
400 //SwapBytes(len);
401 #endif
402
403 user->SetPhase(1); //TODO: Check
404 /*KOIToWin((const char*)err.text, &messageText);
405 if (pErrorCb != NULL)
406     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
407 phaseTime = GetTickCount();
408 codeError = IA_SERVER_ERROR;*/
409
410 return true;
411 }
412
413 bool PROTO::Send_CONN_SYN(USER * user)
414 {
415 CONN_SYN_8 packet;
416
417 packet.len = sizeof(packet);
418
419 #ifdef ARCH_BE
420 SwapBytes(packet.len);
421 #endif
422
423 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
424 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
425 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
426 packet.dirs = 0xFFffFFff;
427
428 return SendPacket(&packet, sizeof(packet), user);
429 }
430
431 bool PROTO::Send_CONN_ACK(USER * user)
432 {
433 CONN_ACK_8 packet;
434
435 packet.len = sizeof(packet);
436 packet.rnd = user->IncRnd();
437
438 #ifdef ARCH_BE
439 SwapBytes(packet.len);
440 SwapBytes(packet.rnd);
441 #endif
442
443 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
444 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
445
446 return SendPacket(&packet, sizeof(packet), user);
447 }
448
449 bool PROTO::Send_ALIVE_ACK(USER * user)
450 {
451 ALIVE_ACK_8 packet;
452
453 packet.len = sizeof(packet);
454 packet.rnd = user->IncRnd();
455
456 #ifdef ARCH_BE
457 SwapBytes(packet.len);
458 SwapBytes(packet.rnd);
459 #endif
460
461 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
462 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
463
464 return SendPacket(&packet, sizeof(packet), user);
465 }
466
467 bool PROTO::Send_DISCONN_SYN(USER * user)
468 {
469 DISCONN_SYN_8 packet;
470
471 packet.len = sizeof(packet);
472
473 #ifdef ARCH_BE
474 SwapBytes(packet.len);
475 #endif
476
477 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
478 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
479 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
480
481 return SendPacket(&packet, sizeof(packet), user);
482 }
483
484 bool PROTO::Send_DISCONN_ACK(USER * user)
485 {
486 DISCONN_ACK_8 packet;
487
488 packet.len = sizeof(packet);
489 packet.rnd = user->IncRnd();
490
491 #ifdef ARCH_BE
492 SwapBytes(packet.len);
493 SwapBytes(packet.rnd);
494 #endif
495
496 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
497 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
498
499 return SendPacket(&packet, sizeof(packet), user);
500 }
501
502 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
503 {
504 HDR_8 hdr;
505
506 assert(length < 2048 && "Packet length must not exceed 2048 bytes");
507
508 strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic));
509 hdr.protoVer[0] = 0;
510 hdr.protoVer[1] = 8; // IA_PROTO_VER
511
512 unsigned char buffer[2048];
513 memset(buffer, 0, sizeof(buffer));
514 memcpy(buffer, packet, length);
515 memcpy(buffer, &hdr, sizeof(hdr));
516
517 size_t offset = sizeof(HDR_8);
518 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
519     {
520     Blowfish_Encrypt(&ctx,
521                      (uint32_t *)(buffer + offset + i * 8),
522                      (uint32_t *)(buffer + offset + i * 8 + 4));
523     }
524
525 offset += IA_LOGIN_LEN;
526 size_t encLen = (length - IA_LOGIN_LEN) / 8;
527 for (size_t i = 0; i < encLen; i++)
528     {
529     Blowfish_Encrypt(user->GetCtx(),
530                      (uint32_t*)(buffer + offset + i * 8),
531                      (uint32_t*)(buffer + offset + i * 8 + 4));
532     }
533
534 int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
535
536 if (res < 0)
537     {
538     errorStr = "Failed to send packet: '";
539     errorStr += strerror(errno);
540     errorStr += "'";
541     printfd(__FILE__, "PROTO::SendPacket() - %s, fd: %d\n", errorStr.c_str(), user->GetSocket());
542     return false;
543     }
544
545 if (res < length)
546     {
547     errorStr = "Packet sent partially";
548     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
549     return false;
550     }
551
552 return true;
553 }
554
555 bool PROTO::RealConnect(USER * user)
556 {
557 if (user->GetPhase() != 1 &&
558     user->GetPhase() != 5)
559     {
560     errorStr = "Unexpected connect";
561     printfd(__FILE__, "PROTO::RealConnect() - wrong phase: %d\n", user->GetPhase());
562     }
563 user->SetPhase(2);
564
565 return Send_CONN_SYN(user);
566 }
567
568 bool PROTO::RealDisconnect(USER * user)
569 {
570 if (user->GetPhase() != 3)
571     {
572     errorStr = "Unexpected disconnect";
573     printfd(__FILE__, "PROTO::RealDisconnect() - wrong phase: %d\n", user->GetPhase());
574     }
575 user->SetPhase(4);
576
577 return Send_DISCONN_SYN(user);
578 }