]> git.stg.codes - stg.git/blobdiff - projects/sgauthstress/proto.cpp
Merge branch 'stg-2.409' into stg-2.409-radius
[stg.git] / projects / sgauthstress / proto.cpp
index dd4984caca545bf5d0243aefa067fab8f8ea47ad..d1114f17aa105667e703ae1775870ea3f29a5b0f 100644 (file)
@@ -1,6 +1,10 @@
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
 #include <netdb.h>
 #include <arpa/inet.h>
 
+#include <csignal>
 #include <cerrno>
 #include <cstring>
 #include <cassert>
@@ -77,6 +81,10 @@ pthread_mutex_destroy(&mutex);
 
 void * PROTO::Runner(void * data)
 {
+sigset_t signalSet;
+sigfillset(&signalSet);
+pthread_sigmask(SIG_BLOCK, &signalSet, NULL);
+
 PROTO * protoPtr = static_cast<PROTO *>(data);
 protoPtr->Run();
 return NULL;
@@ -126,7 +134,7 @@ return true;
 
 void PROTO::AddUser(const USER & user, bool connect)
 {
-STG_LOCKER lock(&mutex, __FILE__, __LINE__);
+STG_LOCKER lock(&mutex);
 users.push_back(std::make_pair(user.GetIP(), user));
 users.back().second.InitNetwork();
 
@@ -145,7 +153,7 @@ if (connect)
 bool PROTO::Connect(uint32_t ip)
 {
 std::list<std::pair<uint32_t, USER> >::iterator it;
-STG_LOCKER lock(&mutex, __FILE__, __LINE__);
+STG_LOCKER lock(&mutex);
 it = std::find_if(users.begin(), users.end(), HasIP(ip));
 if (it == users.end())
     return false;
@@ -158,7 +166,7 @@ return RealConnect(&it->second);
 bool PROTO::Disconnect(uint32_t ip)
 {
 std::list<std::pair<uint32_t, USER> >::iterator it;
-STG_LOCKER lock(&mutex, __FILE__, __LINE__);
+STG_LOCKER lock(&mutex);
 it = std::find_if(users.begin(), users.end(), HasIP(ip));
 if (it == users.end())
     return false;
@@ -172,7 +180,11 @@ void PROTO::Run()
 {
 while (running)
     {
-    int res = poll(&pollFds.front(), pollFds.size(), timeout);
+    int res;
+        {
+        STG_LOCKER lock(&mutex);
+        res = poll(&pollFds.front(), pollFds.size(), timeout);
+        }
     if (res < 0)
         break;
     if (!running)
@@ -182,23 +194,48 @@ while (running)
         printfd(__FILE__, "PROTO::Run() - events: %d\n", res);
         RecvPacket();
         }
+    else
+        {
+        CheckTimeouts();
+        }
     }
 
 stopped = true;
 }
 
+void PROTO::CheckTimeouts()
+{
+STG_LOCKER lock(&mutex);
+std::list<std::pair<uint32_t, USER> >::iterator it;
+for (it = users.begin(); it != users.end(); ++it)
+    {
+    int delta = difftime(time(NULL), it->second.GetPhaseChangeTime());
+    if ((it->second.GetPhase() == 3) &&
+        (delta > it->second.GetUserTimeout()))
+        {
+        printfd(__FILE__, "PROTO::CheckTimeouts() - user alive timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetUserTimeout());
+        it->second.SetPhase(1);
+        }
+    if ((it->second.GetPhase() == 2) &&
+        (delta > it->second.GetAliveTimeout()))
+        {
+        printfd(__FILE__, "PROTO::CheckTimeouts() - user connect timeout (ip: %s, login: '%s', delta: %d > %d)\n", inet_ntostring(it->second.GetIP()).c_str(), it->second.GetLogin().c_str(), delta, it->second.GetAliveTimeout());
+        it->second.SetPhase(1);
+        }
+    }
+}
+
 bool PROTO::RecvPacket()
 {
 bool result = true;
 std::vector<struct pollfd>::iterator it;
 std::list<std::pair<uint32_t, USER> >::iterator userIt;
-STG_LOCKER lock(&mutex, __FILE__, __LINE__);
+STG_LOCKER lock(&mutex);
 for (it = pollFds.begin(), userIt = users.begin(); it != pollFds.end() && userIt != users.end(); ++it, ++userIt)
     {
     if (it->revents)
         {
         it->revents = 0;
-        printfd(__FILE__, "PROTO::RecvPacket() - pollfd: %d, socket: %d\n", it->fd, userIt->second.GetSocket());
         assert(it->fd == userIt->second.GetSocket() && "File descriptors from poll fds and users must be syncked");
         struct sockaddr_in addr;
         socklen_t fromLen = sizeof(addr);
@@ -260,6 +297,7 @@ if (user->GetPhase() != 2)
     {
     errorStr = "Unexpected CONN_SYN_ACK";
     printfd(__FILE__, "PROTO::CONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
+    return false;
     }
 
 user->SetPhase(3);
@@ -288,6 +326,7 @@ if (user->GetPhase() != 3)
     {
     errorStr = "Unexpected ALIVE_SYN";
     printfd(__FILE__, "PROTO::ALIVE_SYN_Proc() - wrong phase: %d\n", user->GetPhase());
+    return false;
     }
 
 user->SetPhase(3);
@@ -312,6 +351,7 @@ if (user->GetPhase() != 4)
     {
     errorStr = "Unexpected DISCONN_SYN_ACK";
     printfd(__FILE__, "PROTO::DISCONN_SYN_ACK_Proc() - wrong phase: %d\n", user->GetPhase());
+    return false;
     }
 
 if (user->GetRnd() + 1 != rnd)
@@ -334,6 +374,7 @@ if (user->GetPhase() != 5)
     {
     errorStr = "Unexpected FIN";
     printfd(__FILE__, "PROTO::FIN_Proc() - wrong phase: %d\n", user->GetPhase());
+    return false;
     }
 
 user->SetPhase(1);
@@ -348,11 +389,8 @@ bool PROTO::INFO_Proc(const void * buffer, USER * user)
 return true;
 }
 
-bool PROTO::ERR_Proc(const void * buffer, USER * user)
+bool PROTO::ERR_Proc(const void * /*buffer*/, USER * user)
 {
-const ERR_8 * packet = static_cast<const ERR_8 *>(buffer);
-const char * ptr = static_cast<const char *>(buffer);
-
 //uint32_t len = packet->len;
 
 #ifdef ARCH_BE
@@ -469,9 +507,8 @@ hdr.protoVer[0] = 0;
 hdr.protoVer[1] = 8; // IA_PROTO_VER
 
 unsigned char buffer[2048];
-memset(buffer, 0, sizeof(buffer));
-memcpy(buffer, packet, length);
 memcpy(buffer, &hdr, sizeof(hdr));
+memcpy(buffer + sizeof(hdr), packet, length);
 
 size_t offset = sizeof(HDR_8);
 for (size_t i = 0; i < IA_LOGIN_LEN / 8; i++)