]> git.stg.codes - stg.git/blobdiff - projects/stargazer/plugins/other/radius/conn.cpp
Allowed to pass mutex by reference into STG_LOCKER.
[stg.git] / projects / stargazer / plugins / other / radius / conn.cpp
index 392624a6ef4846dd8d4c424baca068638e4ee4a2..c0270e7834303e2a7146b3cee37543ab565463e1 100644 (file)
 #include <yajl/yajl_gen.h>
 
 #include <map>
 #include <yajl/yajl_gen.h>
 
 #include <map>
+#include <stdexcept>
 #include <cstring>
 #include <cerrno>
 
 #include <unistd.h>
 #include <cstring>
 #include <cerrno>
 
 #include <unistd.h>
+#include <sys/types.h>
+#include <sys/socket.h>
 
 using STG::Conn;
 using STG::Config;
 
 using STG::Conn;
 using STG::Config;
@@ -50,8 +53,8 @@ using STG::JSON::StringGen;
 namespace
 {
 
 namespace
 {
 
-double CONN_TIMEOUT = 5;
-double PING_TIMEOUT = 1;
+double CONN_TIMEOUT = 60;
+double PING_TIMEOUT = 10;
 
 enum Packet
 {
 
 enum Packet
 {
@@ -105,10 +108,12 @@ class StageParser : public EnumParser<Stage>
 class TopParser : public NodeParser
 {
     public:
 class TopParser : public NodeParser
 {
     public:
-        TopParser()
+        typedef void (*Callback) (void* /*data*/);
+        TopParser(Callback callback, void* data)
             : m_packetParser(this, m_packet, m_packetStr),
               m_stageParser(this, m_stage, m_stageStr),
             : m_packetParser(this, m_packet, m_packetStr),
               m_stageParser(this, m_stage, m_stageStr),
-              m_pairsParser(this, m_data)
+              m_pairsParser(this, m_data),
+              m_callback(callback), m_callbackData(data)
         {}
 
         virtual NodeParser* parseStartMap() { return this; }
         {}
 
         virtual NodeParser* parseStartMap() { return this; }
@@ -125,7 +130,7 @@ class TopParser : public NodeParser
 
             return this;
         }
 
             return this;
         }
-        virtual NodeParser* parseEndMap() { return this; }
+        virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
 
         const std::string& packetStr() const { return m_packetStr; }
         Packet packet() const { return m_packet; }
 
         const std::string& packetStr() const { return m_packetStr; }
         Packet packet() const { return m_packet; }
@@ -143,12 +148,18 @@ class TopParser : public NodeParser
         PacketParser m_packetParser;
         StageParser m_stageParser;
         PairsParser m_pairsParser;
         PacketParser m_packetParser;
         StageParser m_stageParser;
         PairsParser m_pairsParser;
+
+        Callback m_callback;
+        void* m_callbackData;
 };
 
 class ProtoParser : public Parser
 {
     public:
 };
 
 class ProtoParser : public Parser
 {
     public:
-        ProtoParser() : Parser( &m_topParser ) {}
+        ProtoParser(TopParser::Callback callback, void* data)
+            : Parser( &m_topParser ),
+              m_topParser(callback, data)
+        {}
 
         const std::string& packetStr() const { return m_topParser.packetStr(); }
         Packet packet() const { return m_topParser.packet(); }
 
         const std::string& packetStr() const { return m_topParser.packetStr(); }
         Packet packet() const { return m_topParser.packet(); }
@@ -182,6 +193,11 @@ class PacketGen : public Gen
             m_gen.add(key, map);
             return *this;
         }
             m_gen.add(key, map);
             return *this;
         }
+        PacketGen& add(const std::string& key, MapGen& map)
+        {
+            m_gen.add(key, map);
+            return *this;
+        }
     private:
         MapGen m_gen;
         StringGen m_type;
     private:
         MapGen m_gen;
         StringGen m_type;
@@ -213,10 +229,27 @@ class Conn::Impl
         time_t m_lastActivity;
         ProtoParser m_parser;
 
         time_t m_lastActivity;
         ProtoParser m_parser;
 
-        bool process();
-        bool processPing();
-        bool processPong();
-        bool processData();
+        const Config::Pairs& stagePairs(Config::Pairs Config::Section::* pairs) const
+        {
+            switch (m_parser.stage())
+            {
+                case AUTHORIZE: return m_config.autz.*pairs;
+                case AUTHENTICATE: return m_config.auth.*pairs;
+                case POSTAUTH: return m_config.postauth.*pairs;
+                case PREACCT: return m_config.preacct.*pairs;
+                case ACCOUNTING: return m_config.acct.*pairs;
+            }
+            throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
+        }
+
+        const Config::Pairs& match() const { return stagePairs(&Config::Section::match); }
+        const Config::Pairs& modify() const { return stagePairs(&Config::Section::modify); }
+        const Config::Pairs& reply() const { return stagePairs(&Config::Section::reply); }
+
+        static void process(void* data);
+        void processPing();
+        void processPong();
+        void processData();
         bool answer(const USER& user);
         bool answerNo();
         bool sendPing();
         bool answer(const USER& user);
         bool answerNo();
         bool sendPing();
@@ -262,7 +295,8 @@ Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int
       m_remote(remote),
       m_ok(true),
       m_lastPing(time(NULL)),
       m_remote(remote),
       m_ok(true),
       m_lastPing(time(NULL)),
-      m_lastActivity(m_lastPing)
+      m_lastActivity(m_lastPing),
+      m_parser(&Conn::Impl::process, this)
 {
 }
 
 {
 }
 
@@ -281,16 +315,12 @@ bool Conn::Impl::read()
         m_ok = false;
         return false;
     }
         m_ok = false;
         return false;
     }
