]> git.stg.codes - stg.git/blobdiff - projects/rlm_stg/conn.cpp
Merge branch 'stg-2.409-radius'
[stg.git] / projects / rlm_stg / conn.cpp
index 625245189efe4f837b74ab39c97992d82473ff37..3589a9012d3c7d0a5f93a9c074b3ff5236a33538 100644 (file)
@@ -21,6 +21,7 @@
 #include "conn.h"
 
 #include "radlog.h"
+#include "stgpair.h"
 
 #include "stg/json_parser.h"
 #include "stg/json_generator.h"
@@ -55,10 +56,10 @@ double PING_TIMEOUT = 10;
 
 struct ChannelConfig {
     struct Error : std::runtime_error {
-        Error(const std::string& message) : runtime_error(message) {}
+        explicit Error(const std::string& message) : runtime_error(message) {}
     };
 
-    ChannelConfig(std::string address);
+    explicit ChannelConfig(std::string address);
 
     std::string transport;
     std::string key;
@@ -89,6 +90,7 @@ enum Packet
 
 std::map<std::string, Packet> packetCodes;
 std::map<std::string, bool> resultCodes;
+std::map<std::string, int> returnCodes;
 
 class PacketParser : public EnumParser<Packet>
 {
@@ -117,6 +119,26 @@ class ResultParser : public EnumParser<bool>
         }
 };
 
+class ReturnCodeParser : public EnumParser<int>
+{
+    public:
+        ReturnCodeParser(NodeParser* next, int& returnCode, std::string& returnCodeStr)
+            : EnumParser(next, returnCode, returnCodeStr, returnCodes)
+        {
+            if (!returnCodes.empty())
+                return;
+            returnCodes["reject"]   = STG_REJECT;
+            returnCodes["fail"]     = STG_FAIL;
+            returnCodes["ok"]       = STG_OK;
+            returnCodes["handled"]  = STG_HANDLED;
+            returnCodes["invalid"]  = STG_INVALID;
+            returnCodes["userlock"] = STG_USERLOCK;
+            returnCodes["notfound"] = STG_NOTFOUND;
+            returnCodes["noop"]     = STG_NOOP;
+            returnCodes["updated"]  = STG_UPDATED;
+        }
+};
+
 class TopParser : public NodeParser
 {
     public:
@@ -124,8 +146,10 @@ class TopParser : public NodeParser
         TopParser(Callback callback, void* data)
             : m_packet(PING),
               m_result(false),
+              m_returnCode(STG_REJECT),
               m_packetParser(this, m_packet, m_packetStr),
               m_resultParser(this, m_result, m_resultStr),
+              m_returnCodeParser(this, m_returnCode, m_returnCodeStr),
               m_replyParser(this, m_reply),
               m_modifyParser(this, m_modify),
               m_callback(callback), m_data(data)
@@ -144,6 +168,8 @@ class TopParser : public NodeParser
                 return &m_replyParser;
             else if (key == "modify")
                 return &m_modifyParser;
+            else if (key == "return_code")
+                return &m_returnCodeParser;
 
             return this;
         }
@@ -153,6 +179,8 @@ class TopParser : public NodeParser
         Packet packet() const { return m_packet; }
         const std::string& resultStr() const { return m_resultStr; }
         bool result() const { return m_result; }
+        const std::string& returnCodeStr() const { return m_returnCodeStr; }
+        int returnCode() const { return m_returnCode; }
         const PairsParser::Pairs& reply() const { return m_reply; }
         const PairsParser::Pairs& modify() const { return m_modify; }
 
@@ -161,11 +189,14 @@ class TopParser : public NodeParser
         Packet m_packet;
         std::string m_resultStr;
         bool m_result;
+        std::string m_returnCodeStr;
+        int m_returnCode;
         PairsParser::Pairs m_reply;
         PairsParser::Pairs m_modify;
 
         PacketParser m_packetParser;
         ResultParser m_resultParser;
