]> git.stg.codes - stg.git/commitdiff
Another part of PROTO implemented
authorMaxim Mamontov <faust@gts.dp.ua>
Fri, 6 May 2011 15:00:10 +0000 (18:00 +0300)
committerMaxim Mamontov <faust@gts.dp.ua>
Fri, 6 May 2011 15:00:10 +0000 (18:00 +0300)
Senders added
Processors changed
Send/Recv implemeted
Implementation changed from select to poll
Control number logic implemeted
Phase checking implemented

projects/sgauthstress/proto.cpp
projects/sgauthstress/proto.h

index 06a67d8f600ee85651618ed163d1a3e61e337ddb..7c547daf8462caeb086d4e38cd54241560a5293a 100644 (file)
@@ -3,40 +3,13 @@
 
 #include <cerrno>
 #include <cstring>
 
 #include <cerrno>
 #include <cstring>
+#include <cassert>
 #include <stdexcept>
 
 #include "stg/common.h"
 
 #include "proto.h"
 
 #include <stdexcept>
 
 #include "stg/common.h"
 
 #include "proto.h"
 
-int WaitPacket(int sd, int timeout)
-{
-fd_set rfds;
-FD_ZERO(&rfds);
-FD_SET(sd, &rfds);
-
-struct timeval tv;
-tv.tv_sec = timeout;
-tv.tv_usec = 0;
-
-int res = select(sd + 1, &rfds, NULL, NULL, &tv);
-if (res == -1) // Error
-    {
-    if (errno != EINTR)
-        {
-        printfd(__FILE__, "Error on select: '%s'\n", strerror(errno));
-        }
-    return -1;
-    }
-
-if (res == 0) // Timeout
-    {
-    return 0;
-    }
-
-return 1;
-}
-
 PROTO::PROTO(const std::string & server,
              uint16_t port,
              uint16_t localPort,
 PROTO::PROTO(const std::string & server,
              uint16_t port,
              uint16_t localPort,
@@ -63,8 +36,6 @@ if (ip == INADDR_NONE)
         }
     }
 
         }
     }
 
-sock = socket(AF_INET, SOCK_DGRAM, 0);
-
 localAddr.sin_family = AF_INET;
 localAddr.sin_port = htons(localPort);
 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
 localAddr.sin_family = AF_INET;
 localAddr.sin_port = htons(localPort);
 localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
@@ -88,7 +59,6 @@ processors["INFO"] = &PROTO::INFO_Proc;
 
 PROTO::~PROTO()
 {
 
 PROTO::~PROTO()
 {
-close(sock);
 }
 
 void * PROTO::Runner(void * data)
 }
 
 void * PROTO::Runner(void * data)
@@ -139,10 +109,20 @@ if (pthread_join(tid, NULL))
 return true;
 }
 
 return true;
 }
 
-bool PROTO::Connect(const std::string & login)
+void PROTO::AddUser(const USER & user)
 {
 {
-std::map<std::string, USER>::const_iterator it;
-it = users.find(login);
+    users.insert(std::make_pair(user.GetIP(), user));
+    struct pollfd pfd;
+    pfd.fd = user.GetSocket();
+    pfd.events = POLLIN;
+    pfd.revents = 0;
+    pollFds.push_back(pfd);
+}
+
+bool PROTO::Connect(uint32_t ip)
+{
+std::map<uint32_t, USER>::const_iterator it;
+it = users.find(ip);
 if (it == users.end())
     return false;
 
 if (it == users.end())
     return false;
 
@@ -151,10 +131,10 @@ if (it == users.end())
 return true;
 }
 
 return true;
 }
 
