]> git.stg.codes - stg.git/commitdiff
Use received return code on no-match in rlm_stg.
authorMaxim Mamontov <faust.madf@gmail.com>
Wed, 7 Oct 2015 19:27:33 +0000 (22:27 +0300)
committerMaxim Mamontov <faust.madf@gmail.com>
Wed, 7 Oct 2015 19:27:33 +0000 (22:27 +0300)
projects/rlm_stg/conn.cpp
projects/rlm_stg/conn.h
projects/rlm_stg/iface.cpp
projects/rlm_stg/rlm_stg.c
projects/rlm_stg/stg_client.cpp
projects/rlm_stg/stgpair.h
projects/rlm_stg/types.h

index 13a9cea3294285022aaf1a93f25e797997215b41..3589a9012d3c7d0a5f93a9c074b3ff5236a33538 100644 (file)
@@ -21,6 +21,7 @@
 #include "conn.h"
 
 #include "radlog.h"
+#include "stgpair.h"
 
 #include "stg/json_parser.h"
 #include "stg/json_generator.h"
@@ -89,6 +90,7 @@ enum Packet
 
 std::map<std::string, Packet> packetCodes;
 std::map<std::string, bool> resultCodes;
+std::map<std::string, int> returnCodes;
 
 class PacketParser : public EnumParser<Packet>
 {
@@ -117,6 +119,26 @@ class ResultParser : public EnumParser<bool>
         }
 };
 
+class ReturnCodeParser : public EnumParser<int>
+{
+    public:
+        ReturnCodeParser(NodeParser* next, int& returnCode, std::string& returnCodeStr)
+            : EnumParser(next, returnCode, returnCodeStr, returnCodes)
+        {
+            if (!returnCodes.empty())
+                return;
+            returnCodes["reject"]   = STG_REJECT;
+            returnCodes["fail"]     = STG_FAIL;
+            returnCodes["ok"]       = STG_OK;
+            returnCodes["handled"]  = STG_HANDLED;
+            returnCodes["invalid"]  = STG_INVALID;
+            returnCodes["userlock"] = STG_USERLOCK;
+            returnCodes["notfound"] = STG_NOTFOUND;
+            returnCodes["noop"]     = STG_NOOP;
+            returnCodes["updated"]  = STG_UPDATED;
+        }
+};
+
 class TopParser : public NodeParser
 {
     public:
@@ -124,8 +146,10 @@ class TopParser : public NodeParser
         TopParser(Callback callback, void* data)
             : m_packet(PING),
               m_result(false),
+              m_returnCode(STG_REJECT),
               m_packetParser(this, m_packet, m_packetStr),
               m_resultParser(this, m_result, m_resultStr),
+              m_returnCodeParser(this, m_returnCode, m_returnCodeStr),
               m_replyParser(this, m_reply),
               m_modifyParser(this, m_modify),
               m_callback(callback), m_data(data)
@@ -144,6 +168,8 @@ class TopParser : public NodeParser
                 return &m_replyParser;
             else if (key == "modify")
                 return &m_modifyParser;
+            else if (key == "return_code")
+                return &m_returnCodeParser;
 
             return this;
         }
@@ -153,6 +179,8 @@ class TopParser : public NodeParser
         Packet packet() const { return m_packet; }
         const std::string& resultStr() const { return m_resultStr; }
         bool result() const { return m_result; }
+        const std::string& returnCodeStr() const { return m_returnCodeStr; }
+        int returnCode() const { return m_returnCode; }
         const PairsParser::Pairs& reply() const { return m_reply; }
         const PairsParser::Pairs& modify() const { return m_modify; }
 
@@ -161,11 +189,14 @@ class TopParser : public NodeParser
         Packet m_packet;
         std::string m_resultStr;
         bool m_result;
+        std::string m_returnCodeStr;
+        int m_returnCode;
         PairsParser::Pairs m_reply;
         PairsParser::Pairs m_modify;
 
         PacketParser m_packetParser;
         ResultParser m_resultParser;
+        ReturnCodeParser m_returnCodeParser;
         PairsParser m_replyParser;
         PairsParser m_modifyParser;
 
@@ -185,6 +216,8 @@ class ProtoParser : public Parser
         Packet packet() const { return m_topParser.packet(); }
         const std::string& resultStr() const { return m_topParser.resultStr(); }
         bool result() const { return m_topParser.result(); }
+        const std::string& returnCodeStr() const { return m_topParser.returnCodeStr(); }
+        int returnCode() const { return m_topParser.returnCode(); }
         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
 
@@ -595,11 +628,17 @@ void Conn::Impl::processPong()
 void Conn::Impl::processData()
 {
     RESULT data;
-    for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
-        data.reply.push_back(std::make_pair(it->first, it->second));
-    for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
-        data.modify.push_back(std::make_pair(it->first, it->second));
-    m_callback(m_data, data, m_parser.result());
+    if (m_parser.result())
+    {
+        for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
+            data.reply.push_back(std::make_pair(it->first, it->second));
+        for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
+            data.modify.push_back(std::make_pair(it->first, it->second));
+        data.returnCode = STG_UPDATED;
+    }
+    else
+        data.returnCode = m_parser.returnCode();
+    m_callback(m_data, data);
 }
 
 bool Conn::Impl::sendPing()
