]> git.stg.codes - stg.git/commitdiff
Optional authorization.
authorMaxim Mamontov <faust.madf@gmail.com>
Wed, 9 Dec 2015 20:28:48 +0000 (22:28 +0200)
committerMaxim Mamontov <faust.madf@gmail.com>
Wed, 9 Dec 2015 20:28:48 +0000 (22:28 +0200)
projects/stargazer/plugins/other/radius/config.cpp
projects/stargazer/plugins/other/radius/config.h
projects/stargazer/plugins/other/radius/conn.cpp
projects/stargazer/plugins/other/radius/conn.h
projects/stargazer/plugins/other/radius/radius.cpp
projects/stargazer/plugins/other/radius/radius.h

index 68b8760f1347c29f240d08e97e5abfdd1342cfc1..79822852a8ec9a59d2102d84929bb30e7e2740f0 100644 (file)
@@ -20,6 +20,7 @@
 
 #include "config.h"
 
 
 #include "config.h"
 
+#include "stg/user.h"
 #include "stg/common.h"
 
 #include <vector>
 #include "stg/common.h"
 
 #include <vector>
@@ -34,6 +35,11 @@ namespace
 
 struct ParserError : public std::runtime_error
 {
 
 struct ParserError : public std::runtime_error
 {
+    ParserError(const std::string& message)
+        : runtime_error("Config is not valid. " + message),
+          position(0),
+          error(message)
+    {}
     ParserError(size_t pos, const std::string& message)
         : runtime_error("Parsing error at position " + x2str(pos) + ". " + message),
           position(pos),
     ParserError(size_t pos, const std::string& message)
         : runtime_error("Parsing error at position " + x2str(pos) + ". " + message),
           position(pos),
@@ -132,7 +138,7 @@ uid_t toUID(const std::vector<std::string>& values)
         return -1;
     uid_t res = str2uid(values[0]);
     if (res == static_cast<uid_t>(-1))
         return -1;
     uid_t res = str2uid(values[0]);
     if (res == static_cast<uid_t>(-1))
-        throw ParserError(0, "Invalid user name: '" + values[0] + "'");
+        throw ParserError("Invalid user name: '" + values[0] + "'");
     return res;
 }
 
     return res;
 }
 
@@ -142,7 +148,7 @@ gid_t toGID(const std::vector<std::string>& values)
         return -1;
     gid_t res = str2gid(values[0]);
     if (res == static_cast<gid_t>(-1))
         return -1;
     gid_t res = str2gid(values[0]);
     if (res == static_cast<gid_t>(-1))
-        throw ParserError(0, "Invalid group name: '" + values[0] + "'");
+        throw ParserError("Invalid group name: '" + values[0] + "'");
     return res;
 }
 
     return res;
 }
 
@@ -152,7 +158,7 @@ mode_t toMode(const std::vector<std::string>& values)
         return -1;
     mode_t res = str2mode(values[0]);
     if (res == static_cast<mode_t>(-1))
         return -1;
     mode_t res = str2mode(values[0]);
     if (res == static_cast<mode_t>(-1))
-        throw ParserError(0, "Invalid mode: '" + values[0] + "'");
+        throw ParserError("Invalid mode: '" + values[0] + "'");
     return res;
 }
 
     return res;
 }
 
@@ -174,7 +180,7 @@ uint16_t toPort(const std::string& value)
     uint16_t res = 0;
     if (str2x(value, res) == 0)
         return res;
     uint16_t res = 0;
     if (str2x(value, res) == 0)
         return res;
-    throw ParserError(0, "'" + value + "' is not a valid port number.");
+    throw ParserError("'" + value + "' is not a valid port number.");
 }
 
 typedef std::map<std::string, Config::ReturnCode> Codes;
 }
 
 typedef std::map<std::string, Config::ReturnCode> Codes;
@@ -215,6 +221,14 @@ Config::Pairs parseVector(const std::string& paramName, const std::vector<PARAM_
     return Config::Pairs();
 }
 
     return Config::Pairs();
 }
 