-bool PROTO::Disconnect(const std::string & login)
+bool PROTO::Disconnect(uint32_t ip)
 {
 {
-std::map<std::string, USER>::const_iterator it;
-it = users.find(login);
+std::map<uint32_t, USER>::const_iterator it;
+it = users.find(ip);
 if (it == users.end())
     return false;
 
 if (it == users.end())
     return false;
 
@@ -167,7 +147,7 @@ void PROTO::Run()
 {
 while (running)
     {
 {
 while (running)
     {
-    int res = WaitPacket(sock, timeout);
+    int res = poll(&pollFds.front(), pollFds.size(), timeout);
     if (res < 0)
         break;
     if (!running)
     if (res < 0)
         break;
     if (!running)
@@ -181,29 +161,320 @@ stopped = true;
 
 bool PROTO::RecvPacket()
 {
 
 bool PROTO::RecvPacket()
 {
-struct sockaddr_in addr;
-socklen_t fromLen = sizeof(addr);
-char buffer[2048];
-int res = recvfrom(sock, buffer, sizeof(buffer), 0, (struct sockaddr*)&addr, &fromLen);
-
-if (res == -1)
-    return res;
+bool result = true;
+std::vector<struct pollfd>::iterator it;
+std::map<uint32_t, USER>::iterator userIt(users.begin());
+for (it = pollFds.begin(); it != pollFds.end(); ++it)
+    {
+    if (it->revents)
+        {
+        it->revents = 0;
+        assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
+        struct sockaddr_in addr;
+        socklen_t fromLen = sizeof(addr);
+        char buffer[2048];
+        int res = recvfrom(userIt->second.GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr*)&addr, &fromLen);
+
+        if (res == -1)
+            {
+            result = false;
+            ++userIt;
+            continue;
+            }
+
+        result = result && HandlePacket(buffer, &(userIt->second));
+        }
+    ++userIt;
+    }
 
 
-return HandlePacket(buffer);
+return result;
 }
 
 }
 
-bool PROTO::HandlePacket(char * buffer)
+bool PROTO::HandlePacket(const char * buffer, USER * user)
 {
 if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
     {
 {
 if (strcmp(buffer + 4 + sizeof(HDR_8), "ERR"))
     {
-    return ERR_Proc(buffer);
+    return ERR_Proc(buffer, user);
     }
 
 std::string packetName(buffer + 12);
 std::map<std::string, PacketProcessor>::const_iterator it;
 it = processors.find(packetName);
 if (it != processors.end())
     }
 
 std::string packetName(buffer + 12);
 std::map<std::string, PacketProcessor>::const_iterator it;
 it = processors.find(packetName);
 if (it != processors.end())
-    return (this->*it->second)(buffer);
+    return (this->*it->second)(buffer, user);
 
 return false;
 }
 
 return false;
 }
+
+bool PROTO::CONN_SYN_ACK_Proc(const void * buffer, USER * user)
+{
+const CONN_SYN_ACK_8 * packet = static_cast<const CONN_SYN_ACK_8 *>(buffer);
+
+uint32_t rnd = packet->rnd;
+uint32_t userTimeout = packet->userTimeOut;
+uint32_t aliveTimeout = packet->aliveDelay;
+
+#ifdef ARCH_BE
+SwapBytes(rnd);
+SwapBytes(userTimeout);
+SwapBytes(aliveDelay);
+#endif
+
+Send_CONN_ACK(user);
+
+if (user->GetPhase() != 2)
+    {
+    errorStr = "Unexpected CONN_SYN_ACK";
+    printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
+    }
+
+user->SetPhase(3);
+user->SetAliveTimeout(aliveTimeout);
+user->SetUserTimeout(userTimeout);
+user->SetRnd(rnd);
+
+return true;
+}
+
+bool PROTO::ALIVE_SYN_Proc(const void * buffer, USER * user)
+{
+const ALIVE_SYN_8 * packet = static_cast<const ALIVE_SYN_8 *>(buffer);
+
+uint32_t rnd = packet->rnd;
+
+#ifdef ARCH_BE
+SwapBytes(rnd);
+#endif
+
+if (user->GetPhase() != 3)
+    {
+    errorStr = "Unexpected ALIVE_SYN";
+    printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
+    }
+
+if (user->GetRnd() + 1 != rnd)
+    {
+    errorStr = "Wrong control value at ALIVE_SYN";
+    printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
+    }
+
+user->SetPhase(3);
+user->SetRnd(rnd);
+
+Send_ALIVE_ACK(user);
+
+return true;
+}
+
+bool PROTO::DISCONN_SYN_ACK_Proc(const void * buffer, USER * user)
+{
+const DISCONN_SYN_ACK_8 * packet = static_cast<const DISCONN_SYN_ACK_8 *>(buffer);
+
+uint32_t rnd = packet->rnd;
+
+#ifdef ARCH_BE
+SwapBytes(rnd);
+#endif
+
+if (user->GetPhase() != 4)
+    {
+    errorStr = "Unexpected DISCONN_SYN_ACK";
+    printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
+    }
+
+if (user->GetRnd() + 1 != rnd)
+    {
+    errorStr = "Wrong control value at DISCONN_SYN_ACK";
+    printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong control value: %d, expected: %d\n", rnd, user->GetRnd() + 1);
+    }
+
+user->SetPhase(5);
+user->SetRnd(rnd);
+
+Send_DISCONN_ACK(user);
+
+return true;
+}
+
+bool PROTO::FIN_Proc(const void * buffer, USER * user)
+{
+if (user->GetPhase() != 5)
+    {
+    errorStr = "Unexpected FIN";
+    printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
+    }
+
+user->SetPhase(1);
+
+return true;
+}
+
+bool PROTO::INFO_Proc(const void * buffer, USER * user)
+{
+//const INFO_8 * packet = static_cast<const INFO_8 *>(buffer);
+
+return true;
+}
+
+bool PROTO::ERR_Proc(const void * buffer, USER * user)
+{
+const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
+
+for (int i = 0; i < len/8; i++)
+    Blowfish_Decrypt(&ctxPass, (uint32_t*)(buffer + i*8), (uint32_t*)(buffer + i*8 + 4));
+
+//uint32_t len = packet->len;
+
+#ifdef ARCH_BE
+//SwapBytes(len);
+#endif
+
+user->SetPhase(1); //TODO: Check
+/*KOIToWin((const char*)err.text, &messageText);
+if (pErrorCb != NULL)
+    pErrorCb(messageText, IA_SERVER_ERROR, errorCbData);
+phaseTime = GetTickCount();
+codeError = IA_SERVER_ERROR;*/
+
+return true;
+}
+
+bool PROTO::Send_CONN_SYN(USER * user)
+{
+CONN_SYN_8 packet;
+
+packet.len = sizeof(packet);
+
+#ifdef ARCH_BE
+SwapBytes(packet.len);
+#endif
+
+strncpy((char *)packet.type, "CONN_SYN", sizeof(packet.type));
+strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
+packet.dirs = 0xFFffFFff;
+
+return SendPacket(&packet, sizeof(packet), user);
+}
+
+bool PROTO::Send_CONN_ACK(USER * user)
+{
+CONN_ACK_8 packet;
+
+packet.len = sizeof(packet);
+packet.rnd = user->IncRnd();
+
+#ifdef ARCH_BE
+SwapBytes(packet.len);
+SwapBytes(packet.rnd);
+#endif
+
+strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
+strncpy((char *)packet.type, "CONN_ACK", sizeof(packet.type));
+
+return SendPacket(&packet, sizeof(packet), user);
+}
+
+bool PROTO::Send_ALIVE_ACK(USER * user)
+{
+ALIVE_ACK_8 packet;
+
+packet.len = sizeof(packet);
+packet.rnd = user->IncRnd();
+
+#ifdef ARCH_BE
+SwapBytes(packet.len);
+SwapBytes(packet.rnd);
+#endif
+
+strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
+strncpy((char *)packet.type, "ALIVE_ACK", sizeof(packet.type));
+
+return SendPacket(&packet, sizeof(packet), user);
+}
+
+bool PROTO::Send_DISCONN_SYN(USER * user)
+{
+DISCONN_SYN_8 packet;
+
+packet.len = sizeof(packet);
+
+#ifdef ARCH_BE
+SwapBytes(packet.len);
+#endif
+
+strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
+strncpy((char *)packet.type, "DISCONN_SYN", sizeof(packet.type));
+strncpy((char *)packet.login, user->GetLogin().c_str(), sizeof(packet.login));
+
+return SendPacket(&packet, sizeof(packet), user);
+}
+
+bool PROTO::Send_DISCONN_ACK(USER * user)
+{
+DISCONN_ACK_8 packet;
+
+packet.len = sizeof(packet);
+packet.rnd = user->IncRnd();
+
+#ifdef ARCH_BE
+SwapBytes(packet.len);
+SwapBytes(packet.rnd);
+#endif
+
+strncpy((char *)packet.loginS, user->GetLogin().c_str(), sizeof(packet.loginS));
+strncpy((char *)packet.type, "DISCONN_ACK", sizeof(packet.type));
+
+return SendPacket(&packet, sizeof(packet), user);
+}
+
+bool PROTO::SendPacket(const void * packet, size_t length, USER * user)
+{
+HDR_8 hdr;
+
+assert(sizeof(hdr) + length < 2048 && "Packet length must not exceed 2048 bytes");
+
+strncpy((char *)hdr.magic, IA_ID, 6);
+hdr.protoVer[0] = 0;
+hdr.protoVer[1] = IA_PROTO_VER;
+
+unsigned char buffer[2048];
+memcpy(buffer, &hdr, sizeof(hdr));
+memcpy(buffer + sizeof(hdr), packet, length);
+
+size_t offset = sizeof(HDR_8);
+for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)
+    {
+    Blowfish_Encrypt(&ctx,
+                     (uint32_t *)(buffer + offset + i * 8),
+                     (uint32_t *)(buffer + offset + i * 8 + 4));
+    }
+
+offset += IA_LOGIN_LEN;
+size_t encLen = (length - IA_LOGIN_LEN) / 8;
+for (size_t i = 0; i < encLen; i++)
+    {
+    Blowfish_Encrypt(user->GetCtx(),
+                     (uint32_t*)(buffer + offset + i * 8),
+                     (uint32_t*)(buffer + offset + i * 8 + 4));
+    }
+
+int res = sendto(user->GetSocket(), buffer, sizeof(buffer), 0, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
+
+if (res < 0)
+    {
+    errorStr = "Failed to send packet: '";
+    errorStr += strerror(errno);
+    errorStr += "'";
+    printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
+    return false;
+    }
+
+if (res < sizeof(buffer))
+    {
+    errorStr = "Packet sent partially";
+    printfd(__FILE__, "PROTO::SendPacket() - %s\n", errorStr.c_str());
+    return false;
+    }
+
+return true;
+}
index f478c0e1eb4133dad60a6294a1300457938e10b9..879e7a35b11a49e3a367842a7de31705bb6429ac 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <netinet/ip.h>
 #include <pthread.h>
 
 #include <netinet/ip.h>
 #include <pthread.h>
+#include <poll.h>
 
 #include <string>
 #include <map>
 
 #include <string>
 #include <map>
@@ -14,7 +15,7 @@
 
 class PROTO;
 
 
 class PROTO;
 
-typedef bool (PROTO::*PacketProcessor)(char *);
+typedef bool (PROTO::*PacketProcessor)(const void *, USER *);
 
 class PROTO {
     public:
 
 class PROTO {
     public:
@@ -29,16 +30,18 @@ class PROTO {
 
         const std::string GetStrError() const { return errorStr; }
 
 
         const std::string GetStrError() const { return errorStr; }
 
-        bool Connect(const std::string & login);
-        bool Disconnect(const std::string & login);
+        void AddUser(const USER & user);
+
+        bool Connect(uint32_t ip);
+        bool Disconnect(uint32_t ip);
     private:
     private:
-        int sock;
         BLOWFISH_CTX ctx;
         struct sockaddr_in localAddr;
         struct sockaddr_in serverAddr;
         int timeout;
 
         BLOWFISH_CTX ctx;
         struct sockaddr_in localAddr;
         struct sockaddr_in serverAddr;
         int timeout;
 
-        std::map<std::string, USER> users;
+        std::map<uint32_t, USER> users;
+        std::vector<struct pollfd> pollFds;
 
         bool running;
         bool stopped;
 
         bool running;
         bool stopped;
@@ -53,14 +56,21 @@ class PROTO {
 
         void Run();
         bool RecvPacket();
 
         void Run();
         bool RecvPacket();
-        bool HandlePacket(char * buffer);
-
-        bool CONN_SYN_ACK_Proc(char * buffer);
-        bool ALIVE_SYN_Proc(char * buffer);
-        bool DISCONN_SYN_ACK_Proc(char * buffer);
-        bool FIN_Proc(char * buffer);
-        bool INFO_Proc(char * buffer);
-        bool ERR_Proc(char * buffer);
+        bool SendPacket(const void * buffer, size_t length, USER * user);
+        bool HandlePacket(const char * buffer, USER * user);
+
+        bool CONN_SYN_ACK_Proc(const void * buffer, USER * user);
+        bool ALIVE_SYN_Proc(const void * buffer, USER * user);
+        bool DISCONN_SYN_ACK_Proc(const void * buffer, USER * user);
+        bool FIN_Proc(const void * buffer, USER * user);
+        bool INFO_Proc(const void * buffer, USER * user);
+        bool ERR_Proc(const void * buffer, USER * user);
+
+        bool Send_CONN_SYN(USER * user);
+        bool Send_CONN_ACK(USER * user);
+        bool Send_DISCONN_SYN(USER * user);
+        bool Send_DISCONN_ACK(USER * user);
+        bool Send_ALIVE_ACK(USER * user);
 };
 
 #endif
 };
 
 #endif