]> git.stg.codes - stg.git/commitdiff
Use separate config sections for each session stage.
authorMaxim Mamontov <faust.madf@gmail.com>
Sat, 5 Sep 2015 12:24:52 +0000 (15:24 +0300)
committerMaxim Mamontov <faust.madf@gmail.com>
Sat, 5 Sep 2015 12:24:52 +0000 (15:24 +0300)
projects/stargazer/plugins/other/radius/config.cpp
projects/stargazer/plugins/other/radius/config.h
projects/stargazer/plugins/other/radius/conn.cpp
projects/stargazer/plugins/other/radius/conn.h

index 23339c0267e260951d265211916f8b442d36d01a..8e9f27a9fab9d99121235efb67627b040b69e5dc 100644 (file)
@@ -137,39 +137,30 @@ T toInt(const std::vector<std::string>& values)
     return 0;
 }
 
-Config::Pairs parseVector(const std::string& paramName, const MODULE_SETTINGS& params)
+Config::Pairs parseVector(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
 {
-    for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].param == paramName)
-            return toPairs(params.moduleParams[i].value);
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toPairs(params[i].value);
     return Config::Pairs();
 }
 
-bool parseBool(const std::string& paramName, const MODULE_SETTINGS& params)
+bool parseBool(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
 {
-    for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].param == paramName)
-            return toBool(params.moduleParams[i].value);
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toBool(params[i].value);
     return false;
 }
 
-std::string parseString(const std::string& paramName, const MODULE_SETTINGS& params)
+std::string parseString(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
 {
-    for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].param == paramName)
-            return toString(params.moduleParams[i].value);
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toString(params[i].value);
     return "";
 }
 
-template <typename T>
-T parseInt(const std::string& paramName, const MODULE_SETTINGS& params)
-{
-    for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].param == paramName)
-            return toInt<T>(params.moduleParams[i].value);
-    return 0;
-}
-
 std::string parseAddress(const std::string& address)
 {
     size_t pos = address.find_first_of(':');
@@ -191,18 +182,28 @@ Config::Type parseConnectionType(const std::string& address)
     throw ParserError(0, "Invalid connection type. Should be either 'unix' or 'tcp', got '" + type + "'");
 }
 
+Config::Section parseSection(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return Config::Section(parseVector("match", params[i].sections),
+                                   parseVector("modify", params[i].sections),
+                                   parseVector("reply", params[i].sections));
+    return Config::Section();
+}
+
 } // namespace anonymous
 
 Config::Config(const MODULE_SETTINGS& settings)
-    : match(parseVector("match", settings)),
-      modify(parseVector("modify", settings)),
-      reply(parseVector("reply", settings)),
-      verbose(parseBool("verbose", settings)),
-      address(parseString("bind_address", settings)),
+    : autz(parseSection("autz", settings.moduleParams)),
+      auth(parseSection("auth", settings.moduleParams)),
+      postauth(parseSection("postauth", settings.moduleParams)),
+      preacct(parseSection("preacct", settings.moduleParams)),
+      acct(parseSection("acct", settings.moduleParams)),
+      verbose(parseBool("verbose", settings.moduleParams)),
+      address(parseString("bind_address", settings.moduleParams)),
       bindAddress(parseAddress(address)),
       connectionType(parseConnectionType(address)),
-      portStr(parseString("port", settings)),
-      port(parseInt<uint16_t>("port", settings)),
-      key(parseString("key", settings))
+      key(parseString("key", settings.moduleParams))
 {
 }
index 28da964e930320269be5e507e7bbee01bfbe3391..45ee521fdc9c3e151d8c677a59b3438751df917a 100644 (file)
@@ -37,12 +37,24 @@ struct Config
     typedef std::pair<std::string, std::string> Pair;
     enum Type { UNIX, TCP };
 
+    struct Section
+    {
+        Section() {}
+        Section(const Pairs& ma, const Pairs& mo, const Pairs& re)
+            : match(ma), modify(mo), reply(re) {}
+        Pairs match;
+        Pairs modify;
+        Pairs reply;
+    };
+
     Config() {}
     Config(const MODULE_SETTINGS& settings);
 
