]> git.stg.codes - stg.git/commitdiff
Locks added, rnd logic for ALIVE changed, CONN_ACK bug fixed
authorMaxim Mamontov <faust@gts.dp.ua>
Wed, 11 May 2011 11:18:56 +0000 (14:18 +0300)
committerMaxim Mamontov <faust@gts.dp.ua>
Wed, 11 May 2011 11:18:56 +0000 (14:18 +0300)
projects/sgauthstress/proto.cpp

index d4624594ebf006b10ed20f355759eb1398c348e9..dd4984caca545bf5d0243aefa067fab8f8ea47ad 100644 (file)
@@ -9,6 +9,7 @@
 
 #include "stg/common.h"
 #include "stg/ia_packets.h"
 
 #include "stg/common.h"
 #include "stg/ia_packets.h"
+#include "stg/locker.h"
 
 #include "proto.h"
 
 
 #include "proto.h"
 
@@ -65,10 +66,13 @@ processors["DISCONN_SYN_ACK"] = &PROTO::DISCONN_SYN_ACK_Proc;
 processors["FIN"] = &PROTO::FIN_Proc;
 processors["INFO"] = &PROTO::INFO_Proc;
 // ERR_Proc will be handled explicitly
 processors["FIN"] = &PROTO::FIN_Proc;
 processors["INFO"] = &PROTO::INFO_Proc;
 // ERR_Proc will be handled explicitly
+
+pthread_mutex_init(&mutex, NULL);
 }
 
 PROTO::~PROTO()
 {
 }
 
 PROTO::~PROTO()
 {
+pthread_mutex_destroy(&mutex);
 }
 
 void * PROTO::Runner(void * data)
 }
 
 void * PROTO::Runner(void * data)