+Config::Authorize parseAuthorize(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return Config::Authorize(toPairs(params[i].value));
+    return Config::Authorize();
+}
+
 Config::ReturnCode parseReturnCode(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
 {
     for (size_t i = 0; i < params.size(); ++i)
 Config::ReturnCode parseReturnCode(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
 {
     for (size_t i = 0; i < params.size(); ++i)
@@ -243,13 +257,13 @@ std::string parseAddress(Config::Type connectionType, const std::string& value)
 {
     size_t pos = value.find_first_of(':');
     if (pos == std::string::npos)
 {
     size_t pos = value.find_first_of(':');
     if (pos == std::string::npos)
-        throw ParserError(0, "Connection type is not specified. Should be either 'unix' or 'tcp'.");
+        throw ParserError("Connection type is not specified. Should be either 'unix' or 'tcp'.");
     if (connectionType == Config::UNIX)
         return value.substr(pos + 1);
     std::string address(value.substr(pos + 1));
     pos = address.find_first_of(':', pos + 1);
     if (pos == std::string::npos)
     if (connectionType == Config::UNIX)
         return value.substr(pos + 1);
     std::string address(value.substr(pos + 1));
     pos = address.find_first_of(':', pos + 1);
     if (pos == std::string::npos)
-        throw ParserError(0, "Port is not specified.");
+        throw ParserError("Port is not specified.");
     return address.substr(0, pos - 1);
 }
 
     return address.substr(0, pos - 1);
 }
 
@@ -257,13 +271,13 @@ std::string parsePort(Config::Type connectionType, const std::string& value)
 {
     size_t pos = value.find_first_of(':');
     if (pos == std::string::npos)
 {
     size_t pos = value.find_first_of(':');
     if (pos == std::string::npos)
-        throw ParserError(0, "Connection type is not specified. Should be either 'unix' or 'tcp'.");
+        throw ParserError("Connection type is not specified. Should be either 'unix' or 'tcp'.");
     if (connectionType == Config::UNIX)
         return "";
     std::string address(value.substr(pos + 1));
     pos = address.find_first_of(':', pos + 1);
     if (pos == std::string::npos)
     if (connectionType == Config::UNIX)
         return "";
     std::string address(value.substr(pos + 1));
     pos = address.find_first_of(':', pos + 1);
     if (pos == std::string::npos)
-        throw ParserError(0, "Port is not specified.");
+        throw ParserError("Port is not specified.");
     return address.substr(pos + 1);
 }
 
     return address.substr(pos + 1);
 }
 
@@ -271,13 +285,13 @@ Config::Type parseConnectionType(const std::string& address)
 {
     size_t pos = address.find_first_of(':');
     if (pos == std::string::npos)
 {
     size_t pos = address.find_first_of(':');
     if (pos == std::string::npos)
-        throw ParserError(0, "Connection type is not specified. Should be either 'unix' or 'tcp'.");
+        throw ParserError("Connection type is not specified. Should be either 'unix' or 'tcp'.");
     std::string type = ToLower(address.substr(0, pos));
     if (type == "unix")
         return Config::UNIX;
     else if (type == "tcp")
         return Config::TCP;
     std::string type = ToLower(address.substr(0, pos));
     if (type == "unix")
         return Config::UNIX;
     else if (type == "tcp")
         return Config::TCP;
-    throw ParserError(0, "Invalid connection type. Should be either 'unix' or 'tcp', got '" + type + "'");
+    throw ParserError("Invalid connection type. Should be either 'unix' or 'tcp', got '" + type + "'");
 }
 
 Config::Section parseSection(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
 }
 
 Config::Section parseSection(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
@@ -287,7 +301,8 @@ Config::Section parseSection(const std::string& paramName, const std::vector<PAR
             return Config::Section(parseVector("match", params[i].sections),
                                    parseVector("modify", params[i].sections),
                                    parseVector("reply", params[i].sections),
             return Config::Section(parseVector("match", params[i].sections),
                                    parseVector("modify", params[i].sections),
                                    parseVector("reply", params[i].sections),
-                                   parseReturnCode("no_match", params[i].sections));
+                                   parseReturnCode("no_match", params[i].sections),
+                                   parseAuthorize("authorize", params[i].sections));
     return Config::Section();
 }
 
     return Config::Section();
 }
 
@@ -317,6 +332,27 @@ mode_t parseMode(const std::string& paramName, const std::vector<PARAM_VALUE>& p
 
 } // namespace anonymous
 
 
 } // namespace anonymous
 
+bool Config::Authorize::check(const USER& user, const Config::Pairs& radiusData) const
+{
+    if (!m_auth)
+        return false; // No flag - no authorization.
+
+    if (m_cond.empty())
+        return true; // Empty parameter - always authorize.
+
+    Config::Pairs::const_iterator it = m_cond.begin();
+    for (; it != m_cond.end(); ++it)
+    {
+        const Config::Pairs::const_iterator pos = radiusData.find(it->first);
+        if (pos == radiusData.end())
+            return false; // No required Radius parameter.
+        if (user.GetParamValue(it->second) != pos->second)
+            return false; // No match with the user.
+    }
+
+    return true;
+}
+
 Config::Config(const MODULE_SETTINGS& settings)
     : autz(parseSection("autz", settings.moduleParams)),
       auth(parseSection("auth", settings.moduleParams)),
 Config::Config(const MODULE_SETTINGS& settings)
     : autz(parseSection("autz", settings.moduleParams)),
       auth(parseSection("auth", settings.moduleParams)),
@@ -334,4 +370,17 @@ Config::Config(const MODULE_SETTINGS& settings)
       sockGID(parseGID("sock_group", settings.moduleParams)),
       sockMode(parseMode("sock_mode", settings.moduleParams))
 {
       sockGID(parseGID("sock_group", settings.moduleParams)),
       sockMode(parseMode("sock_mode", settings.moduleParams))
 {
+    size_t count = 0;
+    if (autz.authorize.exists())
+        ++count;
+    if (auth.authorize.exists())
+        ++count;
+    if (postauth.authorize.exists())
+        ++count;
+    if (preacct.authorize.exists())
+        ++count;
+    if (acct.authorize.exists())
+        ++count;
+    if (count > 0)
+        throw ParserError("Authorization flag is specified in more than one section.");
 }
 }
index 2d4f638f652dc4263bbf5b86e627fe2055f7ade4..44d5ed856c38818f7c211072183260e1f86da7a9 100644 (file)
@@ -31,6 +31,8 @@
 #include <unistd.h> // uid_t, gid_t
 #include <sys/stat.h> // mode_t
 
 #include <unistd.h> // uid_t, gid_t
 #include <sys/stat.h> // mode_t
 
+class USER;
+
 namespace STG
 {
 
 namespace STG
 {
 
@@ -52,15 +54,29 @@ struct Config
         UPDATED   // Module sends some updates.
     };
 
         UPDATED   // Module sends some updates.
     };
 
+    class Authorize
+    {
+        public:
+            Authorize() : m_auth(false) {}
+            Authorize(const Pairs& cond) : m_auth(true), m_cond(cond) {}
+
+            bool check(const USER& user, const Pairs& radiusData) const;
+            bool exists() const { return m_auth; }
+        private:
+            bool m_auth;
+            Pairs m_cond;
+    };
+
     struct Section
     {
         Section() {}
     struct Section
     {
         Section() {}
-        Section(const Pairs& ma, const Pairs& mo, const Pairs& re, ReturnCode code)
-            : match(ma), modify(mo), reply(re), returnCode(code) {}
+        Section(const Pairs& ma, const Pairs& mo, const Pairs& re, ReturnCode code, const Authorize& auth)
+            : match(ma), modify(mo), reply(re), returnCode(code), authorize(auth) {}
         Pairs match;
         Pairs modify;
         Pairs reply;
         ReturnCode returnCode;
         Pairs match;
         Pairs modify;
         Pairs reply;
         ReturnCode returnCode;
+        Authorize authorize;
     };
 
     Config() {}
     };
 
     Config() {}
index 8a416ae012be81aa7539678fc80410752964a328..a209409ba95e36f1e490d03217c0211b70ea3733 100644 (file)
@@ -20,6 +20,7 @@
 
 #include "conn.h"
 
 
 #include "conn.h"
 
+#include "radius.h"
 #include "config.h"
 
 #include "stg/json_parser.h"
 #include "config.h"
 
 #include "stg/json_parser.h"
@@ -225,7 +226,7 @@ std::string toString(Config::ReturnCode code)
 class Conn::Impl
 {
     public:
 class Conn::Impl
 {
     public:
-        Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
+        Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
         ~Impl();
 
         int sock() const { return m_sock; }
         ~Impl();
 
         int sock() const { return m_sock; }
@@ -238,6 +239,7 @@ class Conn::Impl
     private:
         USERS& m_users;
         PLUGIN_LOGGER& m_logger;
     private:
         USERS& m_users;
         PLUGIN_LOGGER& m_logger;
+        RADIUS& m_plugin;
         const Config& m_config;
         int m_sock;
         std::string m_remote;
         const Config& m_config;
         int m_sock;
         std::string m_remote;
@@ -245,6 +247,7 @@ class Conn::Impl
         time_t m_lastPing;
         time_t m_lastActivity;
         ProtoParser m_parser;
         time_t m_lastPing;
         time_t m_lastActivity;
         ProtoParser m_parser;
+        std::set<std::string> m_authorized;
 
         template <typename T>
         const T& stageMember(T Config::Section::* member) const
 
         template <typename T>
         const T& stageMember(T Config::Section::* member) const
@@ -264,6 +267,7 @@ class Conn::Impl
         const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
         const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
         Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
         const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
         const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
         Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
+        const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
 
         static void process(void* data);
         void processPing();
 
         static void process(void* data);
         void processPing();
@@ -277,8 +281,8 @@ class Conn::Impl
         static bool write(void* data, const char* buf, size_t size);
 };
 
         static bool write(void* data, const char* buf, size_t size);
 };
 
-Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
-    : m_impl(new Impl(users, logger, config, fd, remote))
+Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
+    : m_impl(new Impl(users, logger, plugin, config, fd, remote))
 {
 }
 
 {
 }
 
@@ -306,9 +310,10 @@ bool Conn::isOk() const
     return m_impl->isOk();
 }
 
     return m_impl->isOk();
 }
 
-Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
+Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
     : m_users(users),
       m_logger(logger),
     : m_users(users),
       m_logger(logger),
+      m_plugin(plugin),
       m_config(config),
       m_sock(fd),
       m_remote(remote),
       m_config(config),
       m_sock(fd),
       m_remote(remote),
@@ -322,6 +327,10 @@ Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int
 Conn::Impl::~Impl()
 {
     close(m_sock);
 Conn::Impl::~Impl()
 {
     close(m_sock);
+
+    std::set<std::string>::const_iterator it = m_authorized.begin();
+    for (; it != m_authorized.end(); ++it)
+        m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
 }
 
 bool Conn::Impl::read()
 }
 
 bool Conn::Impl::read()
@@ -433,6 +442,11 @@ void Conn::Impl::processData()
         if (!matched)
             continue;
         answer(*user);
         if (!matched)
             continue;
         answer(*user);
+        if (authorize().check(*user, m_parser.data()))
+        {
+            m_plugin.authorize(*user);
+            m_authorized.insert(user->GetLogin());
+        }
         break;
     }
 
         break;
     }
 
