X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/3cc6c36fcf3f0c6449debeb56e53c2ede03efc26..c088c3e07ab17e33165fa41fe6175a25bc03d4da:/projects/sgauthstress/proto.cpp diff --git a/projects/sgauthstress/proto.cpp b/projects/sgauthstress/proto.cpp index d4624594..dd4984ca 100644 --- a/projects/sgauthstress/proto.cpp +++ b/projects/sgauthstress/proto.cpp @@ -9,6 +9,7 @@ #include "stg/common.h" #include "stg/ia_packets.h" +#include "stg/locker.h" #include "proto.h" @@ -65,10 +66,13 @@ processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc; processors["FIN"] = &PROTO::FIN_Proc; processors["INFO"] = &PROTO::INFO_Proc; // ERR_Proc will be handled explicitly + +pthread_mutex_init(&mutex, NULL); } PROTO::~PROTO() { +pthread_mutex_destroy(&mutex); } void * PROTO::Runner(void * data) @@ -82,7 +86,7 @@ bool PROTO::Start() { stopped = false; running = true; -if (pthread_create(&tid, NULL, &Runner, NULL)) +if (pthread_create(&tid, NULL, &Runner, this)) { errorStr = "Failed to create listening thread: '"; errorStr += strerror(errno); @@ -122,15 +126,16 @@ return true; void PROTO::AddUser(const USER & user, bool connect) { +STG_LOCKER lock(&mutex, __FILE__, __LINE__); users.push_back(std::make_pair(user.GetIP(), user)); +users.back().second.InitNetwork(); + struct pollfd pfd; -pfd.fd = user.GetSocket(); +pfd.fd = users.back().second.GetSocket(); pfd.events = POLLIN; pfd.revents = 0; pollFds.push_back(pfd); -users.back().second.InitNetwork(); - if (connect) { RealConnect(&users.back().second); @@ -139,7 +144,8 @@ if (connect) bool PROTO::Connect(uint32_t ip) { -std::vector >::iterator it; +std::list >::iterator it; +STG_LOCKER lock(&mutex, __FILE__, __LINE__); it = std::find_if(users.begin(), users.end(), HasIP(ip)); if (it == users.end()) return false; @@ -151,7 +157,8 @@ return RealConnect(&it->second); bool PROTO::Disconnect(uint32_t ip) { -std::vector >::iterator it; +std::list >::iterator it; +STG_LOCKER lock(&mutex, __FILE__, __LINE__); it = std::find_if(users.begin(), users.end(), HasIP(ip)); if (it == users.end()) return false; @@ -171,7 +178,10 @@ while (running) if (!running) break; if (res) + { + printfd(__FILE__, "PROTO::Run() - events: %d\n", res); RecvPacket(); + } } stopped = true; @@ -181,12 +191,14 @@ bool PROTO::RecvPacket() { bool result = true; std::vector::iterator it; -std::vector >::iterator userIt; +std::list >::iterator userIt; +STG_LOCKER lock(&mutex, __FILE__, __LINE__); for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt) { if (it->revents) { it->revents = 0; + printfd(__FILE__, "PROTO::RecvPacket() - pollfd: %d, socket: %d\n", it->fd, userIt->second.GetSocket()); assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked"); struct sockaddr_in addr; socklen_t fromLen = sizeof(addr); @@ -199,21 +211,27 @@ for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt continue; } - result = result && HandlePacket(buffer, &(userIt->second)); + result = result && HandlePacket(buffer, res, &(userIt->second)); } } return result; } -bool PROTO::HandlePacket(const char * buffer, USER * user) +bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user) { -if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR")) +if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3)) { return ERR_Proc(buffer, user); } +for (size_t i = 0; i < length / 8; i++) + Blowfish_Decrypt(user->GetCtx(), + (uint32_t *)(buffer + i * 8), + (uint32_t *)(buffer + i * 8 + 4)); + std::string packetName(buffer + 12); + std::map::const_iterator it; it = processors.find(packetName); if (it != processors.end()) @@ -238,8 +256,6 @@ SwapBytes(userTimeout); SwapBytes(aliveDelay); #endif -Send_CONN_ACK(user); - if (user->GetPhase() != 2) { errorStr = "Unexpected CONN_SYN_ACK"; @@ -251,6 +267,8 @@ user->SetAliveTimeout(aliveTimeout); user->SetUserTimeout(userTimeout); user->SetRnd(rnd); +Send_CONN_ACK(user); + printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str()); return true; @@ -272,14 +290,8 @@ if (user->GetPhase() != 3) printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase()); } -if (user->GetRnd() + 1 != rnd) - { - errorStr = "Wrong control value at ALIVE_SYN"; - printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1); - } - user->SetPhase(3); -user->SetRnd(rnd); +user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK Send_ALIVE_ACK(user); @@ -341,9 +353,6 @@ bool PROTO::ERR_Proc(const void * buffer, USER * user) const ERR_8 * packet = static_cast(buffer); const char * ptr = static_cast(buffer); -for (size_t i = 0; i < sizeof(ERR_8) / 8; i++) - Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4)); - //uint32_t len = packet->len; #ifdef ARCH_BE @@ -370,6 +379,7 @@ packet.len = sizeof(packet); SwapBytes(packet.len); #endif +strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS)); strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type)); strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login)); packet.dirs = 0xFFffFFff; @@ -452,15 +462,16 @@ bool PROTO::SendPacket(const void * packet, size_t length, USER * user) { HDR_8 hdr; -assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes"); +assert(length < 2048 && "Packet length must not exceed 2048 bytes"); -strncpy((char *)hdr.magic, IA_ID, 6); +strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic)); hdr.protoVer[0] = 0; hdr.protoVer[1] = 8; // IA_PROTO_VER unsigned char buffer[2048]; +memset(buffer, 0, sizeof(buffer)); +memcpy(buffer, packet, length); memcpy(buffer, &hdr, sizeof(hdr)); -memcpy(buffer + sizeof(hdr), packet, length); size_t offset = sizeof(HDR_8); for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++) @@ -479,7 +490,7 @@ for (size_t i = 0; i < encLen; i++) (uint32_t*)(buffer + offset + i * 8 + 4)); } -int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr)); +int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr)); if (res < 0) { @@ -490,7 +501,7 @@ if (res < 0) return false; } -if (res < sizeof(buffer)) +if (res < length) { errorStr = "Packet sent partially"; printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());