]> git.stg.codes - stg.git/blobdiff - projects/stargazer/plugins/other/radius/conn.cpp
Merge branch 'stg-2.409-radius'
[stg.git] / projects / stargazer / plugins / other / radius / conn.cpp
index 99aa83b6cd59cc4cf69780809e176e43a1185839..a209409ba95e36f1e490d03217c0211b70ea3733 100644 (file)
@@ -20,6 +20,7 @@
 
 #include "conn.h"
 
+#include "radius.h"
 #include "config.h"
 
 #include "stg/json_parser.h"
@@ -32,6 +33,7 @@
 #include <yajl/yajl_gen.h>
 
 #include <map>
+#include <stdexcept>
 #include <cstring>
 #include <cerrno>
 
@@ -202,12 +204,29 @@ class PacketGen : public Gen
         StringGen m_type;
 };
 
+std::string toString(Config::ReturnCode code)
+{
+    switch (code)
+    {
+        case Config::REJECT:   return "reject";
+        case Config::FAIL:     return "fail";
+        case Config::OK:       return "ok";
+        case Config::HANDLED:  return "handled";
+        case Config::INVALID:  return "invalid";
+        case Config::USERLOCK: return "userlock";
+        case Config::NOTFOUND: return "notfound";
+        case Config::NOOP:     return "noop";
+        case Config::UPDATED:  return "noop";
+    }
+    return "reject";
+}
+
 }
 
 class Conn::Impl
 {
     public:
-        Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
+        Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
         ~Impl();
 
         int sock() const { return m_sock; }
@@ -220,6 +239,7 @@ class Conn::Impl
     private:
         USERS& m_users;
         PLUGIN_LOGGER& m_logger;
+        RADIUS& m_plugin;
         const Config& m_config;
         int m_sock;
         std::string m_remote;
@@ -227,6 +247,27 @@ class Conn::Impl
         time_t m_lastPing;
         time_t m_lastActivity;
         ProtoParser m_parser;
+        std::set<std::string> m_authorized;
+
+        template <typename T>
+        const T& stageMember(T Config::Section::* member) const
+        {
+            switch (m_parser.stage())
+            {
+                case AUTHORIZE: return m_config.autz.*member;
+                case AUTHENTICATE: return m_config.auth.*member;
+                case POSTAUTH: return m_config.postauth.*member;
+                case PREACCT: return m_config.preacct.*member;
+                case ACCOUNTING: return m_config.acct.*member;
+            }
+            throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
+        }
+
+        const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
+        const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
+        const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
+        Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
+        const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
 
         static void process(void* data);
         void processPing();
@@ -240,8 +281,8 @@ class Conn::Impl
         static bool write(void* data, const char* buf, size_t size);
 };
 
-Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
-    : m_impl(new Impl(users, logger, config, fd, remote))
+Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
+    : m_impl(new Impl(users, logger, plugin, config, fd, remote))
 {
 }
 
@@ -269,9 +310,10 @@ bool Conn::isOk() const
     return m_impl->isOk();
 }
 
-Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
+Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
     : m_users(users),
       m_logger(logger),
+      m_plugin(plugin),
       m_config(config),
       m_sock(fd),
       m_remote(remote),
@@ -285,6 +327,10 @@ Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int
 Conn::Impl::~Impl()
 {
     close(m_sock);
+
+    std::set<std::string>::const_iterator it = m_authorized.begin();
+    for (; it != m_authorized.end(); ++it)
+        m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
 }
 
 bool Conn::Impl::read()
@@ -330,17 +376,25 @@ bool Conn::Impl::tick()
 void Conn::Impl::process(void* data)
 {
     Impl& impl = *static_cast<Impl*>(data);
-    switch (impl.m_parser.packet())
+    try
+    {
+        switch (impl.m_parser.packet())
+        {
+            case PING:
+                impl.processPing();
+                return;
+            case PONG:
+                impl.processPong();
+                return;
+            case DATA:
+                impl.processData();
+                return;
+        }
+    }
+    catch (const std::exception& ex)
     {
-        case PING:
-            impl.processPing();
-            return;
-        case PONG:
-            impl.processPong();
-            return;
-        case DATA:
-            impl.processData();
-            return;
+        printfd(__FILE__, "Processing error. %s", ex.what());
+        impl.m_logger("Processing error. %s", ex.what());
     }
     printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
     impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
@@ -364,34 +418,39 @@ void Conn::Impl::processData()
     int handle = m_users.OpenSearch();
 
     USER_PTR user = NULL;
-    bool match = false;
+    bool matched = false;
     while (m_users.SearchNext(handle, &user) == 0)
     {
         if (user == NULL)
             continue;
 
-        match = true;
-        for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
+        matched = true;
+        for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
         {
             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
             if (pos == m_parser.data().end())
             {
-                match = false;
+                matched = false;
                 break;
             }
             if (user->GetParamValue(it->second) != pos->second)
             {
-                match = false;
+                matched = false;
                 break;
             }
         }
-        if (!match)
+        if (!matched)
             continue;
         answer(*user);
+        if (authorize().check(*user, m_parser.data()))
+        {
+            m_plugin.authorize(*user);
+            m_authorized.insert(user->GetLogin());
+        }
         break;
     }
 
-    if (!match)
+    if (!matched)
         answerNo();
 
     m_users.CloseSearch(handle);
@@ -400,18 +459,18 @@ void Conn::Impl::processData()
 bool Conn::Impl::answer(const USER& user)
 {
     printfd(__FILE__, "Got match. Sending answer...\n");
-    MapGen reply;
-    for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
-        reply.add(it->first, new StringGen(user.GetParamValue(it->second)));
+    MapGen replyData;
+    for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
+        replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
 
-    MapGen modify;
-    for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
-        modify.add(it->first, new StringGen(user.GetParamValue(it->second)));
+    MapGen modifyData;
+    for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
+        modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
 
     PacketGen gen("data");
     gen.add("result", "ok")
-       .add("reply", reply)
-       .add("modify", modify);
+       .add("reply", replyData)
+       .add("modify", modifyData);
 
     m_lastPing = time(NULL);
 
@@ -423,6 +482,7 @@ bool Conn::Impl::answerNo()
     printfd(__FILE__, "No match. Sending answer...\n");
     PacketGen gen("data");
     gen.add("result", "no");
+    gen.add("return_code", toString(returnCode()));
 
     m_lastPing = time(NULL);