+    printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
     m_lastActivity = time(NULL);
     if (res == 0)
     {
     m_lastActivity = time(NULL);
     if (res == 0)
     {
-        if (!m_parser.done())
-        {
-            m_ok = false;
-            m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
-            return false;
-        }
-        return process();
+        m_ok = false;
+        return true;
     }
     return m_parser.append(buffer.data(), res);
 }
     }
     return m_parser.append(buffer.data(), res);
 }
@@ -300,95 +330,114 @@ bool Conn::Impl::tick()
     time_t now = time(NULL);
     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
     {
     time_t now = time(NULL);
     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
     {
+        int delta = difftime(now, m_lastActivity);
+        printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
         m_logger("Connection to " + m_remote + " timed out.");
         m_ok = false;
         return false;
     }
     if (difftime(now, m_lastPing) > PING_TIMEOUT)
         m_logger("Connection to " + m_remote + " timed out.");
         m_ok = false;
         return false;
     }
     if (difftime(now, m_lastPing) > PING_TIMEOUT)
+    {
+        int delta = difftime(now, m_lastPing);
+        printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
         sendPing();
         sendPing();
+    }
     return true;
 }
 
     return true;
 }
 
-bool Conn::Impl::process()
+void Conn::Impl::process(void* data)
 {
 {
-    switch (m_parser.packet())
+    Impl& impl = *static_cast<Impl*>(data);
+    try
+    {
+        switch (impl.m_parser.packet())
+        {
+            case PING:
+                impl.processPing();
+                return;
+            case PONG:
+                impl.processPong();
+                return;
+            case DATA:
+                impl.processData();
+                return;
+        }
+    }
+    catch (const std::exception& ex)
     {
     {
-        case PING:
-            return processPing();
-        case PONG:
-            return processPong();
-        case DATA:
-            return processData();
+        printfd(__FILE__, "Processing error. %s", ex.what());
+        impl.m_logger("Processing error. %s", ex.what());
     }
     }
-    m_logger("Received invalid packet type: " + m_parser.packetStr());
-    return false;
+    printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
+    impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
 }
 
 }
 