index 17ebc8cfdd8b07f36707e4881ad8d7a9681c17c6..96e74300f2f7f8b413570f0b362f3a970415792e 100644 (file)
@@ -28,6 +28,7 @@
 class USER;
 class USERS;
 class PLUGIN_LOGGER;
 class USER;
 class USERS;
 class PLUGIN_LOGGER;
+class RADIUS;
 
 namespace STG
 {
 
 namespace STG
 {
@@ -37,7 +38,7 @@ struct Config;
 class Conn
 {
     public:
 class Conn
 {
     public:
-        Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
+        Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
         ~Conn();
 
         int sock() const;
         ~Conn();
 
         int sock() const;
index 376e427860fd18e7eb5d07d501fbba57e498c935..45a9d0e19446f80054df4c9bdf972b0ba79ef73f 100644 (file)
@@ -91,6 +91,11 @@ int RADIUS::Start()
 
 int RADIUS::Stop()
 {
 
 int RADIUS::Stop()
 {
+    std::set<std::string>::const_iterator it = m_logins.begin();
+    for (; it != m_logins.end(); ++it)
+        m_users->Unauthorize(*it, this, "Stopping RADIUS plugin.");
+    m_logins.clear();
+
     if (m_stopped)
         return 0;
 
     if (m_stopped)
         return 0;
 
@@ -342,7 +347,7 @@ void RADIUS::acceptUNIX()
         return;
     }
     printfd(__FILE__, "New UNIX connection: '%s'\n", addr.sun_path);
         return;
     }
     printfd(__FILE__, "New UNIX connection: '%s'\n", addr.sun_path);
-    m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, addr.sun_path));
+    m_conns.push_back(new Conn(*m_users, m_logger, *this, m_config, res, addr.sun_path));
 }
 
 void RADIUS::acceptTCP()
 }
 
 void RADIUS::acceptTCP()