-    Pairs match;
-    Pairs modify;
-    Pairs reply;
+    Section autz;
+    Section auth;
+    Section postauth;
+    Section preacct;
+    Section acct;
 
     bool verbose;
 
index 99aa83b6cd59cc4cf69780809e176e43a1185839..c0270e7834303e2a7146b3cee37543ab565463e1 100644 (file)
@@ -32,6 +32,7 @@
 #include <yajl/yajl_gen.h>
 
 #include <map>
+#include <stdexcept>
 #include <cstring>
 #include <cerrno>
 
@@ -228,6 +229,23 @@ class Conn::Impl
         time_t m_lastActivity;
         ProtoParser m_parser;
 
+        const Config::Pairs& stagePairs(Config::Pairs Config::Section::* pairs) const
+        {
+            switch (m_parser.stage())
+            {
+                case AUTHORIZE: return m_config.autz.*pairs;
+                case AUTHENTICATE: return m_config.auth.*pairs;
+                case POSTAUTH: return m_config.postauth.*pairs;
+                case PREACCT: return m_config.preacct.*pairs;
+                case ACCOUNTING: return m_config.acct.*pairs;
+            }
+            throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
+        }
+
+        const Config::Pairs& match() const { return stagePairs(&Config::Section::match); }
+        const Config::Pairs& modify() const { return stagePairs(&Config::Section::modify); }
+        const Config::Pairs& reply() const { return stagePairs(&Config::Section::reply); }
+
         static void process(void* data);
         void processPing();
         void processPong();
@@ -330,17 +348,25 @@ bool Conn::Impl::tick()
 void Conn::Impl::process(void* data)
 {
     Impl& impl = *static_cast<Impl*>(data);
-    switch (impl.m_parser.packet())
+    try
+    {
+        switch (impl.m_parser.packet())
+        {
+            case PING:
+                impl.processPing();
+                return;
+            case PONG:
+                impl.processPong();
+                return;
+            case DATA:
+                impl.processData();
+                return;
+        }
+    }
+    catch (const std::exception& ex)
     {
-        case PING:
-            impl.processPing();
-            return;
-        case PONG:
-            impl.processPong();
-            return;
-        case DATA:
-            impl.processData();
-            return;
+        printfd(__FILE__, "Processing error. %s", ex.what());
+        impl.m_logger("Processing error. %s", ex.what());
     }
     printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
     impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
@@ -364,34 +390,34 @@ void Conn::Impl::processData()
     int handle = m_users.OpenSearch();
 
     USER_PTR user = NULL;
-    bool match = false;
+    bool matched = false;
     while (m_users.SearchNext(handle, &user) == 0)
     {
         if (user == NULL)
             continue;
 
-        match = true;
-        for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
+        matched = true;
+        for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
         {
             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
             if (pos == m_parser.data().end())
             {
-                match = false;
+                matched = false;
                 break;
             }
             if (user->GetParamValue(it->second) != pos->second)
             {
-                match = false;
+                matched = false;
                 break;
             }
         }
-        if (!match)
+        if (!matched)
             continue;
         answer(*user);
         break;
     }
 
-    if (!match)
+    if (!matched)
         answerNo();
 
     m_users.CloseSearch(handle);
@@ -400,18 +426,18 @@ void Conn::Impl::processData()
 bool Conn::Impl::answer(const USER& user)
 {
     printfd(__FILE__, "Got match. Sending answer...\n");
-    MapGen reply;
-    for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
-        reply.add(it->first, new StringGen(user.GetParamValue(it->second)));
+    MapGen replyData;
+    for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
+        replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
 
-    MapGen modify;
-    for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
-        modify.add(it->first, new StringGen(user.GetParamValue(it->second)));
+    MapGen modifyData;
+    for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
+        modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
 
     PacketGen gen("data");
     gen.add("result", "ok")
-       .add("reply", reply)
-       .add("modify", modify);
+       .add("reply", replyData)
+       .add("modify", modifyData);
 
     m_lastPing = time(NULL);
 
index 38f2db2ff8aaff28f22c1849030c3942ec70965c..17ebc8cfdd8b07f36707e4881ad8d7a9681c17c6 100644 (file)
@@ -32,7 +32,7 @@ class PLUGIN_LOGGER;
 namespace STG
 {
 
-class Config;
+struct Config;
 
 class Conn
 {