X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/8e28d5793334af32fced307d23554f91f0847a5e..a637d472ffa2023dd0748557f9ef343d7138c2f0:/projects/stargazer/plugins/other/radius/conn.cpp diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp index 99aa83b6..a209409b 100644 --- a/projects/stargazer/plugins/other/radius/conn.cpp +++ b/projects/stargazer/plugins/other/radius/conn.cpp @@ -20,6 +20,7 @@ #include "conn.h" +#include "radius.h" #include "config.h" #include "stg/json_parser.h" @@ -32,6 +33,7 @@ #include #include +#include #include #include @@ -202,12 +204,29 @@ class PacketGen : public Gen StringGen m_type; }; +std::string toString(Config::ReturnCode code) +{ + switch (code) + { + case Config::REJECT: return "reject"; + case Config::FAIL: return "fail"; + case Config::OK: return "ok"; + case Config::HANDLED: return "handled"; + case Config::INVALID: return "invalid"; + case Config::USERLOCK: return "userlock"; + case Config::NOTFOUND: return "notfound"; + case Config::NOOP: return "noop"; + case Config::UPDATED: return "noop"; + } + return "reject"; +} + } class Conn::Impl { public: - Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote); + Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote); ~Impl(); int sock() const { return m_sock; } @@ -220,6 +239,7 @@ class Conn::Impl private: USERS& m_users; PLUGIN_LOGGER& m_logger; + RADIUS& m_plugin; const Config& m_config; int m_sock; std::string m_remote; @@ -227,6 +247,27 @@ class Conn::Impl time_t m_lastPing; time_t m_lastActivity; ProtoParser m_parser; + std::set m_authorized; + + template + const T& stageMember(T Config::Section::* member) const + { + switch (m_parser.stage()) + { + case AUTHORIZE: return m_config.autz.*member; + case AUTHENTICATE: return m_config.auth.*member; + case POSTAUTH: return m_config.postauth.*member; + case PREACCT: return m_config.preacct.*member; + case ACCOUNTING: return m_config.acct.*member; + } + throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'."); + } + + const Config::Pairs& match() const { return stageMember(&Config::Section::match); } + const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); } + const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); } + Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); } + const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); } static void process(void* data); void processPing(); @@ -240,8 +281,8 @@ class Conn::Impl static bool write(void* data, const char* buf, size_t size); }; -Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) - : m_impl(new Impl(users, logger, config, fd, remote)) +Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote) + : m_impl(new Impl(users, logger, plugin, config, fd, remote)) { } @@ -269,9 +310,10 @@ bool Conn::isOk() const return m_impl->isOk(); } -Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) +Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote) : m_users(users), m_logger(logger), + m_plugin(plugin), m_config(config), m_sock(fd), m_remote(remote), @@ -285,6 +327,10 @@ Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int Conn::Impl::~Impl() { close(m_sock); + + std::set::const_iterator it = m_authorized.begin(); + for (; it != m_authorized.end(); ++it) + m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + "."); } bool Conn::Impl::read() @@ -330,17 +376,25 @@ bool Conn::Impl::tick() void Conn::Impl::process(void* data) { Impl& impl = *static_cast(data); - switch (impl.m_parser.packet()) + try + { + switch (impl.m_parser.packet()) + { + case PING: + impl.processPing(); + return; + case PONG: + impl.processPong(); + return; + case DATA: + impl.processData(); + return; + } + } + catch (const std::exception& ex) { - case PING: - impl.processPing(); - return; - case PONG: - impl.processPong(); - return; - case DATA: - impl.processData(); - return; + printfd(__FILE__, "Processing error. %s", ex.what()); + impl.m_logger("Processing error. %s", ex.what()); } printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str()); impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr()); @@ -364,34 +418,39 @@ void Conn::Impl::processData() int handle = m_users.OpenSearch(); USER_PTR user = NULL; - bool match = false; + bool matched = false; while (m_users.SearchNext(handle, &user) == 0) { if (user == NULL) continue; - match = true; - for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it) + matched = true; + for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it) { Config::Pairs::const_iterator pos = m_parser.data().find(it->first); if (pos == m_parser.data().end()) { - match = false; + matched = false; break; } if (user->GetParamValue(it->second) != pos->second) { - match = false; + matched = false; break; } } - if (!match) + if (!matched) continue; answer(*user); + if (authorize().check(*user, m_parser.data())) + { + m_plugin.authorize(*user); + m_authorized.insert(user->GetLogin()); + } break; } - if (!match) + if (!matched) answerNo(); m_users.CloseSearch(handle); @@ -400,18 +459,18 @@ void Conn::Impl::processData() bool Conn::Impl::answer(const USER& user) { printfd(__FILE__, "Got match. Sending answer...\n"); - MapGen reply; - for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it) - reply.add(it->first, new StringGen(user.GetParamValue(it->second))); + MapGen replyData; + for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it) + replyData.add(it->first, new StringGen(user.GetParamValue(it->second))); - MapGen modify; - for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it) - modify.add(it->first, new StringGen(user.GetParamValue(it->second))); + MapGen modifyData; + for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it) + modifyData.add(it->first, new StringGen(user.GetParamValue(it->second))); PacketGen gen("data"); gen.add("result", "ok") - .add("reply", reply) - .add("modify", modify); + .add("reply", replyData) + .add("modify", modifyData); m_lastPing = time(NULL); @@ -423,6 +482,7 @@ bool Conn::Impl::answerNo() printfd(__FILE__, "No match. Sending answer...\n"); PacketGen gen("data"); gen.add("result", "no"); + gen.add("return_code", toString(returnCode())); m_lastPing = time(NULL);