]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
Fix some types and add UserCount method to PROTO
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <netdb.h>
2 #include <arpa/inet.h>
3
4 #include <cerrno>
5 #include <cstring>
6 #include <cassert>
7 #include <stdexcept>
8
9 #include "stg/common.h"
10 #include "stg/ia_packets.h"
11
12 #include "proto.h"
13
14 PROTO::PROTO(const std::string & server,
15              uint16_t port,
16              uint16_t localPort,
17              int to)
18     : timeout(to),
19       running(false),
20       stopped(true)
21 {
22 uint32_t ip = inet_addr(server.c_str());
23 if (ip == INADDR_NONE)
24     {
25     struct hostent * hePtr = gethostbyname(server.c_str());
26     if (hePtr)
27         {
28         ip = *((uint32_t *)hePtr->h_addr_list[0]);
29         }
30     else
31         {
32         errorStr = "Unknown host: '";
33         errorStr += server;
34         errorStr += "'";
35         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
36         throw std::runtime_error(errorStr);
37         }
38     }
39
40 localAddr.sin_family = AF_INET;
41 localAddr.sin_port = htons(localPort);
42 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
43
44 serverAddr.sin_family = AF_INET;
45 serverAddr.sin_port = htons(port);
46 serverAddr.sin_addr.s_addr = ip;
47
48 unsigned char key[IA_PASSWD_LEN];
49 memset(key, 0, IA_PASSWD_LEN);
50 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
51 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
52
53 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
54 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
55 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
56 processors["FIN"] = &PROTO::FIN_Proc;
57 processors["INFO"] = &PROTO::INFO_Proc;
58 // ERR_Proc will be handled explicitly
59 }
60
61 PROTO::~PROTO()
62 {
63 }
64
65 void * PROTO::Runner(void * data)
66 {
67 PROTO * protoPtr = static_cast<PROTO *>(data);
68 protoPtr->Run();
69 return NULL;
70 }
71
72 bool PROTO::Start()
73 {
74 stopped = false;
75 running = true;
76 if (pthread_create(&tid, NULL, &Runner, NULL))
77     {
78     errorStr = "Failed to create listening thread: '";
79     errorStr += strerror(errno);
80     errorStr += "'";
81     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
82     return false;
83     }
84 return true;
85 }
86
87 bool PROTO::Stop()
88 {
89 running = false;
90 int time = 0;
91 while (!stopped && time < timeout)
92     {
93     struct timespec ts = {1, 0};
94     nanosleep(&ts, NULL);
95     ++time;
96     }
97 if (!stopped)
98     {
99     errorStr = "Failed to stop listening thread - timed out";
100     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
101     return false;
102     }
103 if (pthread_join(tid, NULL))
104     {
105     errorStr = "Failed to join listening thread after stop: '";
106     errorStr += strerror(errno);
107     errorStr += "'";
108     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
109     return false;
110     }
111 return true;
112 }
113
114 void PROTO::AddUser(const USER & user)
115 {
116     users.push_back(std::make_pair(user.GetIP(), user));
117     struct pollfd pfd;
118     pfd.fd = user.GetSocket();
119     pfd.events = POLLIN;
120     pfd.revents = 0;
121     pollFds.push_back(pfd);
122 }
123
124 bool PROTO::Connect(uint32_t ip)
125 {
126 /*std::vector<std::pair<uint32_t, USER> >::const_iterator it;
127 it = users.find(ip);
128 if (it == users.end())
129     return false;*/
130
131 // Do something
132
133 return true;
134 }
135
136 bool PROTO::Disconnect(uint32_t ip)
137 {
138 /*std::vector<std::pair<uint32_t, USER> >::const_iterator it;
139 it = users.find(ip);
140 if (it == users.end())
141     return false;*/
142
143 // Do something
144
145 return true;
146 }
147
148 void PROTO::Run()
149 {
150 while (running)
151     {
152     int res = poll(&pollFds.front(), pollFds.size(), timeout);
153     if (res < 0)
154         break;
155     if (!running)
156         break;
157     if (res)
158         RecvPacket();
159     }
160
161 stopped = true;
162 }
163
164 bool PROTO::RecvPacket()
165 {
166 bool result = true;
167 std::vector<struct pollfd>::iterator it;
168 std::vector<std::pair<uint32_t, USER> >::iterator userIt;
169 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
170     {
171     if (it->revents)
172         {
173         it->revents = 0;
174         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
175         struct sockaddr_in addr;
176         socklen_t fromLen = sizeof(addr);
177         char buffer[2048];
178         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
179
180         if (res == -1)
181             {
182             result = false;
183             continue;
184             }
185
186         result = result && HandlePacket(buffer, &(userIt->second));
187         }
188     }
189
190 return result;
191 }
192
193 bool PROTO::HandlePacket(const char * buffer, USER * user)
194 {
195 if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
196     {
197     return ERR_Proc(buffer, user);
198     }
199
200 std::string packetName(buffer + 12);
201 std::map<std::string, PacketProcessor>::const_iterator it;
202 it = processors.find(packetName);
203 if (it != processors.end())
204     return (this->*it->second)(buffer, user);
205
206 printfd(__FILE__, "PROTO::HandlePacket() - invalid packet signature: '%s'\n", packetName.c_str());
207
208 return false;
209 }
210
211 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
212 {
213 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
214
215 uint32_t rnd = packet->rnd;
216 uint32_t userTimeout = packet->userTimeOut;
217 uint32_t aliveTimeout = packet->aliveDelay;
218
219 #ifdef ARCH_BE
220 SwapBytes(rnd);
221 SwapBytes(userTimeout);
222 SwapBytes(aliveDelay);
223 #endif
224
225 Send_CONN_ACK(user);
226
227 if (user->GetPhase() != 2)
228     {
229     errorStr = "Unexpected CONN_SYN_ACK";
230     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
231     }
232
233 user->SetPhase(3);
234 user->SetAliveTimeout(aliveTimeout);
235 user->SetUserTimeout(userTimeout);
236 user->SetRnd(rnd);
237
238 return true;
239 }
240
241 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
242 {
243 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
244
245 uint32_t rnd = packet->rnd;
246
247 #ifdef ARCH_BE
248 SwapBytes(rnd);
249 #endif
250
251 if (user->GetPhase() != 3)
252     {
253     errorStr = "Unexpected ALIVE_SYN";
254     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
255     }
256
257 if (user->GetRnd() + 1 != rnd)
258     {
259     errorStr = "Wrong control value at ALIVE_SYN";
260     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
261     }
262
263 user->SetPhase(3);
264 user->SetRnd(rnd);
265
266 Send_ALIVE_ACK(user);
267
268 return true;
269 }
270
271 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
272 {
273 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
274
275 uint32_t rnd = packet->rnd;
276
277 #ifdef ARCH_BE
278 SwapBytes(rnd);
279 #endif
280
281 if (user->GetPhase() != 4)
282     {
283     errorStr = "Unexpected DISCONN_SYN_ACK";
284     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
285     }
286
287 if (user->GetRnd() + 1 != rnd)
288     {
289     errorStr = "Wrong control value at DISCONN_SYN_ACK";
290     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
291     }
292
293 user->SetPhase(5);
294 user->SetRnd(rnd);
295
296 Send_DISCONN_ACK(user);
297
298 return true;
299 }
300
301 bool PROTO::FIN_Proc(const void * buffer, USER * user)
302 {
303 if (user->GetPhase() != 5)
304     {
305     errorStr = "Unexpected FIN";
306     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
307     }
308
309 user->SetPhase(1);
310
311 return true;
312 }
313
314 bool PROTO::INFO_Proc(const void * buffer, USER * user)
315 {
316 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
317
318 return true;
319 }
320
321 bool PROTO::ERR_Proc(const void * buffer, USER * user)
322 {
323 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
324 const char * ptr = static_cast<const char *>(buffer);
325
326 for (size_t i = 0; i < sizeof(ERR_8) / 8; i++)
327     Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4));
328
329 //uint32_t len = packet->len;
330
331 #ifdef ARCH_BE
332 //SwapBytes(len);
333 #endif
334
335 user->SetPhase(1); //TODO: Check
336 /*KOIToWin((const char*)err.text, &messageText);
337 if (pErrorCb != NULL)
338     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
339 phaseTime = GetTickCount();
340 codeError = IA_SERVER_ERROR;*/
341
342 return true;
343 }
344
345 bool PROTO::Send_CONN_SYN(USER * user)
346 {
347 CONN_SYN_8 packet;
348
349 packet.len = sizeof(packet);
350
351 #ifdef ARCH_BE
352 SwapBytes(packet.len);
353 #endif
354
355 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
356 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
357 packet.dirs = 0xFFffFFff;
358
359 return SendPacket(&packet, sizeof(packet), user);
360 }
361
362 bool PROTO::Send_CONN_ACK(USER * user)
363 {
364 CONN_ACK_8 packet;
365
366 packet.len = sizeof(packet);
367 packet.rnd = user->IncRnd();
368
369 #ifdef ARCH_BE
370 SwapBytes(packet.len);
371 SwapBytes(packet.rnd);
372 #endif
373
374 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
375 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
376
377 return SendPacket(&packet, sizeof(packet), user);
378 }
379
380 bool PROTO::Send_ALIVE_ACK(USER * user)
381 {
382 ALIVE_ACK_8 packet;
383
384 packet.len = sizeof(packet);
385 packet.rnd = user->IncRnd();
386
387 #ifdef ARCH_BE
388 SwapBytes(packet.len);
389 SwapBytes(packet.rnd);
390 #endif
391
392 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
393 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
394
395 return SendPacket(&packet, sizeof(packet), user);
396 }
397
398 bool PROTO::Send_DISCONN_SYN(USER * user)
399 {
400 DISCONN_SYN_8 packet;
401
402 packet.len = sizeof(packet);
403
404 #ifdef ARCH_BE
405 SwapBytes(packet.len);
406 #endif
407
408 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
409 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
410 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
411
412 return SendPacket(&packet, sizeof(packet), user);
413 }
414
415 bool PROTO::Send_DISCONN_ACK(USER * user)
416 {
417 DISCONN_ACK_8 packet;
418
419 packet.len = sizeof(packet);
420 packet.rnd = user->IncRnd();
421
422 #ifdef ARCH_BE
423 SwapBytes(packet.len);
424 SwapBytes(packet.rnd);
425 #endif
426
427 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
428 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
429
430 return SendPacket(&packet, sizeof(packet), user);
431 }
432
433 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
434 {
435 HDR_8 hdr;
436
437 assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes");
438
439 strncpy((char *)hdr.magic, IA_ID, 6);
440 hdr.protoVer[0] = 0;
441 hdr.protoVer[1] = 8; // IA_PROTO_VER
442
443 unsigned char buffer[2048];
444 memcpy(buffer, &hdr, sizeof(hdr));
445 memcpy(buffer + sizeof(hdr), packet, length);
446
447 size_t offset = sizeof(HDR_8);
448 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
449     {
450     Blowfish_Encrypt(&ctx,
451                      (uint32_t *)(buffer + offset + i * 8),
452                      (uint32_t *)(buffer + offset + i * 8 + 4));
453     }
454
455 offset += IA_LOGIN_LEN;
456 size_t encLen = (length - IA_LOGIN_LEN) / 8;
457 for (size_t i = 0; i < encLen; i++)
458     {
459     Blowfish_Encrypt(user->GetCtx(),
460                      (uint32_t*)(buffer + offset + i * 8),
461                      (uint32_t*)(buffer + offset + i * 8 + 4));
462     }
463
464 int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
465
466 if (res < 0)
467     {
468     errorStr = "Failed to send packet: '";
469     errorStr += strerror(errno);
470     errorStr += "'";
471     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
472     return false;
473     }
474
475 if (res < sizeof(buffer))
476     {
477     errorStr = "Packet sent partially";
478     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
479     return false;
480     }
481
482 return true;
483 }