]> git.stg.codes - stg.git/blobdiff - projects/rlm_stg/conn.cpp
Merge branch 'stg-2.409-radius'
[stg.git] / projects / rlm_stg / conn.cpp
index 3ee720d19f2fae0b904b75266db225c2cf9fb3d1..3589a9012d3c7d0a5f93a9c074b3ff5236a33538 100644 (file)
@@ -21,6 +21,7 @@
 #include "conn.h"
 
 #include "radlog.h"
+#include "stgpair.h"
 
 #include "stg/json_parser.h"
 #include "stg/json_generator.h"
@@ -89,6 +90,7 @@ enum Packet
 
 std::map<std::string, Packet> packetCodes;
 std::map<std::string, bool> resultCodes;
+std::map<std::string, int> returnCodes;
 
 class PacketParser : public EnumParser<Packet>
 {
@@ -117,6 +119,26 @@ class ResultParser : public EnumParser<bool>
         }
 };
 
+class ReturnCodeParser : public EnumParser<int>
+{
+    public:
+        ReturnCodeParser(NodeParser* next, int& returnCode, std::string& returnCodeStr)
+            : EnumParser(next, returnCode, returnCodeStr, returnCodes)
+        {
+            if (!returnCodes.empty())
+                return;
+            returnCodes["reject"]   = STG_REJECT;
+            returnCodes["fail"]     = STG_FAIL;
+            returnCodes["ok"]       = STG_OK;
+            returnCodes["handled"]  = STG_HANDLED;
+            returnCodes["invalid"]  = STG_INVALID;
+            returnCodes["userlock"] = STG_USERLOCK;
+            returnCodes["notfound"] = STG_NOTFOUND;
+            returnCodes["noop"]     = STG_NOOP;
+            returnCodes["updated"]  = STG_UPDATED;
+        }
+};
+
 class TopParser : public NodeParser
 {
     public:
@@ -124,8 +146,10 @@ class TopParser : public NodeParser
         TopParser(Callback callback, void* data)
             : m_packet(PING),
               m_result(false),
+              m_returnCode(STG_REJECT),
               m_packetParser(this, m_packet, m_packetStr),
               m_resultParser(this, m_result, m_resultStr),
+              m_returnCodeParser(this, m_returnCode, m_returnCodeStr),
               m_replyParser(this, m_reply),
               m_modifyParser(this, m_modify),
               m_callback(callback), m_data(data)
@@ -144,6 +168,8 @@ class TopParser : public NodeParser
                 return &m_replyParser;
             else if (key == "modify")
                 return &m_modifyParser;
+            else if (key == "return_code")
+                return &m_returnCodeParser;
 
             return this;
         }
@@ -153,6 +179,8 @@ class TopParser : public NodeParser
         Packet packet() const { return m_packet; }
         const std::string& resultStr() const { return m_resultStr; }
         bool result() const { return m_result; }
+        const std::string& returnCodeStr() const { return m_returnCodeStr; }
+        int returnCode() const { return m_returnCode; }
         const PairsParser::Pairs& reply() const { return m_reply; }
         const PairsParser::Pairs& modify() const { return m_modify; }
 
@@ -161,11 +189,14 @@ class TopParser : public NodeParser
         Packet m_packet;
         std::string m_resultStr;
         bool m_result;
+        std::string m_returnCodeStr;
+        int m_returnCode;
         PairsParser::Pairs m_reply;
         PairsParser::Pairs m_modify;
 
         PacketParser m_packetParser;
         ResultParser m_resultParser;
+        ReturnCodeParser m_returnCodeParser;
         PairsParser m_replyParser;
         PairsParser m_modifyParser;
 
@@ -185,6 +216,8 @@ class ProtoParser : public Parser
         Packet packet() const { return m_topParser.packet(); }
         const std::string& resultStr() const { return m_topParser.resultStr(); }
         bool result() const { return m_topParser.result(); }
+        const std::string& returnCodeStr() const { return m_topParser.returnCodeStr(); }
+        int returnCode() const { return m_topParser.returnCode(); }
         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
 
