X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/9a3ec37da47b35901d0ad25a257398895c37bfb1..852b085dcef99353ae1bedefbaf654b5b72c9f64:/projects/stargazer/plugins/other/radius/conn.cpp diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp index 392624a6..c0270e78 100644 --- a/projects/stargazer/plugins/other/radius/conn.cpp +++ b/projects/stargazer/plugins/other/radius/conn.cpp @@ -32,10 +32,13 @@ #include #include +#include #include #include #include +#include +#include using STG::Conn; using STG::Config; @@ -50,8 +53,8 @@ using STG::JSON::StringGen; namespace { -double CONN_TIMEOUT = 5; -double PING_TIMEOUT = 1; +double CONN_TIMEOUT = 60; +double PING_TIMEOUT = 10; enum Packet { @@ -105,10 +108,12 @@ class StageParser : public EnumParser class TopParser : public NodeParser { public: - TopParser() + typedef void (*Callback) (void* /*data*/); + TopParser(Callback callback, void* data) : m_packetParser(this, m_packet, m_packetStr), m_stageParser(this, m_stage, m_stageStr), - m_pairsParser(this, m_data) + m_pairsParser(this, m_data), + m_callback(callback), m_callbackData(data) {} virtual NodeParser* parseStartMap() { return this; } @@ -125,7 +130,7 @@ class TopParser : public NodeParser return this; } - virtual NodeParser* parseEndMap() { return this; } + virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; } const std::string& packetStr() const { return m_packetStr; } Packet packet() const { return m_packet; } @@ -143,12 +148,18 @@ class TopParser : public NodeParser PacketParser m_packetParser; StageParser m_stageParser; PairsParser m_pairsParser; + + Callback m_callback; + void* m_callbackData; }; class ProtoParser : public Parser { public: - ProtoParser() : Parser( &m_topParser ) {} + ProtoParser(TopParser::Callback callback, void* data) + : Parser( &m_topParser ), + m_topParser(callback, data) + {} const std::string& packetStr() const { return m_topParser.packetStr(); } Packet packet() const { return m_topParser.packet(); } @@ -182,6 +193,11 @@ class PacketGen : public Gen m_gen.add(key, map); return *this; } + PacketGen& add(const std::string& key, MapGen& map) + { + m_gen.add(key, map); + return *this; + } private: MapGen m_gen; StringGen m_type; @@ -213,10 +229,27 @@ class Conn::Impl time_t m_lastActivity; ProtoParser m_parser; - bool process(); - bool processPing(); - bool processPong(); - bool processData(); + const Config::Pairs& stagePairs(Config::Pairs Config::Section::* pairs) const + { + switch (m_parser.stage()) + { + case AUTHORIZE: return m_config.autz.*pairs; + case AUTHENTICATE: return m_config.auth.*pairs; + case POSTAUTH: return m_config.postauth.*pairs; + case PREACCT: return m_config.preacct.*pairs; + case ACCOUNTING: return m_config.acct.*pairs; + } + throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'."); + } + + const Config::Pairs& match() const { return stagePairs(&Config::Section::match); } + const Config::Pairs& modify() const { return stagePairs(&Config::Section::modify); } + const Config::Pairs& reply() const { return stagePairs(&Config::Section::reply); } + + static void process(void* data); + void processPing(); + void processPong(); + void processData(); bool answer(const USER& user); bool answerNo(); bool sendPing(); @@ -262,7 +295,8 @@ Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int m_remote(remote), m_ok(true), m_lastPing(time(NULL)), - m_lastActivity(m_lastPing) + m_lastActivity(m_lastPing), + m_parser(&Conn::Impl::process, this) { } @@ -281,16 +315,12 @@ bool Conn::Impl::read() m_ok = false; return false; } + printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str()); m_lastActivity = time(NULL); if (res == 0) { - if (!m_parser.done()) - { - m_ok = false; - m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); - return false; - } - return process(); + m_ok = false; + return true; } return m_parser.append(buffer.data(), res); } @@ -300,95 +330,114 @@ bool Conn::Impl::tick() time_t now = time(NULL); if (difftime(now, m_lastActivity) > CONN_TIMEOUT) { + int delta = difftime(now, m_lastActivity); + printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta); m_logger("Connection to " + m_remote + " timed out."); m_ok = false; return false; } if (difftime(now, m_lastPing) > PING_TIMEOUT) + { + int delta = difftime(now, m_lastPing); + printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta); sendPing(); + } return true; } -bool Conn::Impl::process() +void Conn::Impl::process(void* data) { - switch (m_parser.packet()) + Impl& impl = *static_cast(data); + try + { + switch (impl.m_parser.packet()) + { + case PING: + impl.processPing(); + return; + case PONG: + impl.processPong(); + return; + case DATA: + impl.processData(); + return; + } + } + catch (const std::exception& ex) { - case PING: - return processPing(); - case PONG: - return processPong(); - case DATA: - return processData(); + printfd(__FILE__, "Processing error. %s", ex.what()); + impl.m_logger("Processing error. %s", ex.what()); } - m_logger("Received invalid packet type: " + m_parser.packetStr()); - return false; + printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str()); + impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr()); } -bool Conn::Impl::processPing() +void Conn::Impl::processPing() { - return sendPong(); + printfd(__FILE__, "Got ping. Sending pong...\n"); + sendPong(); } -bool Conn::Impl::processPong() +void Conn::Impl::processPong() { + printfd(__FILE__, "Got pong.\n"); m_lastActivity = time(NULL); - return true; } -bool Conn::Impl::processData() +void Conn::Impl::processData() { + printfd(__FILE__, "Got data.\n"); int handle = m_users.OpenSearch(); USER_PTR user = NULL; - bool match = true; - while (m_users.SearchNext(handle, &user)) + bool matched = false; + while (m_users.SearchNext(handle, &user) == 0) { if (user == NULL) continue; - match = true; - for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it) + matched = true; + for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it) { Config::Pairs::const_iterator pos = m_parser.data().find(it->first); if (pos == m_parser.data().end()) { - match = false; + matched = false; break; } if (user->GetParamValue(it->second) != pos->second) { - match = false; + matched = false; break; } } - if (!match) + if (!matched) continue; answer(*user); break; } - if (!match) + if (!matched) answerNo(); m_users.CloseSearch(handle); - - return true; } bool Conn::Impl::answer(const USER& user) { - boost::scoped_ptr reply(new MapGen); - for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it) - reply->add(it->first, new StringGen(user.GetParamValue(it->second))); + printfd(__FILE__, "Got match. Sending answer...\n"); + MapGen replyData; + for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it) + replyData.add(it->first, new StringGen(user.GetParamValue(it->second))); - boost::scoped_ptr modify(new MapGen); - for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it) - modify->add(it->first, new StringGen(user.GetParamValue(it->second))); + MapGen modifyData; + for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it) + modifyData.add(it->first, new StringGen(user.GetParamValue(it->second))); PacketGen gen("data"); gen.add("result", "ok") - .add("reply", reply.get()) - .add("modify", modify.get()); + .add("reply", replyData) + .add("modify", modifyData); m_lastPing = time(NULL); @@ -397,8 +446,9 @@ bool Conn::Impl::answer(const USER& user) bool Conn::Impl::answerNo() { + printfd(__FILE__, "No match. Sending answer...\n"); PacketGen gen("data"); - gen.add("result", "ok"); + gen.add("result", "no"); m_lastPing = time(NULL); @@ -425,10 +475,12 @@ bool Conn::Impl::sendPong() bool Conn::Impl::write(void* data, const char* buf, size_t size) { + std::string json(buf, size); + printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str()); Conn::Impl& conn = *static_cast(data); while (size > 0) { - ssize_t res = ::write(conn.m_sock, buf, size); + ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL); if (res < 0) { conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));