-bool Conn::Impl::processPing()
+void Conn::Impl::processPing()
 {
 {
-    return sendPong();
+    printfd(__FILE__, "Got ping. Sending pong...\n");
+    sendPong();
 }
 
 }
 
-bool Conn::Impl::processPong()
+void Conn::Impl::processPong()
 {
 {
+    printfd(__FILE__, "Got pong.\n");
     m_lastActivity = time(NULL);
     m_lastActivity = time(NULL);
-    return true;
 }
 
 }
 
-bool Conn::Impl::processData()
+void Conn::Impl::processData()
 {
 {
+    printfd(__FILE__, "Got data.\n");
     int handle = m_users.OpenSearch();
 
     USER_PTR user = NULL;
     int handle = m_users.OpenSearch();
 
     USER_PTR user = NULL;
-    bool match = true;
-    while (m_users.SearchNext(handle, &user))
+    bool matched = false;
+    while (m_users.SearchNext(handle, &user) == 0)
     {
         if (user == NULL)
             continue;
 
     {
         if (user == NULL)
             continue;
 
-        match = true;
-        for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
+        matched = true;
+        for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
         {
             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
             if (pos == m_parser.data().end())
             {
         {
             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
             if (pos == m_parser.data().end())
             {
-                match = false;
+                matched = false;
                 break;
             }
             if (user->GetParamValue(it->second) != pos->second)
             {
                 break;
             }
             if (user->GetParamValue(it->second) != pos->second)
             {
-                match = false;
+                matched = false;
                 break;
             }
         }
                 break;
             }
         }
-        if (!match)
+        if (!matched)
             continue;
         answer(*user);
         break;
     }
 
             continue;
         answer(*user);
         break;
     }
 
-    if (!match)
+    if (!matched)
         answerNo();
 
     m_users.CloseSearch(handle);
         answerNo();
 
     m_users.CloseSearch(handle);
-
-    return true;
 }
 
 bool Conn::Impl::answer(const USER& user)
 {
 }
 
 bool Conn::Impl::answer(const USER& user)
 {
-    boost::scoped_ptr<MapGen> reply(new MapGen);
-    for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
-        reply->add(it->first, new StringGen(user.GetParamValue(it->second)));
+    printfd(__FILE__, "Got match. Sending answer...\n");
+    MapGen replyData;
+    for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
+        replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
 
 
-    boost::scoped_ptr<MapGen> modify(new MapGen);
-    for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
-        modify->add(it->first, new StringGen(user.GetParamValue(it->second)));
+    MapGen modifyData;
+    for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
+        modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
 
     PacketGen gen("data");
     gen.add("result", "ok")
 
     PacketGen gen("data");
     gen.add("result", "ok")
-       .add("reply", reply.get())
-       .add("modify", modify.get());
+       .add("reply", replyData)
+       .add("modify", modifyData);
 
     m_lastPing = time(NULL);
 
 
     m_lastPing = time(NULL);
 
@@ -397,8 +446,9 @@ bool Conn::Impl::answer(const USER& user)
 
 bool Conn::Impl::answerNo()
 {
 
 bool Conn::Impl::answerNo()
 {
+    printfd(__FILE__, "No match. Sending answer...\n");
     PacketGen gen("data");
     PacketGen gen("data");
-    gen.add("result", "ok");
+    gen.add("result", "no");
 
     m_lastPing = time(NULL);
 
 
     m_lastPing = time(NULL);
 
@@ -425,10 +475,12 @@ bool Conn::Impl::sendPong()
 
 bool Conn::Impl::write(void* data, const char* buf, size_t size)
 {
 
 bool Conn::Impl::write(void* data, const char* buf, size_t size)
 {
+    std::string json(buf, size);
+    printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
     while (size > 0)
     {
     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
     while (size > 0)
     {
-        ssize_t res = ::write(conn.m_sock, buf, size);
+        ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
         if (res < 0)
         {
             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
         if (res < 0)
         {
             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));