X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/7897474c8a3fb975cc0bcda45e40c47d58959ea6..refs/heads/ticket52:/projects/sgauthstress/proto.cpp?ds=inline diff --git a/projects/sgauthstress/proto.cpp b/projects/sgauthstress/proto.cpp index b89681f6..3150002d 100644 --- a/projects/sgauthstress/proto.cpp +++ b/projects/sgauthstress/proto.cpp @@ -1,16 +1,30 @@ +#include +#include +#include #include #include +#include #include #include #include #include +#include #include "stg/common.h" #include "stg/ia_packets.h" +#include "stg/locker.h" #include "proto.h" +class HasIP : public std::unary_function, bool> { + public: + explicit HasIP(uint32_t i) : ip(i) {} + bool operator()(const std::pair & value) { return value.first == ip; } + private: + uint32_t ip; +}; + PROTO::PROTO(const std::string & server, uint16_t port, uint16_t localPort, @@ -56,14 +70,21 @@ processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc; processors["FIN"] = &PROTO::FIN_Proc; processors["INFO"] = &PROTO::INFO_Proc; // ERR_Proc will be handled explicitly + +pthread_mutex_init(&mutex, NULL); } PROTO::~PROTO() { +pthread_mutex_destroy(&mutex); } void * PROTO::Runner(void * data) { +sigset_t signalSet; +sigfillset(&signalSet); +pthread_sigmask(SIG_BLOCK, &signalSet, NULL); + PROTO * protoPtr = static_cast(data); protoPtr->Run(); return NULL; @@ -73,7 +94,7 @@ bool PROTO::Start() { stopped = false; running = true; -if (pthread_create(&tid, NULL, &Runner, NULL)) +if (pthread_create(&tid, NULL, &Runner, this)) { errorStr = "Failed to create listening thread: '"; errorStr += strerror(errno); @@ -111,61 +132,105 @@ if (pthread_join(tid, NULL)) return true; } -void PROTO::AddUser(const USER & user) +void PROTO::AddUser(const USER & user, bool connect) { - users.push_back(std::make_pair(user.GetIP(), user)); - struct pollfd pfd; - pfd.fd = user.GetSocket(); - pfd.events = POLLIN; - pfd.revents = 0; - pollFds.push_back(pfd); +STG_LOCKER lock(&mutex); +users.push_back(std::make_pair(user.GetIP(), user)); +users.back().second.InitNetwork(); + +struct pollfd pfd; +pfd.fd = users.back().second.GetSocket(); +pfd.events = POLLIN; +pfd.revents = 0; +pollFds.push_back(pfd); + +if (connect) + { + RealConnect(&users.back().second); + } } bool PROTO::Connect(uint32_t ip) { -/*std::vector >::const_iterator it; -it = users.find(ip); +std::list >::iterator it; +STG_LOCKER lock(&mutex); +it = std::find_if(users.begin(), users.end(), HasIP(ip)); if (it == users.end()) - return false;*/ + return false; // Do something -return true; +return RealConnect(&it->second); } bool PROTO::Disconnect(uint32_t ip) { -/*std::vector >::const_iterator it; -it = users.find(ip); +std::list >::iterator it; +STG_LOCKER lock(&mutex); +it = std::find_if(users.begin(), users.end(), HasIP(ip)); if (it == users.end()) - return false;*/ + return false; // Do something -return true; +return RealDisconnect(&it->second); } void PROTO::Run() { while (running) { - int res = poll(&pollFds.front(), pollFds.size(), timeout); + int res; + { + STG_LOCKER lock(&mutex); + res = poll(&pollFds.front(), pollFds.size(), timeout); + } if (res < 0) break; if (!running) break; if (res) + { + printfd(__FILE__, "PROTO::Run() - events: %d\n", res); RecvPacket(); + } + else + { + CheckTimeouts(); + } } stopped = true; } +void PROTO::CheckTimeouts() +{ +STG_LOCKER lock(&mutex); +std::list >::iterator it; +for (it = users.begin(); it != users.end(); ++it) + { + int delta = difftime(time(NULL), it->second.GetPhaseChangeTime()); + if ((it->second.GetPhase() == 3) && + (delta > it->second.GetUserTimeout())) + { + printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout()); + it->second.SetPhase(1); + } + if ((it->second.GetPhase() == 2) && + (delta > it->second.GetAliveTimeout())) + { + printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout()); + it->second.SetPhase(1); + } + } +} + bool PROTO::RecvPacket() { bool result = true; std::vector::iterator it; -std::vector >::iterator userIt; +std::list >::iterator userIt; +STG_LOCKER lock(&mutex); for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt) { if (it->revents) @@ -183,21 +248,27 @@ for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt continue; } - result = result && HandlePacket(buffer, &(userIt->second)); + result = result && HandlePacket(buffer, res, &(userIt->second)); } } return result; } -bool PROTO::HandlePacket(const char * buffer, USER * user) +bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user) { -if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR")) +if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3)) { return ERR_Proc(buffer, user); } +for (size_t i = 0; i < length / 8; i++) + Blowfish_Decrypt(user->GetCtx(), + (uint32_t *)(buffer + i * 8), + (uint32_t *)(buffer + i * 8 + 4)); + std::string packetName(buffer + 12); + std::map::const_iterator it; it = processors.find(packetName); if (it != processors.end()) @@ -222,12 +293,11 @@ SwapBytes(userTimeout); SwapBytes(aliveDelay); #endif -Send_CONN_ACK(user); - if (user->GetPhase() != 2) { errorStr = "Unexpected CONN_SYN_ACK"; printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase()); + return false; } user->SetPhase(3); @@ -235,6 +305,10 @@ user->SetAliveTimeout(aliveTimeout); user->SetUserTimeout(userTimeout); user->SetRnd(rnd); +Send_CONN_ACK(user); + +printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str()); + return true; } @@ -252,16 +326,11 @@ if (user->GetPhase() != 3) { errorStr = "Unexpected ALIVE_SYN"; printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase()); - } - -if (user->GetRnd() + 1 != rnd) - { - errorStr = "Wrong control value at ALIVE_SYN"; - printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1); + return false; } user->SetPhase(3); -user->SetRnd(rnd); +user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK Send_ALIVE_ACK(user); @@ -282,6 +351,7 @@ if (user->GetPhase() != 4) { errorStr = "Unexpected DISCONN_SYN_ACK"; printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase()); + return false; } if (user->GetRnd() + 1 != rnd) @@ -304,6 +374,7 @@ if (user->GetPhase() != 5) { errorStr = "Unexpected FIN"; printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase()); + return false; } user->SetPhase(1); @@ -318,14 +389,8 @@ bool PROTO::INFO_Proc(const void * buffer, USER * user) return true; } -bool PROTO::ERR_Proc(const void * buffer, USER * user) +bool PROTO::ERR_Proc(const void * /*buffer*/, USER * user) { -const ERR_8 * packet = static_cast(buffer); -const char * ptr = static_cast(buffer); - -for (size_t i = 0; i < sizeof(ERR_8) / 8; i++) - Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4)); - //uint32_t len = packet->len; #ifdef ARCH_BE @@ -352,6 +417,7 @@ packet.len = sizeof(packet); SwapBytes(packet.len); #endif +strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS)); strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type)); strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login)); packet.dirs = 0xFFffFFff; @@ -434,15 +500,15 @@ bool PROTO::SendPacket(const void * packet, size_t length, USER * user) { HDR_8 hdr; -assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes"); +assert(length < 2048 && "Packet length must not exceed 2048 bytes"); -strncpy((char *)hdr.magic, IA_ID, 6); +strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic)); hdr.protoVer[0] = 0; hdr.protoVer[1] = 8; // IA_PROTO_VER unsigned char buffer[2048]; +memcpy(buffer, packet, length); memcpy(buffer, &hdr, sizeof(hdr)); -memcpy(buffer + sizeof(hdr), packet, length); size_t offset = sizeof(HDR_8); for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++) @@ -461,18 +527,18 @@ for (size_t i = 0; i < encLen; i++) (uint32_t*)(buffer + offset + i * 8 + 4)); } -int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr)); +int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr)); if (res < 0) { errorStr = "Failed to send packet: '"; errorStr += strerror(errno); errorStr += "'"; - printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str()); + printfd(__FILE__, "PROTO::SendPacket() - %s, fd: %d\n", errorStr.c_str(), user->GetSocket()); return false; } -if (res < sizeof(buffer)) +if (res < length) { errorStr = "Packet sent partially"; printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str()); @@ -481,3 +547,28 @@ if (res < sizeof(buffer)) return true; } + +bool PROTO::RealConnect(USER * user) +{ +if (user->GetPhase() != 1 && + user->GetPhase() != 5) + { + errorStr = "Unexpected connect"; + printfd(__FILE__, "PROTO::RealConnect() - wrong phase: %d\n", user->GetPhase()); + } +user->SetPhase(2); + +return Send_CONN_SYN(user); +} + +bool PROTO::RealDisconnect(USER * user) +{ +if (user->GetPhase() != 3) + { + errorStr = "Unexpected disconnect"; + printfd(__FILE__, "PROTO::RealDisconnect() - wrong phase: %d\n", user->GetPhase()); + } +user->SetPhase(4); + +return Send_DISCONN_SYN(user); +}