X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/1538d6276533140505fddb71c99a0bafe6ca9182..e39f173d25ae650ee843e3d1c311abe9c1cb5ee9:/projects/stargazer/plugins/other/radius/conn.cpp?ds=sidebyside diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp index 223eb2cf..a209409b 100644 --- a/projects/stargazer/plugins/other/radius/conn.cpp +++ b/projects/stargazer/plugins/other/radius/conn.cpp @@ -20,22 +20,269 @@ #include "conn.h" +#include "radius.h" #include "config.h" +#include "stg/json_parser.h" +#include "stg/json_generator.h" #include "stg/users.h" #include "stg/user.h" #include "stg/logger.h" #include "stg/common.h" +#include + +#include +#include #include #include +#include +#include +#include + using STG::Conn; +using STG::Config; +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; -Conn::Conn(USERS& users, PLUGIN_LOGGER & logger, const Config& config) - : m_users(users), - m_logger(logger), - m_config(config) +namespace +{ + +double CONN_TIMEOUT = 60; +double PING_TIMEOUT = 10; + +enum Packet +{ + PING, + PONG, + DATA +}; + +enum Stage +{ + AUTHORIZE, + AUTHENTICATE, + PREACCT, + ACCOUNTING, + POSTAUTH +}; + +std::map packetCodes; +std::map stageCodes; + +class PacketParser : public EnumParser +{ + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class StageParser : public EnumParser +{ + public: + StageParser(NodeParser* next, Stage& stage, std::string& stageStr) + : EnumParser(next, stage, stageStr, stageCodes) + { + if (!stageCodes.empty()) + return; + stageCodes["authorize"] = AUTHORIZE; + stageCodes["authenticate"] = AUTHENTICATE; + stageCodes["preacct"] = PREACCT; + stageCodes["accounting"] = ACCOUNTING; + stageCodes["postauth"] = POSTAUTH; + } +}; + +class TopParser : public NodeParser +{ + public: + typedef void (*Callback) (void* /*data*/); + TopParser(Callback callback, void* data) + : m_packetParser(this, m_packet, m_packetStr), + m_stageParser(this, m_stage, m_stageStr), + m_pairsParser(this, m_data), + m_callback(callback), m_callbackData(data) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) + { + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "stage") + return &m_stageParser; + else if (key == "pairs") + return &m_pairsParser; + + return this; + } + virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& stageStr() const { return m_stageStr; } + Stage stage() const { return m_stage; } + const Config::Pairs& data() const { return m_data; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_stageStr; + Stage m_stage; + Config::Pairs m_data; + + PacketParser m_packetParser; + StageParser m_stageParser; + PairsParser m_pairsParser; + + Callback m_callback; + void* m_callbackData; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser(TopParser::Callback callback, void* data) + : Parser( &m_topParser ), + m_topParser(callback, data) + {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& stageStr() const { return m_topParser.stageStr(); } + Stage stage() const { return m_topParser.stage(); } + const Config::Pairs& data() const { return m_topParser.data(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen* map) + { + m_gen.add(key, map); + return *this; + } + PacketGen& add(const std::string& key, MapGen& map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; + +std::string toString(Config::ReturnCode code) +{ + switch (code) + { + case Config::REJECT: return "reject"; + case Config::FAIL: return "fail"; + case Config::OK: return "ok"; + case Config::HANDLED: return "handled"; + case Config::INVALID: return "invalid"; + case Config::USERLOCK: return "userlock"; + case Config::NOTFOUND: return "notfound"; + case Config::NOOP: return "noop"; + case Config::UPDATED: return "noop"; + } + return "reject"; +} + +} + +class Conn::Impl +{ + public: + Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote); + ~Impl(); + + int sock() const { return m_sock; } + + bool read(); + bool tick(); + + bool isOk() const { return m_ok; } + + private: + USERS& m_users; + PLUGIN_LOGGER& m_logger; + RADIUS& m_plugin; + const Config& m_config; + int m_sock; + std::string m_remote; + bool m_ok; + time_t m_lastPing; + time_t m_lastActivity; + ProtoParser m_parser; + std::set m_authorized; + + template + const T& stageMember(T Config::Section::* member) const + { + switch (m_parser.stage()) + { + case AUTHORIZE: return m_config.autz.*member; + case AUTHENTICATE: return m_config.auth.*member; + case POSTAUTH: return m_config.postauth.*member; + case PREACCT: return m_config.preacct.*member; + case ACCOUNTING: return m_config.acct.*member; + } + throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'."); + } + + const Config::Pairs& match() const { return stageMember(&Config::Section::match); } + const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); } + const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); } + Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); } + const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); } + + static void process(void* data); + void processPing(); + void processPong(); + void processData(); + bool answer(const USER& user); + bool answerNo(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + +Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote) + : m_impl(new Impl(users, logger, plugin, config, fd, remote)) { } @@ -43,21 +290,238 @@ Conn::~Conn() { } +int Conn::sock() const +{ + return m_impl->sock(); +} + bool Conn::read() { - ssize_t res = read(m_sock, m_buffer, m_bufferSize); + return m_impl->read(); +} + +bool Conn::tick() +{ + return m_impl->tick(); +} + +bool Conn::isOk() const +{ + return m_impl->isOk(); +} + +Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote) + : m_users(users), + m_logger(logger), + m_plugin(plugin), + m_config(config), + m_sock(fd), + m_remote(remote), + m_ok(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing), + m_parser(&Conn::Impl::process, this) +{ +} + +Conn::Impl::~Impl() +{ + close(m_sock); + + std::set::const_iterator it = m_authorized.begin(); + for (; it != m_authorized.end(); ++it) + m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + "."); +} + +bool Conn::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); if (res < 0) { - m_state = ERROR; - Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Reason: '" + strerror(errno) + "'"); + m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + m_ok = false; return false; } - if (res == 0 && m_state != DATA) // EOF is ok for data. + printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str()); + m_lastActivity = time(NULL); + if (res == 0) + { + m_ok = false; + return true; + } + return m_parser.append(buffer.data(), res); +} + +bool Conn::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) { - m_state = ERROR; - Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Unexpected EOF."); + int delta = difftime(now, m_lastActivity); + printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta); + m_logger("Connection to " + m_remote + " timed out."); + m_ok = false; return false; } - m_bufferSize -= res; - return HandleBuffer(res); + if (difftime(now, m_lastPing) > PING_TIMEOUT) + { + int delta = difftime(now, m_lastPing); + printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta); + sendPing(); + } + return true; +} + +void Conn::Impl::process(void* data) +{ + Impl& impl = *static_cast(data); + try + { + switch (impl.m_parser.packet()) + { + case PING: + impl.processPing(); + return; + case PONG: + impl.processPong(); + return; + case DATA: + impl.processData(); + return; + } + } + catch (const std::exception& ex) + { + printfd(__FILE__, "Processing error. %s", ex.what()); + impl.m_logger("Processing error. %s", ex.what()); + } + printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str()); + impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr()); +} + +void Conn::Impl::processPing() +{ + printfd(__FILE__, "Got ping. Sending pong...\n"); + sendPong(); +} + +void Conn::Impl::processPong() +{ + printfd(__FILE__, "Got pong.\n"); + m_lastActivity = time(NULL); +} + +void Conn::Impl::processData() +{ + printfd(__FILE__, "Got data.\n"); + int handle = m_users.OpenSearch(); + + USER_PTR user = NULL; + bool matched = false; + while (m_users.SearchNext(handle, &user) == 0) + { + if (user == NULL) + continue; + + matched = true; + for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it) + { + Config::Pairs::const_iterator pos = m_parser.data().find(it->first); + if (pos == m_parser.data().end()) + { + matched = false; + break; + } + if (user->GetParamValue(it->second) != pos->second) + { + matched = false; + break; + } + } + if (!matched) + continue; + answer(*user); + if (authorize().check(*user, m_parser.data())) + { + m_plugin.authorize(*user); + m_authorized.insert(user->GetLogin()); + } + break; + } + + if (!matched) + answerNo(); + + m_users.CloseSearch(handle); +} + +bool Conn::Impl::answer(const USER& user) +{ + printfd(__FILE__, "Got match. Sending answer...\n"); + MapGen replyData; + for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it) + replyData.add(it->first, new StringGen(user.GetParamValue(it->second))); + + MapGen modifyData; + for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it) + modifyData.add(it->first, new StringGen(user.GetParamValue(it->second))); + + PacketGen gen("data"); + gen.add("result", "ok") + .add("reply", replyData) + .add("modify", modifyData); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::answerNo() +{ + printfd(__FILE__, "No match. Sending answer...\n"); + PacketGen gen("data"); + gen.add("result", "no"); + gen.add("return_code", toString(returnCode())); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPong() +{ + PacketGen gen("pong"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::write(void* data, const char* buf, size_t size) +{ + std::string json(buf, size); + printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str()); + Conn::Impl& conn = *static_cast(data); + while (size > 0) + { + ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL); + if (res < 0) + { + conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno)); + conn.m_ok = false; + return false; + } + size -= res; + } + return true; }