index cecc080270bcc0c4e5b2163332c45d718fb08e3b..6233b15414c9f2c425e453be310ada42430a2078 100644 (file)
@@ -42,7 +42,7 @@ class Conn
             explicit Error(const std::string& message) : runtime_error(message) {}
         };
 
-        typedef bool (*Callback)(void* /*data*/, const RESULT& /*result*/, bool /*status*/);
+        typedef bool (*Callback)(void* /*data*/, const RESULT& /*result*/);
 
         Conn(const std::string& address, Callback callback, void* data);
         ~Conn();
index 2f8f3e3794072fc72ee523e2146247cab5debde1..f97593f476a41ab0829523639f723ce9e3a55831 100644 (file)
@@ -52,12 +52,13 @@ STG_RESULT toResult(const RESULT& source)
     STG_RESULT result;
     result.modify = toSTGPairs(source.modify);
     result.reply = toSTGPairs(source.reply);
+    result.returnCode = source.returnCode;
     return result;
 }
 
 STG_RESULT emptyResult()
 {
-    STG_RESULT result = {NULL, NULL};
+    STG_RESULT result = {NULL, NULL, STG_REJECT};
     return result;
 }
 
index 20d46a59ee2e765b35e9ce8ff5accf5480cb6f45..506545934646fce9f9bf3df0b1531dd91d0456c1 100644 (file)
@@ -107,6 +107,23 @@ static STG_PAIR* fromVPS(const VALUE_PAIR* pairs)
     return res;
 }
 
+static int toRLMCode(int code)
+{
+    switch (code)
+    {
+        case STG_REJECT:   return RLM_MODULE_REJECT;
+        case STG_FAIL:     return RLM_MODULE_FAIL;
+        case STG_OK:       return RLM_MODULE_OK;
+        case STG_HANDLED:  return RLM_MODULE_HANDLED;
+        case STG_INVALID:  return RLM_MODULE_INVALID;
+        case STG_USERLOCK: return RLM_MODULE_USERLOCK;
+        case STG_NOTFOUND: return RLM_MODULE_NOTFOUND;
+        case STG_NOOP:     return RLM_MODULE_NOOP;
+        case STG_UPDATED:  return RLM_MODULE_UPDATED;
+    }
+    return RLM_MODULE_REJECT;
+}
+
 /*
  *    Do any per-module initialization that is separate to each
  *    configured instance of the module.  e.g. set up connections
@@ -190,7 +207,7 @@ static int stg_authorize(void* instance, REQUEST* request)
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
 /*
@@ -231,7 +248,7 @@ static int stg_authenticate(void* instance, REQUEST* request)
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
 /*
@@ -272,7 +289,7 @@ static int stg_preacct(void* instance, REQUEST* request)
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
 /*
@@ -313,7 +330,7 @@ static int stg_accounting(void* instance, REQUEST* request)
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_OK;
+    return toRLMCode(result.returnCode);
 }
 
 /*
@@ -372,7 +389,7 @@ static int stg_postauth(void* instance, REQUEST* request)
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
 static int stg_detach(void* instance)
index 239770deffdb068a77e83ba36c2dc2ddb193cc98..e34c50cdbaf30137125e29d50c55b1d338c47336 100644 (file)
@@ -57,14 +57,12 @@ class Client::Impl
         pthread_cond_t m_cond;
         bool m_done;
         RESULT m_result;
-        bool m_status;
 
-        static bool callback(void* data, const RESULT& result, bool status)
+        static bool callback(void* data, const RESULT& result)
         {
             Impl& impl = *static_cast<Impl*>(data);
             STG_LOCKER lock(impl.m_mutex);
             impl.m_result = result;
-            impl.m_status = status;
             impl.m_done = true;
             pthread_cond_signal(&impl.m_cond);
             return true;
@@ -109,7 +107,7 @@ RESULT Client::Impl::request(REQUEST_TYPE type, const std::string& userName, con
     int res = 0;
     while (!m_done && res == 0)
         res = pthread_cond_timedwait(&m_cond, &m_mutex, &ts);
-    if (res != 0 || !m_status)
+    if (res != 0)
         throw Conn::Error("Request failed.");
     return m_result;
 }
index e82c667234ef09194d6efa57087b22dc9db70e5f..ef7ab4b778897fdd3e738e36346b70a7ebe7633a 100644 (file)
@@ -18,6 +18,7 @@ typedef struct STG_PAIR {
 typedef struct STG_RESULT {
     STG_PAIR* modify;
     STG_PAIR* reply;
+    int returnCode;
 } STG_RESULT;
 
 inline
@@ -26,6 +27,19 @@ int emptyPair(const STG_PAIR* pair)
     return pair == NULL || pair->key[0] == '\0' || pair->value[0] == '\0';
 }
 
+enum
+{
+    STG_REJECT,
+    STG_FAIL,
+    STG_OK,
+    STG_HANDLED,
+    STG_INVALID,
+    STG_USERLOCK,
+    STG_NOTFOUND,
+    STG_NOOP,
+    STG_UPDATED
+};
+
 #ifdef __cplusplus
 }
 #endif
index d98ddd46853b418872d7ebefcb59c3e462f8c411..2bc721fc2272bf0c37378cc89ed1f1f9c2f96508 100644 (file)
@@ -35,6 +35,7 @@ struct RESULT
 {
     PAIRS modify;
     PAIRS reply;
+    int returnCode;
 };
 
 enum REQUEST_TYPE {