X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/3cc6c36fcf3f0c6449debeb56e53c2ede03efc26..b4b5d092c4bca0ff33fdf384ed692b36f38d879e:/projects/sgauthstress/proto.cpp?ds=inline diff --git a/projects/sgauthstress/proto.cpp b/projects/sgauthstress/proto.cpp index d4624594..3150002d 100644 --- a/projects/sgauthstress/proto.cpp +++ b/projects/sgauthstress/proto.cpp @@ -1,6 +1,10 @@ +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> #include <netdb.h> #include <arpa/inet.h> +#include <csignal> #include <cerrno> #include <cstring> #include <cassert> @@ -9,6 +13,7 @@ #include "stg/common.h" #include "stg/ia_packets.h" +#include "stg/locker.h" #include "proto.h" @@ -65,14 +70,21 @@ processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc; processors["FIN"] = &PROTO::FIN_Proc; processors["INFO"] = &PROTO::INFO_Proc; // ERR_Proc will be handled explicitly + +pthread_mutex_init(&mutex, NULL); } PROTO::~PROTO() { +pthread_mutex_destroy(&mutex); } void * PROTO::Runner(void * data) { +sigset_t signalSet; +sigfillset(&signalSet); +pthread_sigmask(SIG_BLOCK, &signalSet, NULL); + PROTO * protoPtr = static_cast<PROTO *>(data); protoPtr->Run(); return NULL; @@ -82,7 +94,7 @@ bool PROTO::Start() { stopped = false; running = true; -if (pthread_create(&tid, NULL, &Runner, NULL)) +if (pthread_create(&tid, NULL, &Runner, this)) { errorStr = "Failed to create listening thread: '"; errorStr += strerror(errno); @@ -122,15 +134,16 @@ return true; void PROTO::AddUser(const USER & user, bool connect) { +STG_LOCKER lock(&mutex); users.push_back(std::make_pair(user.GetIP(), user)); +users.back().second.InitNetwork(); + struct pollfd pfd; -pfd.fd = user.GetSocket(); +pfd.fd = users.back().second.GetSocket(); pfd.events = POLLIN; pfd.revents = 0; pollFds.push_back(pfd); -users.back().second.InitNetwork(); - if (connect) { RealConnect(&users.back().second); @@ -139,7 +152,8 @@ if (connect) bool PROTO::Connect(uint32_t ip) { -std::vector<std::pair<uint32_t, USER> >::iterator it; +std::list<std::pair<uint32_t, USER> >::iterator it; +STG_LOCKER lock(&mutex); it = std::find_if(users.begin(), users.end(), HasIP(ip)); if (it == users.end()) return false; @@ -151,7 +165,8 @@ return RealConnect(&it->second); bool PROTO::Disconnect(uint32_t ip) { -std::vector<std::pair<uint32_t, USER> >::iterator it; +std::list<std::pair<uint32_t, USER> >::iterator it; +STG_LOCKER lock(&mutex); it = std::find_if(users.begin(), users.end(), HasIP(ip)); if (it == users.end()) return false; @@ -165,23 +180,57 @@ void PROTO::Run() { while (running) { - int res = poll(&pollFds.front(), pollFds.size(), timeout); + int res; + { + STG_LOCKER lock(&mutex); + res = poll(&pollFds.front(), pollFds.size(), timeout); + } if (res < 0) break; if (!running) break; if (res) + { + printfd(__FILE__, "PROTO::Run() - events: %d\n", res); RecvPacket(); + } + else + { + CheckTimeouts(); + } } stopped = true; } +void PROTO::CheckTimeouts() +{ +STG_LOCKER lock(&mutex); +std::list<std::pair<uint32_t, USER> >::iterator it; +for (it = users.begin(); it != users.end(); ++it) + { + int delta = difftime(time(NULL), it->second.GetPhaseChangeTime()); + if ((it->second.GetPhase() == 3) && + (delta > it->second.GetUserTimeout())) + { + printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout()); + it->second.SetPhase(1); + } + if ((it->second.GetPhase() == 2) && + (delta > it->second.GetAliveTimeout())) + { + printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout()); + it->second.SetPhase(1); + } + } +} + bool PROTO::RecvPacket() { bool result = true; std::vector<struct pollfd>::iterator it; -std::vector<std::pair<uint32_t, USER> >::iterator userIt; +std::list<std::pair<uint32_t, USER> >::iterator userIt; +STG_LOCKER lock(&mutex); for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt) { if (it->revents) @@ -199,21 +248,27 @@ for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt continue; } - result = result && HandlePacket(buffer, &(userIt->second)); + result = result && HandlePacket(buffer, res, &(userIt->second)); } } return result; } -bool PROTO::HandlePacket(const char * buffer, USER * user) +bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user) { -if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR")) +if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3)) { return ERR_Proc(buffer, user); } +for (size_t i = 0; i < length / 8; i++) + Blowfish_Decrypt(user->GetCtx(), + (uint32_t *)(buffer + i * 8), + (uint32_t *)(buffer + i * 8 + 4)); + std::string packetName(buffer + 12); + std::map<std::string, PacketProcessor>::const_iterator it; it = processors.find(packetName); if (it != processors.end()) @@ -238,12 +293,11 @@ SwapBytes(userTimeout); SwapBytes(aliveDelay); #endif -Send_CONN_ACK(user); - if (user->GetPhase() != 2) { errorStr = "Unexpected CONN_SYN_ACK"; printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase()); + return false; } user->SetPhase(3); @@ -251,6 +305,8 @@ user->SetAliveTimeout(aliveTimeout); user->SetUserTimeout(userTimeout); user->SetRnd(rnd); +Send_CONN_ACK(user); + printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str()); return true; @@ -270,16 +326,11 @@ if (user->GetPhase() != 3) { errorStr = "Unexpected ALIVE_SYN"; printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase()); - } - -if (user->GetRnd() + 1 != rnd) - { - errorStr = "Wrong control value at ALIVE_SYN"; - printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1); + return false; } user->SetPhase(3); -user->SetRnd(rnd); +user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK Send_ALIVE_ACK(user); @@ -300,6 +351,7 @@ if (user->GetPhase() != 4) { errorStr = "Unexpected DISCONN_SYN_ACK"; printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase()); + return false; } if (user->GetRnd() + 1 != rnd) @@ -322,6 +374,7 @@ if (user->GetPhase() != 5) { errorStr = "Unexpected FIN"; printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase()); + return false; } user->SetPhase(1); @@ -336,14 +389,8 @@ bool PROTO::INFO_Proc(const void * buffer, USER * user) return true; } -bool PROTO::ERR_Proc(const void * buffer, USER * user) +bool PROTO::ERR_Proc(const void * /*buffer*/, USER * user) { -const ERR_8 * packet = static_cast<const ERR_8 *>(buffer); -const char * ptr = static_cast<const char *>(buffer); - -for (size_t i = 0; i < sizeof(ERR_8) / 8; i++) - Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4)); - //uint32_t len = packet->len; #ifdef ARCH_BE @@ -370,6 +417,7 @@ packet.len = sizeof(packet); SwapBytes(packet.len); #endif +strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS)); strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type)); strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login)); packet.dirs = 0xFFffFFff; @@ -452,15 +500,15 @@ bool PROTO::SendPacket(const void * packet, size_t length, USER * user) { HDR_8 hdr; -assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes"); +assert(length < 2048 && "Packet length must not exceed 2048 bytes"); -strncpy((char *)hdr.magic, IA_ID, 6); +strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic)); hdr.protoVer[0] = 0; hdr.protoVer[1] = 8; // IA_PROTO_VER unsigned char buffer[2048]; +memcpy(buffer, packet, length); memcpy(buffer, &hdr, sizeof(hdr)); -memcpy(buffer + sizeof(hdr), packet, length); size_t offset = sizeof(HDR_8); for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++) @@ -479,7 +527,7 @@ for (size_t i = 0; i < encLen; i++) (uint32_t*)(buffer + offset + i * 8 + 4)); } -int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr)); +int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr)); if (res < 0) { @@ -490,7 +538,7 @@ if (res < 0) return false; } -if (res < sizeof(buffer)) +if (res < length) { errorStr = "Packet sent partially"; printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());