]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
A map of ip-to-user replaced with vector of pairs
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <netdb.h>
2 #include <arpa/inet.h>
3
4 #include <cerrno>
5 #include <cstring>
6 #include <cassert>
7 #include <stdexcept>
8
9 #include "stg/common.h"
10 #include "stg/ia_packets.h"
11
12 #include "proto.h"
13
14 PROTO::PROTO(const std::string & server,
15              uint16_t port,
16              uint16_t localPort,
17              int to)
18     : running(false),
19       stopped(true),
20       timeout(to)
21 {
22 uint32_t ip = inet_addr(server.c_str());
23 if (ip == INADDR_NONE)
24     {
25     struct hostent * hePtr = gethostbyname(server.c_str());
26     if (hePtr)
27         {
28         ip = *((uint32_t *)hePtr->h_addr_list[0]);
29         }
30     else
31         {
32         errorStr = "Unknown host: '";
33         errorStr += server;
34         errorStr += "'";
35         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
36         throw std::runtime_error(errorStr);
37         }
38     }
39
40 localAddr.sin_family = AF_INET;
41 localAddr.sin_port = htons(localPort);
42 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
43
44 serverAddr.sin_family = AF_INET;
45 serverAddr.sin_port = htons(port);
46 serverAddr.sin_addr.s_addr = ip;
47
48 unsigned char key[IA_PASSWD_LEN];
49 memset(key, 0, IA_PASSWD_LEN);
50 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
51 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
52
53 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
54 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
55 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
56 processors["FIN"] = &PROTO::FIN_Proc;
57 processors["INFO"] = &PROTO::INFO_Proc;
58 // ERR_Proc will be handled explicitly
59 }
60
61 PROTO::~PROTO()
62 {
63 }
64
65 void * PROTO::Runner(void * data)
66 {
67 PROTO * protoPtr = static_cast<PROTO *>(data);
68 protoPtr->Run();
69 }
70
71 bool PROTO::Start()
72 {
73 stopped = false;
74 running = true;
75 if (pthread_create(&tid, NULL, &Runner, NULL))
76     {
77     errorStr = "Failed to create listening thread: '";
78     errorStr += strerror(errno);
79     errorStr += "'";
80     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
81     return false;
82     }
83 return true;
84 }
85
86 bool PROTO::Stop()
87 {
88 running = false;
89 int time = 0;
90 while (!stopped && time < timeout)
91     {
92     struct timespec ts = {1, 0};
93     nanosleep(&ts, NULL);
94     ++time;
95     }
96 if (!stopped)
97     {
98     errorStr = "Failed to stop listening thread - timed out";
99     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
100     return false;
101     }
102 if (pthread_join(tid, NULL))
103     {
104     errorStr = "Failed to join listening thread after stop: '";
105     errorStr += strerror(errno);
106     errorStr += "'";
107     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
108     return false;
109     }
110 return true;
111 }
112
113 void PROTO::AddUser(const USER & user)
114 {
115     users.push_back(std::make_pair(user.GetIP(), user));
116     struct pollfd pfd;
117     pfd.fd = user.GetSocket();
118     pfd.events = POLLIN;
119     pfd.revents = 0;
120     pollFds.push_back(pfd);
121 }
122
123 bool PROTO::Connect(uint32_t ip)
124 {
125 /*std::vector<std::pair<uint32_t, USER> >::const_iterator it;
126 it = users.find(ip);
127 if (it == users.end())
128     return false;*/
129
130 // Do something
131
132 return true;
133 }
134
135 bool PROTO::Disconnect(uint32_t ip)
136 {
137 /*std::vector<std::pair<uint32_t, USER> >::const_iterator it;
138 it = users.find(ip);
139 if (it == users.end())
140     return false;*/
141
142 // Do something
143
144 return true;
145 }
146
147 void PROTO::Run()
148 {
149 while (running)
150     {
151     int res = poll(&pollFds.front(), pollFds.size(), timeout);
152     if (res < 0)
153         break;
154     if (!running)
155         break;
156     if (res)
157         RecvPacket();
158     }
159
160 stopped = true;
161 }
162
163 bool PROTO::RecvPacket()
164 {
165 bool result = true;
166 std::vector<struct pollfd>::iterator it;
167 std::vector<std::pair<uint32_t, USER> >::iterator userIt;
168 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
169     {
170     if (it->revents)
171         {
172         it->revents = 0;
173         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
174         struct sockaddr_in addr;
175         socklen_t fromLen = sizeof(addr);
176         char buffer[2048];
177         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&addr, &fromLen);
178
179         if (res == -1)
180             {
181             result = false;
182             continue;
183             }
184
185         result = result && HandlePacket(buffer, &(userIt->second));
186         }
187     }
188
189 return result;
190 }
191
192 bool PROTO::HandlePacket(const char * buffer, USER * user)
193 {
194 if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
195     {
196     return ERR_Proc(buffer, user);
197     }
198
199 std::string packetName(buffer + 12);
200 std::map<std::string, PacketProcessor>::const_iterator it;
201 it = processors.find(packetName);
202 if (it != processors.end())
203     return (this->*it->second)(buffer, user);
204
205 return false;
206 }
207
208 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
209 {
210 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
211
212 uint32_t rnd = packet->rnd;
213 uint32_t userTimeout = packet->userTimeOut;
214 uint32_t aliveTimeout = packet->aliveDelay;
215
216 #ifdef ARCH_BE
217 SwapBytes(rnd);
218 SwapBytes(userTimeout);
219 SwapBytes(aliveDelay);
220 #endif
221
222 Send_CONN_ACK(user);
223
224 if (user->GetPhase() != 2)
225     {
226     errorStr = "Unexpected CONN_SYN_ACK";
227     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
228     }
229
230 user->SetPhase(3);
231 user->SetAliveTimeout(aliveTimeout);
232 user->SetUserTimeout(userTimeout);
233 user->SetRnd(rnd);
234
235 return true;
236 }
237
238 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
239 {
240 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
241
242 uint32_t rnd = packet->rnd;
243
244 #ifdef ARCH_BE
245 SwapBytes(rnd);
246 #endif
247
248 if (user->GetPhase() != 3)
249     {
250     errorStr = "Unexpected ALIVE_SYN";
251     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
252     }
253
254 if (user->GetRnd() + 1 != rnd)
255     {
256     errorStr = "Wrong control value at ALIVE_SYN";
257     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
258     }
259
260 user->SetPhase(3);
261 user->SetRnd(rnd);
262
263 Send_ALIVE_ACK(user);
264
265 return true;
266 }
267
268 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
269 {
270 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
271
272 uint32_t rnd = packet->rnd;
273
274 #ifdef ARCH_BE
275 SwapBytes(rnd);
276 #endif
277
278 if (user->GetPhase() != 4)
279     {
280     errorStr = "Unexpected DISCONN_SYN_ACK";
281     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
282     }
283
284 if (user->GetRnd() + 1 != rnd)
285     {
286     errorStr = "Wrong control value at DISCONN_SYN_ACK";
287     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
288     }
289
290 user->SetPhase(5);
291 user->SetRnd(rnd);
292
293 Send_DISCONN_ACK(user);
294
295 return true;
296 }
297
298 bool PROTO::FIN_Proc(const void * buffer, USER * user)
299 {
300 if (user->GetPhase() != 5)
301     {
302     errorStr = "Unexpected FIN";
303     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
304     }
305
306 user->SetPhase(1);
307
308 return true;
309 }
310
311 bool PROTO::INFO_Proc(const void * buffer, USER * user)
312 {
313 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
314
315 return true;
316 }
317
318 bool PROTO::ERR_Proc(const void * buffer, USER * user)
319 {
320 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
321 const char * ptr = static_cast<const char *>(buffer);
322
323 for (int i = 0; i < sizeof(packet) / 8; i++)
324     Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4));
325
326 //uint32_t len = packet->len;
327
328 #ifdef ARCH_BE
329 //SwapBytes(len);
330 #endif
331
332 user->SetPhase(1); //TODO: Check
333 /*KOIToWin((const char*)err.text, &messageText);
334 if (pErrorCb != NULL)
335     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
336 phaseTime = GetTickCount();
337 codeError = IA_SERVER_ERROR;*/
338
339 return true;
340 }
341
342 bool PROTO::Send_CONN_SYN(USER * user)
343 {
344 CONN_SYN_8 packet;
345
346 packet.len = sizeof(packet);
347
348 #ifdef ARCH_BE
349 SwapBytes(packet.len);
350 #endif
351
352 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
353 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
354 packet.dirs = 0xFFffFFff;
355
356 return SendPacket(&packet, sizeof(packet), user);
357 }
358
359 bool PROTO::Send_CONN_ACK(USER * user)
360 {
361 CONN_ACK_8 packet;
362
363 packet.len = sizeof(packet);
364 packet.rnd = user->IncRnd();
365
366 #ifdef ARCH_BE
367 SwapBytes(packet.len);
368 SwapBytes(packet.rnd);
369 #endif
370
371 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
372 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
373
374 return SendPacket(&packet, sizeof(packet), user);
375 }
376
377 bool PROTO::Send_ALIVE_ACK(USER * user)
378 {
379 ALIVE_ACK_8 packet;
380
381 packet.len = sizeof(packet);
382 packet.rnd = user->IncRnd();
383
384 #ifdef ARCH_BE
385 SwapBytes(packet.len);
386 SwapBytes(packet.rnd);
387 #endif
388
389 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
390 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
391
392 return SendPacket(&packet, sizeof(packet), user);
393 }
394
395 bool PROTO::Send_DISCONN_SYN(USER * user)
396 {
397 DISCONN_SYN_8 packet;
398
399 packet.len = sizeof(packet);
400
401 #ifdef ARCH_BE
402 SwapBytes(packet.len);
403 #endif
404
405 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
406 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
407 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
408
409 return SendPacket(&packet, sizeof(packet), user);
410 }
411
412 bool PROTO::Send_DISCONN_ACK(USER * user)
413 {
414 DISCONN_ACK_8 packet;
415
416 packet.len = sizeof(packet);
417 packet.rnd = user->IncRnd();
418
419 #ifdef ARCH_BE
420 SwapBytes(packet.len);
421 SwapBytes(packet.rnd);
422 #endif
423
424 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
425 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
426
427 return SendPacket(&packet, sizeof(packet), user);
428 }
429
430 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
431 {
432 HDR_8 hdr;
433
434 assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes");
435
436 strncpy((char *)hdr.magic, IA_ID, 6);
437 hdr.protoVer[0] = 0;
438 hdr.protoVer[1] = 8; // IA_PROTO_VER
439
440 unsigned char buffer[2048];
441 memcpy(buffer, &hdr, sizeof(hdr));
442 memcpy(buffer + sizeof(hdr), packet, length);
443
444 size_t offset = sizeof(HDR_8);
445 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
446     {
447     Blowfish_Encrypt(&ctx,
448                      (uint32_t *)(buffer + offset + i * 8),
449                      (uint32_t *)(buffer + offset + i * 8 + 4));
450     }
451
452 offset += IA_LOGIN_LEN;
453 size_t encLen = (length - IA_LOGIN_LEN) / 8;
454 for (size_t i = 0; i < encLen; i++)
455     {
456     Blowfish_Encrypt(user->GetCtx(),
457                      (uint32_t*)(buffer + offset + i * 8),
458                      (uint32_t*)(buffer + offset + i * 8 + 4));
459     }
460
461 int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
462
463 if (res < 0)
464     {
465     errorStr = "Failed to send packet: '";
466     errorStr += strerror(errno);
467     errorStr += "'";
468     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
469     return false;
470     }
471
472 if (res < sizeof(buffer))
473     {
474     errorStr = "Packet sent partially";
475     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
476     return false;
477     }
478
479 return true;
480 }