]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
Implement connecting and disconnecting users
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <netdb.h>
2 #include <arpa/inet.h>
3
4 #include <cerrno>
5 #include <cstring>
6 #include <cassert>
7 #include <stdexcept>
8 #include <algorithm>
9
10 #include "stg/common.h"
11 #include "stg/ia_packets.h"
12
13 #include "proto.h"
14
15 class HasIP : public std::unary_function<std::pair<uint32_t, USER>, bool> {
16     public:
17         explicit HasIP(uint32_t i) : ip(i) {}
18         bool operator()(const std::pair<uint32_t, USER> & value) { return value.first == ip; }
19     private:
20         uint32_t ip;
21 };
22
23 PROTO::PROTO(const std::string & server,
24              uint16_t port,
25              uint16_t localPort,
26              int to)
27     : timeout(to),
28       running(false),
29       stopped(true)
30 {
31 uint32_t ip = inet_addr(server.c_str());
32 if (ip == INADDR_NONE)
33     {
34     struct hostent * hePtr = gethostbyname(server.c_str());
35     if (hePtr)
36         {
37         ip = *((uint32_t *)hePtr->h_addr_list[0]);
38         }
39     else
40         {
41         errorStr = "Unknown host: '";
42         errorStr += server;
43         errorStr += "'";
44         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
45         throw std::runtime_error(errorStr);
46         }
47     }
48
49 localAddr.sin_family = AF_INET;
50 localAddr.sin_port = htons(localPort);
51 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
52
53 serverAddr.sin_family = AF_INET;
54 serverAddr.sin_port = htons(port);
55 serverAddr.sin_addr.s_addr = ip;
56
57 unsigned char key[IA_PASSWD_LEN];
58 memset(key, 0, IA_PASSWD_LEN);
59 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
60 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
61
62 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
63 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
64 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
65 processors["FIN"] = &PROTO::FIN_Proc;
66 processors["INFO"] = &PROTO::INFO_Proc;
67 // ERR_Proc will be handled explicitly
68 }
69
70 PROTO::~PROTO()
71 {
72 }
73
74 void * PROTO::Runner(void * data)
75 {
76 PROTO * protoPtr = static_cast<PROTO *>(data);
77 protoPtr->Run();
78 return NULL;
79 }
80
81 bool PROTO::Start()
82 {
83 stopped = false;
84 running = true;
85 if (pthread_create(&tid, NULL, &Runner, NULL))
86     {
87     errorStr = "Failed to create listening thread: '";
88     errorStr += strerror(errno);
89     errorStr += "'";
90     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
91     return false;
92     }
93 return true;
94 }
95
96 bool PROTO::Stop()
97 {
98 running = false;
99 int time = 0;
100 while (!stopped && time < timeout)
101     {
102     struct timespec ts = {1, 0};
103     nanosleep(&ts, NULL);
104     ++time;
105     }
106 if (!stopped)
107     {
108     errorStr = "Failed to stop listening thread - timed out";
109     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
110     return false;
111     }
112 if (pthread_join(tid, NULL))
113     {
114     errorStr = "Failed to join listening thread after stop: '";
115     errorStr += strerror(errno);
116     errorStr += "'";
117     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
118     return false;
119     }
120 return true;
121 }
122
123 void PROTO::AddUser(const USER & user, bool connect)
124 {
125 users.push_back(std::make_pair(user.GetIP(), user));
126 struct pollfd pfd;
127 pfd.fd = user.GetSocket();
128 pfd.events = POLLIN;
129 pfd.revents = 0;
130 pollFds.push_back(pfd);
131
132 users.back().second.InitNetwork();
133
134 if (connect)
135     {
136     RealConnect(&users.back().second);
137     }
138 }
139
140 bool PROTO::Connect(uint32_t ip)
141 {
142 std::vector<std::pair<uint32_t, USER> >::iterator it;
143 it = std::find_if(users.begin(), users.end(), HasIP(ip));
144 if (it == users.end())
145     return false;
146
147 // Do something
148
149 return RealConnect(&it->second);
150 }
151
152 bool PROTO::Disconnect(uint32_t ip)
153 {
154 std::vector<std::pair<uint32_t, USER> >::iterator it;
155 it = std::find_if(users.begin(), users.end(), HasIP(ip));
156 if (it == users.end())
157     return false;
158
159 // Do something
160
161 return RealDisconnect(&it->second);
162 }
163
164 void PROTO::Run()
165 {
166 while (running)
167     {
168     int res = poll(&pollFds.front(), pollFds.size(), timeout);
169     if (res < 0)
170         break;
171     if (!running)
172         break;
173     if (res)
174         RecvPacket();
175     }
176
177 stopped = true;
178 }
179
180 bool PROTO::RecvPacket()
181 {
182 bool result = true;
183 std::vector<struct pollfd>::iterator it;
184 std::vector<std::pair<uint32_t, USER> >::iterator userIt;
185 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
186     {
187     if (it->revents)
188         {
189         it->revents = 0;
190         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
191         struct sockaddr_in addr;
192         socklen_t fromLen = sizeof(addr);
193         char buffer[2048];
194         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
195
196         if (res == -1)
197             {
198             result = false;
199             continue;
200             }
201
202         result = result && HandlePacket(buffer, &(userIt->second));
203         }
204     }
205
206 return result;
207 }
208
209 bool PROTO::HandlePacket(const char * buffer, USER * user)
210 {
211 if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
212     {
213     return ERR_Proc(buffer, user);
214     }
215
216 std::string packetName(buffer + 12);
217 std::map<std::string, PacketProcessor>::const_iterator it;
218 it = processors.find(packetName);
219 if (it != processors.end())
220     return (this->*it->second)(buffer, user);
221
222 printfd(__FILE__, "PROTO::HandlePacket() - invalid packet signature: '%s'\n", packetName.c_str());
223
224 return false;
225 }
226
227 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
228 {
229 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
230
231 uint32_t rnd = packet->rnd;
232 uint32_t userTimeout = packet->userTimeOut;
233 uint32_t aliveTimeout = packet->aliveDelay;
234
235 #ifdef ARCH_BE
236 SwapBytes(rnd);
237 SwapBytes(userTimeout);
238 SwapBytes(aliveDelay);
239 #endif
240
241 Send_CONN_ACK(user);
242
243 if (user->GetPhase() != 2)
244     {
245     errorStr = "Unexpected CONN_SYN_ACK";
246     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
247     }
248
249 user->SetPhase(3);
250 user->SetAliveTimeout(aliveTimeout);
251 user->SetUserTimeout(userTimeout);
252 user->SetRnd(rnd);
253
254 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
255
256 return true;
257 }
258
259 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
260 {
261 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
262
263 uint32_t rnd = packet->rnd;
264
265 #ifdef ARCH_BE
266 SwapBytes(rnd);
267 #endif
268
269 if (user->GetPhase() != 3)
270     {
271     errorStr = "Unexpected ALIVE_SYN";
272     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
273     }
274
275 if (user->GetRnd() + 1 != rnd)
276     {
277     errorStr = "Wrong control value at ALIVE_SYN";
278     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
279     }
280
281 user->SetPhase(3);
282 user->SetRnd(rnd);
283
284 Send_ALIVE_ACK(user);
285
286 return true;
287 }
288
289 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
290 {
291 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
292
293 uint32_t rnd = packet->rnd;
294
295 #ifdef ARCH_BE
296 SwapBytes(rnd);
297 #endif
298
299 if (user->GetPhase() != 4)
300     {
301     errorStr = "Unexpected DISCONN_SYN_ACK";
302     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
303     }
304
305 if (user->GetRnd() + 1 != rnd)
306     {
307     errorStr = "Wrong control value at DISCONN_SYN_ACK";
308     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
309     }
310
311 user->SetPhase(5);
312 user->SetRnd(rnd);
313
314 Send_DISCONN_ACK(user);
315
316 return true;
317 }
318
319 bool PROTO::FIN_Proc(const void * buffer, USER * user)
320 {
321 if (user->GetPhase() != 5)
322     {
323     errorStr = "Unexpected FIN";
324     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
325     }
326
327 user->SetPhase(1);
328
329 return true;
330 }
331
332 bool PROTO::INFO_Proc(const void * buffer, USER * user)
333 {
334 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
335
336 return true;
337 }
338
339 bool PROTO::ERR_Proc(const void * buffer, USER * user)
340 {
341 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
342 const char * ptr = static_cast<const char *>(buffer);
343
344 for (size_t i = 0; i < sizeof(ERR_8) / 8; i++)
345     Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4));
346
347 //uint32_t len = packet->len;
348
349 #ifdef ARCH_BE
350 //SwapBytes(len);
351 #endif
352
353 user->SetPhase(1); //TODO: Check
354 /*KOIToWin((const char*)err.text, &messageText);
355 if (pErrorCb != NULL)
356     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
357 phaseTime = GetTickCount();
358 codeError = IA_SERVER_ERROR;*/
359
360 return true;
361 }
362
363 bool PROTO::Send_CONN_SYN(USER * user)
364 {
365 CONN_SYN_8 packet;
366
367 packet.len = sizeof(packet);
368
369 #ifdef ARCH_BE
370 SwapBytes(packet.len);
371 #endif
372
373 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
374 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
375 packet.dirs = 0xFFffFFff;
376
377 return SendPacket(&packet, sizeof(packet), user);
378 }
379
380 bool PROTO::Send_CONN_ACK(USER * user)
381 {
382 CONN_ACK_8 packet;
383
384 packet.len = sizeof(packet);
385 packet.rnd = user->IncRnd();
386
387 #ifdef ARCH_BE
388 SwapBytes(packet.len);
389 SwapBytes(packet.rnd);
390 #endif
391
392 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
393 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
394
395 return SendPacket(&packet, sizeof(packet), user);
396 }
397
398 bool PROTO::Send_ALIVE_ACK(USER * user)
399 {
400 ALIVE_ACK_8 packet;
401
402 packet.len = sizeof(packet);
403 packet.rnd = user->IncRnd();
404
405 #ifdef ARCH_BE
406 SwapBytes(packet.len);
407 SwapBytes(packet.rnd);
408 #endif
409
410 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
411 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
412
413 return SendPacket(&packet, sizeof(packet), user);
414 }
415
416 bool PROTO::Send_DISCONN_SYN(USER * user)
417 {
418 DISCONN_SYN_8 packet;
419
420 packet.len = sizeof(packet);
421
422 #ifdef ARCH_BE
423 SwapBytes(packet.len);
424 #endif
425
426 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
427 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
428 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
429
430 return SendPacket(&packet, sizeof(packet), user);
431 }
432
433 bool PROTO::Send_DISCONN_ACK(USER * user)
434 {
435 DISCONN_ACK_8 packet;
436
437 packet.len = sizeof(packet);
438 packet.rnd = user->IncRnd();
439
440 #ifdef ARCH_BE
441 SwapBytes(packet.len);
442 SwapBytes(packet.rnd);
443 #endif
444
445 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
446 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
447
448 return SendPacket(&packet, sizeof(packet), user);
449 }
450
451 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
452 {
453 HDR_8 hdr;
454
455 assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes");
456
457 strncpy((char *)hdr.magic, IA_ID, 6);
458 hdr.protoVer[0] = 0;
459 hdr.protoVer[1] = 8; // IA_PROTO_VER
460
461 unsigned char buffer[2048];
462 memcpy(buffer, &hdr, sizeof(hdr));
463 memcpy(buffer + sizeof(hdr), packet, length);
464
465 size_t offset = sizeof(HDR_8);
466 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
467     {
468     Blowfish_Encrypt(&ctx,
469                      (uint32_t *)(buffer + offset + i * 8),
470                      (uint32_t *)(buffer + offset + i * 8 + 4));
471     }
472
473 offset += IA_LOGIN_LEN;
474 size_t encLen = (length - IA_LOGIN_LEN) / 8;
475 for (size_t i = 0; i < encLen; i++)
476     {
477     Blowfish_Encrypt(user->GetCtx(),
478                      (uint32_t*)(buffer + offset + i * 8),
479                      (uint32_t*)(buffer + offset + i * 8 + 4));
480     }
481
482 int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
483
484 if (res < 0)
485     {
486     errorStr = "Failed to send packet: '";
487     errorStr += strerror(errno);
488     errorStr += "'";
489     printfd(__FILE__, "PROTO::SendPacket() - %s, fd: %d\n", errorStr.c_str(), user->GetSocket());
490     return false;
491     }
492
493 if (res < sizeof(buffer))
494     {
495     errorStr = "Packet sent partially";
496     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
497     return false;
498     }
499
500 return true;
501 }
502
503 bool PROTO::RealConnect(USER * user)
504 {
505 if (user->GetPhase() != 1 &&
506     user->GetPhase() != 5)
507     {
508     errorStr = "Unexpected connect";
509     printfd(__FILE__, "PROTO::RealConnect() - wrong phase: %d\n", user->GetPhase());
510     }
511 user->SetPhase(2);
512
513 return Send_CONN_SYN(user);
514 }
515
516 bool PROTO::RealDisconnect(USER * user)
517 {
518 if (user->GetPhase() != 3)
519     {
520     errorStr = "Unexpected disconnect";
521     printfd(__FILE__, "PROTO::RealDisconnect() - wrong phase: %d\n", user->GetPhase());
522     }
523 user->SetPhase(4);
524
525 return Send_DISCONN_SYN(user);
526 }