X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/70f8adff2c970496bdc45717cad49ddec0405ae7..354ede117d8495936d0ad5bbf945aa4386578acd:/projects/stargazer/plugins/other/radius/conn.cpp?ds=inline diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp index 223eb2cf..392624a6 100644 --- a/projects/stargazer/plugins/other/radius/conn.cpp +++ b/projects/stargazer/plugins/other/radius/conn.cpp @@ -22,20 +22,211 @@ #include "config.h" +#include "stg/json_parser.h" +#include "stg/json_generator.h" #include "stg/users.h" #include "stg/user.h" #include "stg/logger.h" #include "stg/common.h" +#include + +#include #include #include +#include + using STG::Conn; +using STG::Config; +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; -Conn::Conn(USERS& users, PLUGIN_LOGGER & logger, const Config& config) - : m_users(users), - m_logger(logger), - m_config(config) +namespace +{ + +double CONN_TIMEOUT = 5; +double PING_TIMEOUT = 1; + +enum Packet +{ + PING, + PONG, + DATA +}; + +enum Stage +{ + AUTHORIZE, + AUTHENTICATE, + PREACCT, + ACCOUNTING, + POSTAUTH +}; + +std::map packetCodes; +std::map stageCodes; + +class PacketParser : public EnumParser +{ + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class StageParser : public EnumParser +{ + public: + StageParser(NodeParser* next, Stage& stage, std::string& stageStr) + : EnumParser(next, stage, stageStr, stageCodes) + { + if (!stageCodes.empty()) + return; + stageCodes["authorize"] = AUTHORIZE; + stageCodes["authenticate"] = AUTHENTICATE; + stageCodes["preacct"] = PREACCT; + stageCodes["accounting"] = ACCOUNTING; + stageCodes["postauth"] = POSTAUTH; + } +}; + +class TopParser : public NodeParser +{ + public: + TopParser() + : m_packetParser(this, m_packet, m_packetStr), + m_stageParser(this, m_stage, m_stageStr), + m_pairsParser(this, m_data) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) + { + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "stage") + return &m_stageParser; + else if (key == "pairs") + return &m_pairsParser; + + return this; + } + virtual NodeParser* parseEndMap() { return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& stageStr() const { return m_stageStr; } + Stage stage() const { return m_stage; } + const Config::Pairs& data() const { return m_data; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_stageStr; + Stage m_stage; + Config::Pairs m_data; + + PacketParser m_packetParser; + StageParser m_stageParser; + PairsParser m_pairsParser; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser() : Parser( &m_topParser ) {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& stageStr() const { return m_topParser.stageStr(); } + Stage stage() const { return m_topParser.stage(); } + const Config::Pairs& data() const { return m_topParser.data(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen* map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; + +} + +class Conn::Impl +{ + public: + Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote); + ~Impl(); + + int sock() const { return m_sock; } + + bool read(); + bool tick(); + + bool isOk() const { return m_ok; } + + private: + USERS& m_users; + PLUGIN_LOGGER& m_logger; + const Config& m_config; + int m_sock; + std::string m_remote; + bool m_ok; + time_t m_lastPing; + time_t m_lastActivity; + ProtoParser m_parser; + + bool process(); + bool processPing(); + bool processPong(); + bool processData(); + bool answer(const USER& user); + bool answerNo(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + +Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) + : m_impl(new Impl(users, logger, config, fd, remote)) { } @@ -43,21 +234,208 @@ Conn::~Conn() { } +int Conn::sock() const +{ + return m_impl->sock(); +} + bool Conn::read() { - ssize_t res = read(m_sock, m_buffer, m_bufferSize); + return m_impl->read(); +} + +bool Conn::tick() +{ + return m_impl->tick(); +} + +bool Conn::isOk() const +{ + return m_impl->isOk(); +} + +Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) + : m_users(users), + m_logger(logger), + m_config(config), + m_sock(fd), + m_remote(remote), + m_ok(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing) +{ +} + +Conn::Impl::~Impl() +{ + close(m_sock); +} + +bool Conn::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); if (res < 0) { - m_state = ERROR; - Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Reason: '" + strerror(errno) + "'"); + m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + m_ok = false; return false; } - if (res == 0 && m_state != DATA) // EOF is ok for data. + m_lastActivity = time(NULL); + if (res == 0) + { + if (!m_parser.done()) + { + m_ok = false; + m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + return false; + } + return process(); + } + return m_parser.append(buffer.data(), res); +} + +bool Conn::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) { - m_state = ERROR; - Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Unexpected EOF."); + m_logger("Connection to " + m_remote + " timed out."); + m_ok = false; return false; } - m_bufferSize -= res; - return HandleBuffer(res); + if (difftime(now, m_lastPing) > PING_TIMEOUT) + sendPing(); + return true; +} + +bool Conn::Impl::process() +{ + switch (m_parser.packet()) + { + case PING: + return processPing(); + case PONG: + return processPong(); + case DATA: + return processData(); + } + m_logger("Received invalid packet type: " + m_parser.packetStr()); + return false; +} + +bool Conn::Impl::processPing() +{ + return sendPong(); +} + +bool Conn::Impl::processPong() +{ + m_lastActivity = time(NULL); + return true; +} + +bool Conn::Impl::processData() +{ + int handle = m_users.OpenSearch(); + + USER_PTR user = NULL; + bool match = true; + while (m_users.SearchNext(handle, &user)) + { + if (user == NULL) + continue; + + match = true; + for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it) + { + Config::Pairs::const_iterator pos = m_parser.data().find(it->first); + if (pos == m_parser.data().end()) + { + match = false; + break; + } + if (user->GetParamValue(it->second) != pos->second) + { + match = false; + break; + } + } + if (!match) + continue; + answer(*user); + break; + } + + if (!match) + answerNo(); + + m_users.CloseSearch(handle); + + return true; +} + +bool Conn::Impl::answer(const USER& user) +{ + boost::scoped_ptr reply(new MapGen); + for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it) + reply->add(it->first, new StringGen(user.GetParamValue(it->second))); + + boost::scoped_ptr modify(new MapGen); + for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it) + modify->add(it->first, new StringGen(user.GetParamValue(it->second))); + + PacketGen gen("data"); + gen.add("result", "ok") + .add("reply", reply.get()) + .add("modify", modify.get()); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::answerNo() +{ + PacketGen gen("data"); + gen.add("result", "ok"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPong() +{ + PacketGen gen("pong"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::write(void* data, const char* buf, size_t size) +{ + Conn::Impl& conn = *static_cast(data); + while (size > 0) + { + ssize_t res = ::write(conn.m_sock, buf, size); + if (res < 0) + { + conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno)); + conn.m_ok = false; + return false; + } + size -= res; + } + return true; }