]> git.stg.codes - stg.git/blob - projects/sgauthstress/proto.cpp
7c547daf8462caeb086d4e38cd54241560a5293a
[stg.git] / projects / sgauthstress / proto.cpp
1 #include <netdb.h>
2 #include <arpa/inet.h>
3
4 #include <cerrno>
5 #include <cstring>
6 #include <cassert>
7 #include <stdexcept>
8
9 #include "stg/common.h"
10
11 #include "proto.h"
12
13 PROTO::PROTO(const std::string & server,
14              uint16_t port,
15              uint16_t localPort,
16              int to)
17     : running(false),
18       stopped(true),
19       timeout(to)
20 {
21 uint32_t ip = inet_addr(server.c_str());
22 if (ip == INADDR_NONE)
23     {
24     struct hostent * hePtr = gethostbyname(server.c_str());
25     if (hePtr)
26         {
27         ip = *((uint32_t *)hePtr->h_addr_list[0]);
28         }
29     else
30         {
31         errorStr = "Unknown host: '";
32         errorStr += server;
33         errorStr += "'";
34         printfd(__FILE__, "PROTO::PROTO() - %s\n", errorStr.c_str());
35         throw std::runtime_error(errorStr);
36         }
37     }
38
39 localAddr.sin_family = AF_INET;
40 localAddr.sin_port = htons(localPort);
41 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
42
43 serverAddr.sin_family = AF_INET;
44 serverAddr.sin_port = htons(port);
45 serverAddr.sin_addr.s_addr = ip;
46
47 unsigned char key[IA_PASSWD_LEN];
48 memset(key, 0, IA_PASSWD_LEN);
49 strncpy(reinterpret_cast<char *>(key), "pr7Hhen", 8);
50 Blowfish_Init(&ctx, key, IA_PASSWD_LEN);
51
52 processors["CONN_SYN_ACK"] = &PROTO::CONN_SYN_ACK_Proc;
53 processors["ALIVE_SYN"] = &PROTO::ALIVE_SYN_Proc;
54 processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
55 processors["FIN"] = &PROTO::FIN_Proc;
56 processors["INFO"] = &PROTO::INFO_Proc;
57 // ERR_Proc will be handled explicitly
58 }
59
60 PROTO::~PROTO()
61 {
62 }
63
64 void * PROTO::Runner(void * data)
65 {
66 PROTO * protoPtr = static_cast<PROTO *>(data);
67 protoPtr->Run();
68 }
69
70 bool PROTO::Start()
71 {
72 stopped = false;
73 running = true;
74 if (pthread_create(&tid, NULL, &Runner, NULL))
75     {
76     errorStr = "Failed to create listening thread: '";
77     errorStr += strerror(errno);
78     errorStr += "'";
79     printfd(__FILE__, "PROTO::Start() - %s\n", errorStr.c_str());
80     return false;
81     }
82 return true;
83 }
84
85 bool PROTO::Stop()
86 {
87 running = false;
88 int time = 0;
89 while (!stopped && time < timeout)
90     {
91     struct timespec ts = {1, 0};
92     nanosleep(&ts, NULL);
93     ++time;
94     }
95 if (!stopped)
96     {
97     errorStr = "Failed to stop listening thread - timed out";
98     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
99     return false;
100     }
101 if (pthread_join(tid, NULL))
102     {
103     errorStr = "Failed to join listening thread after stop: '";
104     errorStr += strerror(errno);
105     errorStr += "'";
106     printfd(__FILE__, "PROTO::Stop() - %s\n", errorStr.c_str());
107     return false;
108     }
109 return true;
110 }
111
112 void PROTO::AddUser(const USER & user)
113 {
114     users.insert(std::make_pair(user.GetIP(), user));
115     struct pollfd pfd;
116     pfd.fd = user.GetSocket();
117     pfd.events = POLLIN;
118     pfd.revents = 0;
119     pollFds.push_back(pfd);
120 }
121
122 bool PROTO::Connect(uint32_t ip)
123 {
124 std::map<uint32_t, USER>::const_iterator it;
125 it = users.find(ip);
126 if (it == users.end())
127     return false;
128
129 // Do something
130
131 return true;
132 }
133
134 bool PROTO::Disconnect(uint32_t ip)
135 {
136 std::map<uint32_t, USER>::const_iterator it;
137 it = users.find(ip);
138 if (it == users.end())
139     return false;
140
141 // Do something
142
143 return true;
144 }
145
146 void PROTO::Run()
147 {
148 while (running)
149     {
150     int res = poll(&pollFds.front(), pollFds.size(), timeout);
151     if (res < 0)
152         break;
153     if (!running)
154         break;
155     if (res)
156         RecvPacket();
157     }
158
159 stopped = true;
160 }
161
162 bool PROTO::RecvPacket()
163 {
164 bool result = true;
165 std::vector<struct pollfd>::iterator it;
166 std::map<uint32_t, USER>::iterator userIt(users.begin());
167 for (it = pollFds.begin(); it != pollFds.end(); ++it)
168     {
169     if (it->revents)
170         {
171         it->revents = 0;
172         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
173         struct sockaddr_in addr;
174         socklen_t fromLen = sizeof(addr);
175         char buffer[2048];
176         int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr*)&addr, &fromLen);
177
178         if (res == -1)
179             {
180             result = false;
181             ++userIt;
182             continue;
183             }
184
185         result = result && HandlePacket(buffer, &(userIt->second));
186         }
187     ++userIt;
188     }
189
190 return result;
191 }
192
193 bool PROTO::HandlePacket(const char * buffer, USER * user)
194 {
195 if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
196     {
197     return ERR_Proc(buffer, user);
198     }
199
200 std::string packetName(buffer + 12);
201 std::map<std::string, PacketProcessor>::const_iterator it;
202 it = processors.find(packetName);
203 if (it != processors.end())
204     return (this->*it->second)(buffer, user);
205
206 return false;
207 }
208
209 bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
210 {
211 const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
212
213 uint32_t rnd = packet->rnd;
214 uint32_t userTimeout = packet->userTimeOut;
215 uint32_t aliveTimeout = packet->aliveDelay;
216
217 #ifdef ARCH_BE
218 SwapBytes(rnd);
219 SwapBytes(userTimeout);
220 SwapBytes(aliveDelay);
221 #endif
222
223 Send_CONN_ACK(user);
224
225 if (user->GetPhase() != 2)
226     {
227     errorStr = "Unexpected CONN_SYN_ACK";
228     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
229     }
230
231 user->SetPhase(3);
232 user->SetAliveTimeout(aliveTimeout);
233 user->SetUserTimeout(userTimeout);
234 user->SetRnd(rnd);
235
236 return true;
237 }
238
239 bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
240 {
241 const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
242
243 uint32_t rnd = packet->rnd;
244
245 #ifdef ARCH_BE
246 SwapBytes(rnd);
247 #endif
248
249 if (user->GetPhase() != 3)
250     {
251     errorStr = "Unexpected ALIVE_SYN";
252     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
253     }
254
255 if (user->GetRnd() + 1 != rnd)
256     {
257     errorStr = "Wrong control value at ALIVE_SYN";
258     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
259     }
260
261 user->SetPhase(3);
262 user->SetRnd(rnd);
263
264 Send_ALIVE_ACK(user);
265
266 return true;
267 }
268
269 bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
270 {
271 const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
272
273 uint32_t rnd = packet->rnd;
274
275 #ifdef ARCH_BE
276 SwapBytes(rnd);
277 #endif
278
279 if (user->GetPhase() != 4)
280     {
281     errorStr = "Unexpected DISCONN_SYN_ACK";
282     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
283     }
284
285 if (user->GetRnd() + 1 != rnd)
286     {
287     errorStr = "Wrong control value at DISCONN_SYN_ACK";
288     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
289     }
290
291 user->SetPhase(5);
292 user->SetRnd(rnd);
293
294 Send_DISCONN_ACK(user);
295
296 return true;
297 }
298
299 bool PROTO::FIN_Proc(const void * buffer, USER * user)
300 {
301 if (user->GetPhase() != 5)
302     {
303     errorStr = "Unexpected FIN";
304     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
305     }
306
307 user->SetPhase(1);
308
309 return true;
310 }
311
312 bool PROTO::INFO_Proc(const void * buffer, USER * user)
313 {
314 //const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
315
316 return true;
317 }
318
319 bool PROTO::ERR_Proc(const void * buffer, USER * user)
320 {
321 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
322
323 for (int i = 0; i < len/8; i++)
324     Blowfish_Decrypt(&ctxPass, (uint32_t*)(buffer + i*8), (uint32_t*)(buffer + i*8 + 4));
325
326 //uint32_t len = packet->len;
327
328 #ifdef ARCH_BE
329 //SwapBytes(len);
330 #endif
331
332 user->SetPhase(1); //TODO: Check
333 /*KOIToWin((const char*)err.text, &messageText);
334 if (pErrorCb != NULL)
335     pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
336 phaseTime = GetTickCount();
337 codeError = IA_SERVER_ERROR;*/
338
339 return true;
340 }
341
342 bool PROTO::Send_CONN_SYN(USER * user)
343 {
344 CONN_SYN_8 packet;
345
346 packet.len = sizeof(packet);
347
348 #ifdef ARCH_BE
349 SwapBytes(packet.len);
350 #endif
351
352 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
353 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
354 packet.dirs = 0xFFffFFff;
355
356 return SendPacket(&packet, sizeof(packet), user);
357 }
358
359 bool PROTO::Send_CONN_ACK(USER * user)
360 {
361 CONN_ACK_8 packet;
362
363 packet.len = sizeof(packet);
364 packet.rnd = user->IncRnd();
365
366 #ifdef ARCH_BE
367 SwapBytes(packet.len);
368 SwapBytes(packet.rnd);
369 #endif
370
371 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
372 strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
373
374 return SendPacket(&packet, sizeof(packet), user);
375 }
376
377 bool PROTO::Send_ALIVE_ACK(USER * user)
378 {
379 ALIVE_ACK_8 packet;
380
381 packet.len = sizeof(packet);
382 packet.rnd = user->IncRnd();
383
384 #ifdef ARCH_BE
385 SwapBytes(packet.len);
386 SwapBytes(packet.rnd);
387 #endif
388
389 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
390 strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
391
392 return SendPacket(&packet, sizeof(packet), user);
393 }
394
395 bool PROTO::Send_DISCONN_SYN(USER * user)
396 {
397 DISCONN_SYN_8 packet;
398
399 packet.len = sizeof(packet);
400
401 #ifdef ARCH_BE
402 SwapBytes(packet.len);
403 #endif
404
405 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
406 strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
407 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
408
409 return SendPacket(&packet, sizeof(packet), user);
410 }
411
412 bool PROTO::Send_DISCONN_ACK(USER * user)
413 {
414 DISCONN_ACK_8 packet;
415
416 packet.len = sizeof(packet);
417 packet.rnd = user->IncRnd();
418
419 #ifdef ARCH_BE
420 SwapBytes(packet.len);
421 SwapBytes(packet.rnd);
422 #endif
423
424 strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
425 strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
426
427 return SendPacket(&packet, sizeof(packet), user);
428 }
429
430 bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
431 {
432 HDR_8 hdr;
433
434 assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes");
435
436 strncpy((char *)hdr.magic, IA_ID, 6);
437 hdr.protoVer[0] = 0;
438 hdr.protoVer[1] = IA_PROTO_VER;
439
440 unsigned char buffer[2048];
441 memcpy(buffer, &hdr, sizeof(hdr));
442 memcpy(buffer + sizeof(hdr), packet, length);
443
444 size_t offset = sizeof(HDR_8);
445 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
446     {
447     Blowfish_Encrypt(&ctx,
448                      (uint32_t *)(buffer + offset + i * 8),
449                      (uint32_t *)(buffer + offset + i * 8 + 4));
450     }
451
452 offset += IA_LOGIN_LEN;
453 size_t encLen = (length - IA_LOGIN_LEN) / 8;
454 for (size_t i = 0; i < encLen; i++)
455     {
456     Blowfish_Encrypt(user->GetCtx(),
457                      (uint32_t*)(buffer + offset + i * 8),
458                      (uint32_t*)(buffer + offset + i * 8 + 4));
459     }
460
461 int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
462
463 if (res < 0)
464     {
465     errorStr = "Failed to send packet: '";
466     errorStr += strerror(errno);
467     errorStr += "'";
468     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
469     return false;
470     }
471
472 if (res < sizeof(buffer))
473     {
474     errorStr = "Packet sent partially";
475     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
476     return false;
477     }
478
479 return true;
480 }