@@ -359,5 +364,27 @@ void RADIUS::acceptTCP()
     }
     std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port));
     printfd(__FILE__, "New TCP connection: '%s'\n", remote.c_str());
     }
     std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port));
     printfd(__FILE__, "New TCP connection: '%s'\n", remote.c_str());
-    m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, remote));
+    m_conns.push_back(new Conn(*m_users, m_logger, *this, m_config, res, remote));
+}
+
+void RADIUS::authorize(const USER& user)
+{
+    uint32_t ip = 0;
+    const std::string& login(user.GetLogin());
+    if (!m_users->Authorize(login, ip, 0xffFFffFF, this))
+    {
+        m_error = "Unable to authorize user '" + login + "' with ip " + inet_ntostring(ip) + ".";
+        m_logger(m_error);
+    }
+    else
+        m_logins.insert(login);
+}
+
+void RADIUS::unauthorize(const std::string& login, const std::string& reason)
+{
+    const std::set<std::string>::const_iterator it = m_logins.find(login);
+    if (it == m_logins.end())
+        return;
+    m_logins.erase(it);
+    m_users->Unauthorize(login, this, reason);
 }
 }
index 2573cf9ec5809e87b833b5ef75daf70dd88eef48..742923f0f4fa86840be797a04a16b55a5e3cbac2 100644 (file)
@@ -31,6 +31,7 @@
 
 #include <string>
 #include <deque>
 
 #include <string>
 #include <deque>
