]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
Add proper header to use recvfrom/sendto in sgauthstress
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <sys/types.h>
2 #include <sys/socket.h>
3 #include <netdb.h>
4 #include <arpa/inet.h>
5
6 #include <csignal>
7 #include <cerrno>
8 #include <cstring>
9 #include <cassert>
10 #include <stdexcept>
11 #include <algorithm>
12
13 #include "stg/common.h"
14 #include "stg/ia_packets.h"
15 #include "stg/locker.h"
16
17 #include "proto.h"
18
19 class HasIP : public std::unary_function<std::pair<uint32_t, USER>, bool> {
20     public:
21         explicit HasIP(uint32_t i) : ip(i) {}
22         bool operator()(const std::pair<uint32_t, USER> & value) { return value.first == ip; }
23     private:
24         uint32_t ip;
25 };
26
27 PROTO::PROTO(const std::string & server,
28              uint16_t port,
29              uint16_t localPort,
30              int to)
31     : timeout(to),
32       running(false),
33       stopped(true)
34 {
35 uint32_t ip = inet_addr(server.c_str());
36 if (ip == INADDR_NONE)
37     {
38     struct hostent * hePtr = gethostbyname(server.c_str());
39     if (hePtr)
40         {
41         ip = *((uint32_t *)hePtr->h_addr_list[0]);
42         }
43     else
44         {
45         errorStr = "Unknown host: '";
46         errorStr += server;
47         errorStr += "'";
48         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
49         throw std::runtime_error(errorStr);
50         }
51     }
52
53 localAddr.sin_family = AF_INET;
54 localAddr.sin_port = htons(localPort);
55 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
56
57 serverAddr.sin_family = AF_INET;
58 serverAddr.sin_port = htons(port);
59 serverAddr.sin_addr.s_addr = ip;
60
61 unsigned char key[IA_PASSWD_LEN];
62 memset(key, 0, IA_PASSWD_LEN);
63 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
64 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
65
66 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
67 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
68 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
69 processors["FIN"] = &PROTO::FIN_Proc;
70 processors["INFO"] = &PROTO::INFO_Proc;
71 // ERR_Proc will be handled explicitly
72
73 pthread_mutex_init(&mutex, NULL);
74 }
75
76 PROTO::~PROTO()
77 {
78 pthread_mutex_destroy(&mutex);
79 }
80
81 void * PROTO::Runner(void * data)
82 {
83 sigset_t signalSet;
84 sigfillset(&signalSet);
85 pthread_sigmask(SIG_BLOCK, &signalSet, NULL);
86
87 PROTO * protoPtr = static_cast<PROTO *>(data);
88 protoPtr->Run();
89 return NULL;
90 }
91
92 bool PROTO::Start()
93 {
94 stopped = false;
95 running = true;
96 if (pthread_create(&tid, NULL, &Runner, this))
97     {
98     errorStr = "Failed to create listening thread: '";
99     errorStr += strerror(errno);
100     errorStr += "'";
101     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
102     return false;
103     }
104 return true;
105 }
106
107 bool PROTO::Stop()
108 {
109 running = false;
110 int time = 0;
111 while (!stopped && time < timeout)
112     {
113     struct timespec ts = {1, 0};
114     nanosleep(&ts, NULL);
115     ++time;
116     }
117 if (!stopped)
118     {
119     errorStr = "Failed to stop listening thread - timed out";
120     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
121     return false;
122     }
123 if (pthread_join(tid, NULL))
124     {
125     errorStr = "Failed to join listening thread after stop: '";
126     errorStr += strerror(errno);
127     errorStr += "'";
128     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
129     return false;
130     }
131 return true;
132 }
133
134 void PROTO::AddUser(const USER & user, bool connect)
135 {
136 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
137 users.push_back(std::make_pair(user.GetIP(), user));
138 users.back().second.InitNetwork();
139
140 struct pollfd pfd;
141 pfd.fd = users.back().second.GetSocket();
142 pfd.events = POLLIN;
143 pfd.revents = 0;
144 pollFds.push_back(pfd);
145
146 if (connect)
147     {
148     RealConnect(&users.back().second);
149     }
150 }
151
152 bool PROTO::Connect(uint32_t ip)
153 {
154 std::list<std::pair<uint32_t, USER> >::iterator it;
155 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
156 it = std::find_if(users.begin(), users.end(), HasIP(ip));
157 if (it == users.end())
158     return false;
159
160 // Do something
161
162 return RealConnect(&it->second);
163 }
164
165 bool PROTO::Disconnect(uint32_t ip)
166 {
167 std::list<std::pair<uint32_t, USER> >::iterator it;
168 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
169 it = std::find_if(users.begin(), users.end(), HasIP(ip));
170 if (it == users.end())
171     return false;
172
173 // Do something
174
175 return RealDisconnect(&it->second);
176 }
177
178 void PROTO::Run()
179 {
180 while (running)
181     {
182     int res;
183         {
184         STG_LOCKER lock(&mutex, __FILE__, __LINE__);
185         res = poll(&pollFds.front(), pollFds.size(), timeout);
186         }
187     if (res < 0)
188         break;
189     if (!running)
190         break;
191     if (res)
192         {
193         printfd(__FILE__, "PROTO::Run() - events: %d\n", res);
194         RecvPacket();
195         }
196     else
197         {
198         CheckTimeouts();
199         }
200     }
201
202 stopped = true;
203 }
204
205 void PROTO::CheckTimeouts()
206 {
207 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
208 std::list<std::pair<uint32_t, USER> >::iterator it;
209 for (it = users.begin(); it != users.end(); ++it)
210     {
211     int delta = difftime(time(NULL), it->second.GetPhaseChangeTime());
212     if ((it->second.GetPhase() == 3) &&
213         (delta > it->second.GetUserTimeout()))
214         {
215         printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout());
216         it->second.SetPhase(1);
217         }
218     if ((it->second.GetPhase() == 2) &&
219         (delta > it->second.GetAliveTimeout()))
220         {
221         printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout());
222         it->second.SetPhase(1);
223         }
224     }
225 }
226
227 bool PROTO::RecvPacket()
228 {
229 bool result = true;
230 std::vector<struct pollfd>::iterator it;
231 std::list<std::pair<uint32_t, USER> >::iterator userIt;
232 STG_LOCKER lock(&mutex, __FILE__, __LINE__);
233 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
234     {
235     if (it->revents)
236         {
237         it->revents = 0;
238         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
239         struct sockaddr_in addr;
240         socklen_t fromLen = sizeof(addr);
241         char buffer[2048];
242         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
243
244         if (res == -1)
245             {
246             result = false;
247             continue;
248             }
249
250         result = result && HandlePacket(buffer, res, &(userIt->second));
251         }
252     }
253
254 return result;
255 }
256
257 bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user)
258 {
259 if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3))
260     {
261     return ERR_Proc(buffer, user);
262     }
263
264 for (size_t i = 0; i < length / 8; i++)
265     Blowfish_Decrypt(user->GetCtx(),
266                      (uint32_t *)(buffer + i * 8),
267                      (uint32_t *)(buffer + i * 8 + 4));
268
269 std::string packetName(buffer + 12);
270
271 std::map<std::string, PacketProcessor>::const_iterator it;
272 it = processors.find(packetName);
273 if (it != processors.end())
274     return (this->*it->second)(buffer, user);
275
276 printfd(__FILE__, "PROTO::HandlePacket() - invalid packet signature: '%s'\n", packetName.c_str());
277
278 return false;
279 }
280
281 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
282 {
283 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
284
285 uint32_t rnd = packet->rnd;
286 uint32_t userTimeout = packet->userTimeOut;
287 uint32_t aliveTimeout = packet->aliveDelay;
288
289 #ifdef ARCH_BE
290 SwapBytes(rnd);
291 SwapBytes(userTimeout);
292 SwapBytes(aliveDelay);
293 #endif
294
295 if (user->GetPhase() != 2)
296     {
297     errorStr = "Unexpected CONN_SYN_ACK";
298     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
299     return false;
300     }
301
302 user->SetPhase(3);
303 user->SetAliveTimeout(aliveTimeout);
304 user->SetUserTimeout(userTimeout);
305 user->SetRnd(rnd);
306
307 Send_CONN_ACK(user);
308
309 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
310
311 return true;
312 }
313
314 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
315 {
316 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
317
318 uint32_t rnd = packet->rnd;
319
320 #ifdef ARCH_BE
321 SwapBytes(rnd);
322 #endif
323
324 if (user->GetPhase() != 3)
325     {
326     errorStr = "Unexpected ALIVE_SYN";
327     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
328     return false;
329     }
330
331 user->SetPhase(3);
332 user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK
333
334 Send_ALIVE_ACK(user);
335
336 return true;
337 }
338
339 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
340 {
341 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
342
343 uint32_t rnd = packet->rnd;
344
345 #ifdef ARCH_BE
346 SwapBytes(rnd);
347 #endif
348
349 if (user->GetPhase() != 4)
350     {
351     errorStr = "Unexpected DISCONN_SYN_ACK";
352     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
353     return false;
354     }
355
356 if (user->GetRnd() + 1 != rnd)
357     {
358     errorStr = "Wrong control value at DISCONN_SYN_ACK";
359     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
360     }
361
362 user->SetPhase(5);
363 user->SetRnd(rnd);
364
365 Send_DISCONN_ACK(user);
366
367 return true;
368 }
369
370 bool PROTO::FIN_Proc(const void * buffer, USER * user)
371 {
372 if (user->GetPhase() != 5)
373     {
374     errorStr = "Unexpected FIN";
375     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
376     return false;
377     }
378
379 user->SetPhase(1);
380
381 return true;
382 }
383
384 bool PROTO::INFO_Proc(const void * buffer, USER * user)
385 {
386 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
387
388 return true;
389 }
390
391 bool PROTO::ERR_Proc(const void * buffer, USER * user)
392 {
393 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
394 const char * ptr = static_cast<const char *>(buffer);
395
396 //uint32_t len = packet->len;
397
398 #ifdef ARCH_BE
399 //SwapBytes(len);
400 #endif
401
402 user->SetPhase(1); //TODO: Check
403 /*KOIToWin((const char*)err.text, &messageText);
404 if (pErrorCb != NULL)
405     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
406 phaseTime = GetTickCount();
407 codeError = IA_SERVER_ERROR;*/
408
409 return true;
410 }
411
412 bool PROTO::Send_CONN_SYN(USER * user)
413 {
414 CONN_SYN_8 packet;
415
416 packet.len = sizeof(packet);
417
418 #ifdef ARCH_BE
419 SwapBytes(packet.len);
420 #endif
421
422 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
423 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
424 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
425 packet.dirs = 0xFFffFFff;
426
427 return SendPacket(&packet, sizeof(packet), user);
428 }
429
430 bool PROTO::Send_CONN_ACK(USER * user)
431 {
432 CONN_ACK_8 packet;
433
434 packet.len = sizeof(packet);
435 packet.rnd = user->IncRnd();
436
437 #ifdef ARCH_BE
438 SwapBytes(packet.len);
439 SwapBytes(packet.rnd);
440 #endif
441
442 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
443 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
444
445 return SendPacket(&packet, sizeof(packet), user);
446 }
447
448 bool PROTO::Send_ALIVE_ACK(USER * user)
449 {
450 ALIVE_ACK_8 packet;
451
452 packet.len = sizeof(packet);
453 packet.rnd = user->IncRnd();
454
455 #ifdef ARCH_BE
456 SwapBytes(packet.len);
457 SwapBytes(packet.rnd);
458 #endif
459
460 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
461 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
462
463 return SendPacket(&packet, sizeof(packet), user);
464 }
465
466 bool PROTO::Send_DISCONN_SYN(USER * user)
467 {
468 DISCONN_SYN_8 packet;
469
470 packet.len = sizeof(packet);
471
472 #ifdef ARCH_BE
473 SwapBytes(packet.len);
474 #endif
475
476 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
477 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
478 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
479
480 return SendPacket(&packet, sizeof(packet), user);
481 }
482
483 bool PROTO::Send_DISCONN_ACK(USER * user)
484 {
485 DISCONN_ACK_8 packet;
486
487 packet.len = sizeof(packet);
488 packet.rnd = user->IncRnd();
489
490 #ifdef ARCH_BE
491 SwapBytes(packet.len);
492 SwapBytes(packet.rnd);
493 #endif
494
495 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
496 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
497
498 return SendPacket(&packet, sizeof(packet), user);
499 }
500
501 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
502 {
503 HDR_8 hdr;
504
505 assert(length < 2048 && "Packet length must not exceed 2048 bytes");
506
507 strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic));
508 hdr.protoVer[0] = 0;
509 hdr.protoVer[1] = 8; // IA_PROTO_VER
510
511 unsigned char buffer[2048];
512 memset(buffer, 0, sizeof(buffer));
513 memcpy(buffer, packet, length);
514 memcpy(buffer, &hdr, sizeof(hdr));
515
516 size_t offset = sizeof(HDR_8);
517 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
518     {
519     Blowfish_Encrypt(&ctx,
520                      (uint32_t *)(buffer + offset + i * 8),
521                      (uint32_t *)(buffer + offset + i * 8 + 4));
522     }
523
524 offset += IA_LOGIN_LEN;
525 size_t encLen = (length - IA_LOGIN_LEN) / 8;
526 for (size_t i = 0; i < encLen; i++)
527     {
528     Blowfish_Encrypt(user->GetCtx(),
529                      (uint32_t*)(buffer + offset + i * 8),
530                      (uint32_t*)(buffer + offset + i * 8 + 4));
531     }
532
533 int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
534
535 if (res < 0)
536     {
537     errorStr = "Failed to send packet: '";
538     errorStr += strerror(errno);
539     errorStr += "'";
540     printfd(__FILE__, "PROTO::SendPacket() - %s, fd: %d\n", errorStr.c_str(), user->GetSocket());
541     return false;
542     }
543
544 if (res < length)
545     {
546     errorStr = "Packet sent partially";
547     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
548     return false;
549     }
550
551 return true;
552 }
553
554 bool PROTO::RealConnect(USER * user)
555 {
556 if (user->GetPhase() != 1 &&
557     user->GetPhase() != 5)
558     {
559     errorStr = "Unexpected connect";
560     printfd(__FILE__, "PROTO::RealConnect() - wrong phase: %d\n", user->GetPhase());
561     }
562 user->SetPhase(2);
563
564 return Send_CONN_SYN(user);
565 }
566
567 bool PROTO::RealDisconnect(USER * user)
568 {
569 if (user->GetPhase() != 3)
570     {
571     errorStr = "Unexpected disconnect";
572     printfd(__FILE__, "PROTO::RealDisconnect() - wrong phase: %d\n", user->GetPhase());
573     }
574 user->SetPhase(4);
575
576 return Send_DISCONN_SYN(user);
577 }