@@ -261,6 +294,8 @@ private:
 
     void runImpl();
 
+    bool start();
+
     int connect();
     int connectTCP();
     int connectUNIX();
@@ -348,11 +383,7 @@ Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
       m_parser(&Conn::Impl::process, this),
       m_connected(true)
 {
-    RadLog("Created connection.");
     pthread_mutex_init(&m_mutex, NULL);
-    int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
-    if (res != 0)
-        throw Error("Failed to create thread: " + std::string(strerror(errno)));
 }
 
 Conn::Impl::~Impl()
@@ -361,7 +392,6 @@ Conn::Impl::~Impl()
     shutdown(m_sock, SHUT_RDWR);
     close(m_sock);
     pthread_mutex_destroy(&m_mutex);
-    RadLog("Deleted connection.");
 }
 
 bool Conn::Impl::stop()
@@ -388,6 +418,9 @@ bool Conn::Impl::stop()
 
 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
+    if (!m_running)
+        if (!start())
+            return false;
     MapGen map;
     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
         map.add(it->first, new StringGen(it->second));
@@ -409,8 +442,6 @@ void Conn::Impl::runImpl()
 {
     m_running = true;
 
-    RadLog("Run connection.");
-
     while (m_running) {
         fd_set fds;
 
@@ -421,9 +452,7 @@ void Conn::Impl::runImpl()
         tv.tv_sec = 0;
         tv.tv_usec = 500000;
 
-        RadLog("Starting 'select'.");
         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
-        RadLog("'select' result: %d.", res);
         if (res < 0)
         {
             if (errno == EINTR)
@@ -440,21 +469,25 @@ void Conn::Impl::runImpl()
 
         if (res > 0)
         {
-            RadLog("Got %d fds.", res);
             if (FD_ISSET(m_sock, &fds))
                 m_running = read();
-            RadLog("Read complete.");
         }
         else
             m_running = tick();
     }
 
-    RadLog("End running connection.");
-
     m_connected = false;
     m_stopped = true;
 }
 
+bool Conn::Impl::start()
+{
+    int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
+    if (res != 0)
+        return false;
+    return true;
+}
+
 int Conn::Impl::connect()
 {
     if (m_config.transport == "tcp")
@@ -584,25 +617,28 @@ void Conn::Impl::process(void* data)
 
 void Conn::Impl::processPing()
 {
-    RadLog("Got ping, sending pong.");
     sendPong();
 }
 
 void Conn::Impl::processPong()
 {
-    RadLog("Got pong.");
     m_lastActivity = time(NULL);
 }
 
 void Conn::Impl::processData()
 {
     RESULT data;
-    RadLog("Got data.");
-    for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
-        data.reply.push_back(std::make_pair(it->first, it->second));
-    for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
-        data.modify.push_back(std::make_pair(it->first, it->second));
-    m_callback(m_data, data, m_parser.result());
+    if (m_parser.result())
+    {
+        for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
+            data.reply.push_back(std::make_pair(it->first, it->second));
+        for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
+            data.modify.push_back(std::make_pair(it->first, it->second));
+        data.returnCode = STG_UPDATED;
+    }
+    else
+        data.returnCode = m_parser.returnCode();
+    m_callback(m_data, data);
 }
 
 bool Conn::Impl::sendPing()
@@ -625,9 +661,8 @@ bool Conn::Impl::sendPong()
 
 bool Conn::Impl::write(void* data, const char* buf, size_t size)
 {
-    RadLog("Sending JSON:");
     std::string json(buf, size);
-    RadLog("%s", json.c_str());
+    RadLog("Sending JSON: %s", json.c_str());
     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
     while (size > 0)
     {
@@ -638,7 +673,6 @@ bool Conn::Impl::write(void* data, const char* buf, size_t size)
             RadLog("Failed to write data: %s.", strerror(errno));
             return false;
         }
-        RadLog("Send %d bytes.", res);
         size -= res;
     }
     return true;