@@ -82,7 +86,7 @@ bool PROTO::Start()
 {
 stopped = false;
 running = true;
 {
 stopped = false;
 running = true;
-if (pthread_create(&tid, NULL, &Runner, NULL))
+if (pthread_create(&tid, NULL, &Runner, this))
     {
     errorStr = "Failed to create listening thread: '";
     errorStr += strerror(errno);
     {
     errorStr = "Failed to create listening thread: '";
     errorStr += strerror(errno);
@@ -122,15 +126,16 @@ return true;
 
 void PROTO::AddUser(const USER & user, bool connect)
 {
 
 void PROTO::AddUser(const USER & user, bool connect)
 {
+STG_LOCKER lock(&mutex, __FILE__, __LINE__);
 users.push_back(std::make_pair(user.GetIP(), user));
 users.push_back(std::make_pair(user.GetIP(), user));
+users.back().second.InitNetwork();
+
 struct pollfd pfd;
 struct pollfd pfd;
-pfd.fd = user.GetSocket();
+pfd.fd = users.back().second.GetSocket();
 pfd.events = POLLIN;
 pfd.revents = 0;
 pollFds.push_back(pfd);
 
 pfd.events = POLLIN;
 pfd.revents = 0;
 pollFds.push_back(pfd);
 
-users.back().second.InitNetwork();
-
 if (connect)
     {
     RealConnect(&users.back().second);
 if (connect)
     {
     RealConnect(&users.back().second);
@@ -139,7 +144,8 @@ if (connect)
 
 bool PROTO::Connect(uint32_t ip)
 {
 
 bool PROTO::Connect(uint32_t ip)
 {
-std::vector<std::pair<uint32_t, USER> >::iterator it;
+std::list<std::pair<uint32_t, USER> >::iterator it;
+STG_LOCKER lock(&mutex, __FILE__, __LINE__);
 it = std::find_if(users.begin(), users.end(), HasIP(ip));
 if (it == users.end())
     return false;
 it = std::find_if(users.begin(), users.end(), HasIP(ip));
 if (it == users.end())
     return false;
@@ -151,7 +157,8 @@ return RealConnect(&it->second);
 
 bool PROTO::Disconnect(uint32_t ip)
 {
 
 bool PROTO::Disconnect(uint32_t ip)
 {
-std::vector<std::pair<uint32_t, USER> >::iterator it;
+std::list<std::pair<uint32_t, USER> >::iterator it;
+STG_LOCKER lock(&mutex, __FILE__, __LINE__);
 it = std::find_if(users.begin(), users.end(), HasIP(ip));
 if (it == users.end())
     return false;
 it = std::find_if(users.begin(), users.end(), HasIP(ip));
 if (it == users.end())
     return false;
@@ -171,7 +178,10 @@ while (running)
     if (!running)
         break;
     if (res)
     if (!running)
         break;
     if (res)
+        {
+        printfd(__FILE__, "PROTO::Run() - events: %d\n", res);
         RecvPacket();
         RecvPacket();
+        }
     }
 
 stopped = true;
     }
 
 stopped = true;
@@ -181,12 +191,14 @@ bool PROTO::RecvPacket()
 {
 bool result = true;
 std::vector<struct pollfd>::iterator it;
 {
 bool result = true;
 std::vector<struct pollfd>::iterator it;
-std::vector<std::pair<uint32_t, USER> >::iterator userIt;
+std::list<std::pair<uint32_t, USER> >::iterator userIt;
+STG_LOCKER lock(&mutex, __FILE__, __LINE__);
 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
     {
     if (it->revents)
         {
         it->revents = 0;
 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
     {
     if (it->revents)
         {
         it->revents = 0;
+        printfd(__FILE__, "PROTO::RecvPacket() - pollfd: %d, socket: %d\n", it->fd, userIt->second.GetSocket());
         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
         struct sockaddr_in addr;
         socklen_t fromLen = sizeof(addr);
         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
         struct sockaddr_in addr;
         socklen_t fromLen = sizeof(addr);
@@ -199,21 +211,27 @@ for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt
             continue;
             }
 
             continue;
             }
 
-        result = result && HandlePacket(buffer, &(userIt->second));
+        result = result && HandlePacket(buffer, res, &(userIt->second));
         }
     }
 
 return result;
 }
 
         }
     }
 
 return result;
 }
 
-bool PROTO::HandlePacket(const char * buffer, USER * user)
+bool PROTO::HandlePacket(const char * buffer, size_t length, USER * user)
 {
 {
-if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
+if (!strncmp(buffer + 4 + sizeof(HDR_8), "ERR", 3))
     {
     return ERR_Proc(buffer, user);
     }
 
     {
     return ERR_Proc(buffer, user);
     }
 
+for (size_t i = 0; i < length / 8; i++)
+    Blowfish_Decrypt(user->GetCtx(),
+                     (uint32_t *)(buffer + i * 8),
+                     (uint32_t *)(buffer + i * 8 + 4));
+
 std::string packetName(buffer + 12);
 std::string packetName(buffer + 12);
+
 std::map<std::string, PacketProcessor>::const_iterator it;
 it = processors.find(packetName);
 if (it != processors.end())
 std::map<std::string, PacketProcessor>::const_iterator it;
 it = processors.find(packetName);
 if (it != processors.end())
@@ -238,8 +256,6 @@ SwapBytes(userTimeout);
 SwapBytes(aliveDelay);
 #endif
 
 SwapBytes(aliveDelay);
 #endif
 
-Send_CONN_ACK(user);
-
 if (user->GetPhase() != 2)
     {
     errorStr = "Unexpected CONN_SYN_ACK";
 if (user->GetPhase() != 2)
     {
     errorStr = "Unexpected CONN_SYN_ACK";
@@ -251,6 +267,8 @@ user->SetAliveTimeout(aliveTimeout);
 user->SetUserTimeout(userTimeout);
 user->SetRnd(rnd);
 
 user->SetUserTimeout(userTimeout);
 user->SetRnd(rnd);
 
+Send_CONN_ACK(user);
+
 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
 
 return true;
 printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - user '%s' successfully logged in from IP %s\n", user->GetLogin().c_str(), inet_ntostring(user->GetIP()).c_str());
 
 return true;
@@ -272,14 +290,8 @@ if (user->GetPhase() != 3)
     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
     }
 
     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
     }
 
-if (user->GetRnd() + 1 != rnd)
-    {
-    errorStr = "Wrong control value at ALIVE_SYN";
-    printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
-    }
-
 user->SetPhase(3);
 user->SetPhase(3);
-user->SetRnd(rnd);
+user->SetRnd(rnd); // Set new rnd value for ALIVE_ACK
 
 Send_ALIVE_ACK(user);
 
 
 Send_ALIVE_ACK(user);
 
@@ -341,9 +353,6 @@ bool PROTO::ERR_Proc(const void * buffer, USER * user)
 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
 const char * ptr = static_cast<const char *>(buffer);
 
 const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
 const char * ptr = static_cast<const char *>(buffer);
 
-for (size_t i = 0; i < sizeof(ERR_8) / 8; i++)
-    Blowfish_Decrypt(user->GetCtx(), (uint32_t *)(ptr + i * 8), (uint32_t *)(ptr + i * 8 + 4));
-
 //uint32_t len = packet->len;
 
 #ifdef ARCH_BE
 //uint32_t len = packet->len;
 
 #ifdef ARCH_BE
@@ -370,6 +379,7 @@ packet.len = sizeof(packet);
 SwapBytes(packet.len);
 #endif
 
 SwapBytes(packet.len);
 #endif
 
+strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
 packet.dirs = 0xFFffFFff;
 strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
 strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
 packet.dirs = 0xFFffFFff;
@@ -452,15 +462,16 @@ bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
 {
 HDR_8 hdr;
 
 {
 HDR_8 hdr;
 
-assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes");
+assert(length < 2048 && "Packet length must not exceed 2048 bytes");
 
 
-strncpy((char *)hdr.magic, IA_ID, 6);
+strncpy((char *)hdr.magic, IA_ID, sizeof(hdr.magic));
 hdr.protoVer[0] = 0;
 hdr.protoVer[1] = 8; // IA_PROTO_VER
 
 unsigned char buffer[2048];
 hdr.protoVer[0] = 0;
 hdr.protoVer[1] = 8; // IA_PROTO_VER
 
 unsigned char buffer[2048];
+memset(buffer, 0, sizeof(buffer));
+memcpy(buffer, packet, length);
 memcpy(buffer, &hdr, sizeof(hdr));
 memcpy(buffer, &hdr, sizeof(hdr));
-memcpy(buffer + sizeof(hdr), packet, length);
 
 size_t offset = sizeof(HDR_8);
 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
 
 size_t offset = sizeof(HDR_8);
 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
@@ -479,7 +490,7 @@ for (size_t i = 0; i < encLen; i++)
                      (uint32_t*)(buffer + offset + i * 8 + 4));
     }
 
                      (uint32_t*)(buffer + offset + i * 8 + 4));
     }
 
-int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
+int res = sendto(user->GetSocket(), buffer, length, 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
 
 if (res < 0)
     {
 
 if (res < 0)
     {
@@ -490,7 +501,7 @@ if (res < 0)
     return false;
     }
 
     return false;
     }
 
-if (res < sizeof(buffer))
+if (res < length)
     {
     errorStr = "Packet sent partially";
     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
     {
     errorStr = "Packet sent partially";
     printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());