+        ReturnCodeParser m_returnCodeParser;
         PairsParser m_replyParser;
         PairsParser m_modifyParser;
 
@@ -185,6 +216,8 @@ class ProtoParser : public Parser
         Packet packet() const { return m_topParser.packet(); }
         const std::string& resultStr() const { return m_topParser.resultStr(); }
         bool result() const { return m_topParser.result(); }
+        const std::string& returnCodeStr() const { return m_topParser.returnCodeStr(); }
+        int returnCode() const { return m_topParser.returnCode(); }
         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
 
@@ -195,7 +228,7 @@ class ProtoParser : public Parser
 class PacketGen : public Gen
 {
     public:
-        PacketGen(const std::string& type)
+        explicit PacketGen(const std::string& type)
             : m_type(type)
         {
             m_gen.add("packet", m_type);
@@ -261,6 +294,8 @@ private:
 
     void runImpl();
 
+    bool start();
+
     int connect();
     int connectTCP();
     int connectUNIX();
@@ -312,6 +347,30 @@ ChannelConfig::ChannelConfig(std::string addr)
         throw Error("Invalid port value.");
 }
 
+Conn::Conn(const std::string& address, Callback callback, void* data)
+    : m_impl(new Impl(address, callback, data))
+{
+}
+
+Conn::~Conn()
+{
+}
+
+bool Conn::stop()
+{
+    return m_impl->stop();
+}
+
+bool Conn::connected() const
+{
+    return m_impl->connected();
+}
+
+bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
+{
+    return m_impl->request(type, userName, password, pairs);
+}
+
 Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
     : m_config(address),
       m_sock(connect()),
@@ -325,9 +384,6 @@ Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
       m_connected(true)
 {
     pthread_mutex_init(&m_mutex, NULL);
-    int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
-    if (res != 0)
-        throw Error("Failed to create thread: " + std::string(strerror(errno)));
 }
 
 Conn::Impl::~Impl()
@@ -362,6 +418,9 @@ bool Conn::Impl::stop()
 
 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
+    if (!m_running)
+        if (!start())
+            return false;
     MapGen map;
     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
         map.add(it->first, new StringGen(it->second));
@@ -402,6 +461,7 @@ void Conn::Impl::runImpl()
             break;
         }
 
+
         if (!m_running)
             break;
 
@@ -420,6 +480,14 @@ void Conn::Impl::runImpl()
     m_stopped = true;
 }
 
+bool Conn::Impl::start()
+{
+    int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
+    if (res != 0)
+        return false;
+    return true;
+}
+
 int Conn::Impl::connect()
 {
     if (m_config.transport == "tcp")
@@ -549,25 +617,28 @@ void Conn::Impl::process(void* data)
 
 void Conn::Impl::processPing()
 {
-    RadLog("Got ping, sending pong.");
     sendPong();
 }
 
 void Conn::Impl::processPong()
 {
-    RadLog("Got pong.");
     m_lastActivity = time(NULL);
 }
 
 void Conn::Impl::processData()
 {
     RESULT data;
-    RadLog("Got data.");
-    for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
-        data.reply.push_back(std::make_pair(it->first, it->second));
-    for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
-        data.modify.push_back(std::make_pair(it->first, it->second));
-    m_callback(m_data, data, m_parser.result());
+    if (m_parser.result())
+    {
+        for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
+            data.reply.push_back(std::make_pair(it->first, it->second));
+        for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
+            data.modify.push_back(std::make_pair(it->first, it->second));
+        data.returnCode = STG_UPDATED;
+    }
+    else
+        data.returnCode = m_parser.returnCode();
+    m_callback(m_data, data);
 }
 
 bool Conn::Impl::sendPing()
@@ -590,9 +661,8 @@ bool Conn::Impl::sendPong()
 
 bool Conn::Impl::write(void* data, const char* buf, size_t size)
 {
-    RadLog("Sending JSON:");
     std::string json(buf, size);
-    RadLog("%s", json.c_str());
+    RadLog("Sending JSON: %s", json.c_str());
     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
     while (size > 0)
     {