+#include <set>
 
 #include <pthread.h>
 #include <unistd.h>
 
 #include <pthread.h>
 #include <unistd.h>
@@ -59,12 +60,15 @@ public:
     bool IsRunning() { return m_running; }
 
     const std::string& GetStrError() const { return m_error; }
     bool IsRunning() { return m_running; }
 
     const std::string& GetStrError() const { return m_error; }
-    std::string GetVersion() const { return "RADIUS data access plugin v 1.0"; }
+    std::string GetVersion() const { return "RADIUS data access plugin v. 2.0"; }
     uint16_t GetStartPosition() const { return 30; }
     uint16_t GetStopPosition() const { return 30; }
 
     int SendMessage(const STG_MSG&, uint32_t) const { return 0; }
 
     uint16_t GetStartPosition() const { return 30; }
     uint16_t GetStopPosition() const { return 30; }
 
     int SendMessage(const STG_MSG&, uint32_t) const { return 0; }
 
+    void authorize(const USER& user);
+    void unauthorize(const std::string& login, const std::string& reason);
+
 private:
     RADIUS(const RADIUS & rvalue);
     RADIUS & operator=(const RADIUS & rvalue);
 private:
     RADIUS(const RADIUS & rvalue);
     RADIUS & operator=(const RADIUS & rvalue);
@@ -96,6 +100,7 @@ private:
 
     int m_listenSocket;
     std::deque<STG::Conn*> m_conns;
 
     int m_listenSocket;
     std::deque<STG::Conn*> m_conns;
+    std::set<std::string> m_logins;
 
     pthread_t m_thread;
 
 
     pthread_t m_thread;