From: Maxim Mamontov Date: Mon, 24 Aug 2015 13:02:51 +0000 (+0300) Subject: Implemented backend for rlm_stg. X-Git-Url: https://git.stg.codes/stg.git/commitdiff_plain/9a3ec37da47b35901d0ad25a257398895c37bfb1 Implemented backend for rlm_stg. --- diff --git a/projects/stargazer/plugins/other/radius/config.cpp b/projects/stargazer/plugins/other/radius/config.cpp index 8a90567d..64ff3707 100644 --- a/projects/stargazer/plugins/other/radius/config.cpp +++ b/projects/stargazer/plugins/other/radius/config.cpp @@ -39,6 +39,7 @@ struct ParserError : public std::runtime_error position(pos), error(message) {} + virtual ~ParserError() throw() {} size_t position; std::string error; @@ -51,12 +52,12 @@ size_t skipSpaces(const std::string& value, size_t start) return start; } -size_t checkChar(const std:string& value, size_t start, char ch) +size_t checkChar(const std::string& value, size_t start, char ch) { if (start >= value.length()) - throw ParserError(start, "Unexpected end of string. Expected '" + std::string(ch) + "'."); + throw ParserError(start, "Unexpected end of string. Expected '" + std::string(1, ch) + "'."); if (value[start] != ch) - throw ParserError(start, "Expected '" + std::string(ch) + "', got '" + std::string(value[start]) + "'."); + throw ParserError(start, "Expected '" + std::string(1, ch) + "', got '" + std::string(1, value[start]) + "'."); return start + 1; } @@ -71,7 +72,7 @@ std::pair readString(const std::string& value, size_t start else throw ParserError(start, "Unexpected whitespace. Expected string."); } - return dest; + return std::make_pair(start, dest); } Config::Pairs toPairs(const std::vector& values) @@ -90,11 +91,11 @@ Config::Pairs toPairs(const std::vector& values) start = key.first; pair.first = key.second; start = skipSpaces(value, start); - start = checkChar(value, start, ',') + start = checkChar(value, start, ','); start = skipSpaces(value, start); - std::pair value = readString(value, start); + std::pair val = readString(value, start); start = key.first; - pair.second = value.second; + pair.second = val.second; start = skipSpaces(value, start); start = checkChar(value, start, ')'); if (res.find(pair.first) != res.end()) @@ -125,7 +126,7 @@ T toInt(const std::vector& values) if (values.empty()) return 0; T res = 0; - if (srt2x(values[0], res) == 0) + if (str2x(values[0], res) == 0) return res; return 0; } @@ -133,24 +134,24 @@ T toInt(const std::vector& values) Config::Pairs parseVector(const std::string& paramName, const MODULE_SETTINGS& params) { for (size_t i = 0; i < params.moduleParams.size(); ++i) - if (params.moduleParams[i].first == paramName) - return toPairs(params.moduleParams[i].second); + if (params.moduleParams[i].param == paramName) + return toPairs(params.moduleParams[i].value); return Config::Pairs(); } bool parseBool(const std::string& paramName, const MODULE_SETTINGS& params) { for (size_t i = 0; i < params.moduleParams.size(); ++i) - if (params.moduleParams[i].first == paramName) - return toBool(params.moduleParams[i].second); + if (params.moduleParams[i].param == paramName) + return toBool(params.moduleParams[i].value); return false; } std::string parseString(const std::string& paramName, const MODULE_SETTINGS& params) { for (size_t i = 0; i < params.moduleParams.size(); ++i) - if (params.moduleParams[i].first == paramName) - return toString(params.moduleParams[i].second); + if (params.moduleParams[i].param == paramName) + return toString(params.moduleParams[i].value); return ""; } @@ -158,8 +159,8 @@ template T parseInt(const std::string& paramName, const MODULE_SETTINGS& params) { for (size_t i = 0; i < params.moduleParams.size(); ++i) - if (params.moduleParams[i].first == paramName) - return toInt(params.moduleParams[i].second); + if (params.moduleParams[i].param == paramName) + return toInt(params.moduleParams[i].value); return 0; } @@ -171,6 +172,7 @@ Config::Config(const MODULE_SETTINGS& settings) reply(parseVector("reply", settings)), verbose(parseBool("verbose", settings)), bindAddress(parseString("bind_address", settings)), + portStr(parseString("port", settings)), port(parseInt("port", settings)), key(parseString("key", settings)) { diff --git a/projects/stargazer/plugins/other/radius/config.h b/projects/stargazer/plugins/other/radius/config.h index 8e5055d8..d6553e14 100644 --- a/projects/stargazer/plugins/other/radius/config.h +++ b/projects/stargazer/plugins/other/radius/config.h @@ -34,7 +34,8 @@ namespace STG struct Config { typedef std::map Pairs; - typedef Pairs::value_type Pair; + typedef std::pair Pair; + enum Type { UNIX, TCP }; Config(const MODULE_SETTINGS& settings); @@ -44,7 +45,9 @@ struct Config bool verbose; + Type connectionType; std::string bindAddress; + std::string portStr; uint16_t port; std::string key; }; diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp index 223eb2cf..392624a6 100644 --- a/projects/stargazer/plugins/other/radius/conn.cpp +++ b/projects/stargazer/plugins/other/radius/conn.cpp @@ -22,20 +22,211 @@ #include "config.h" +#include "stg/json_parser.h" +#include "stg/json_generator.h" #include "stg/users.h" #include "stg/user.h" #include "stg/logger.h" #include "stg/common.h" +#include + +#include #include #include +#include + using STG::Conn; +using STG::Config; +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; -Conn::Conn(USERS& users, PLUGIN_LOGGER & logger, const Config& config) - : m_users(users), - m_logger(logger), - m_config(config) +namespace +{ + +double CONN_TIMEOUT = 5; +double PING_TIMEOUT = 1; + +enum Packet +{ + PING, + PONG, + DATA +}; + +enum Stage +{ + AUTHORIZE, + AUTHENTICATE, + PREACCT, + ACCOUNTING, + POSTAUTH +}; + +std::map packetCodes; +std::map stageCodes; + +class PacketParser : public EnumParser +{ + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class StageParser : public EnumParser +{ + public: + StageParser(NodeParser* next, Stage& stage, std::string& stageStr) + : EnumParser(next, stage, stageStr, stageCodes) + { + if (!stageCodes.empty()) + return; + stageCodes["authorize"] = AUTHORIZE; + stageCodes["authenticate"] = AUTHENTICATE; + stageCodes["preacct"] = PREACCT; + stageCodes["accounting"] = ACCOUNTING; + stageCodes["postauth"] = POSTAUTH; + } +}; + +class TopParser : public NodeParser +{ + public: + TopParser() + : m_packetParser(this, m_packet, m_packetStr), + m_stageParser(this, m_stage, m_stageStr), + m_pairsParser(this, m_data) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) + { + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "stage") + return &m_stageParser; + else if (key == "pairs") + return &m_pairsParser; + + return this; + } + virtual NodeParser* parseEndMap() { return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& stageStr() const { return m_stageStr; } + Stage stage() const { return m_stage; } + const Config::Pairs& data() const { return m_data; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_stageStr; + Stage m_stage; + Config::Pairs m_data; + + PacketParser m_packetParser; + StageParser m_stageParser; + PairsParser m_pairsParser; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser() : Parser( &m_topParser ) {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& stageStr() const { return m_topParser.stageStr(); } + Stage stage() const { return m_topParser.stage(); } + const Config::Pairs& data() const { return m_topParser.data(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen* map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; + +} + +class Conn::Impl +{ + public: + Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote); + ~Impl(); + + int sock() const { return m_sock; } + + bool read(); + bool tick(); + + bool isOk() const { return m_ok; } + + private: + USERS& m_users; + PLUGIN_LOGGER& m_logger; + const Config& m_config; + int m_sock; + std::string m_remote; + bool m_ok; + time_t m_lastPing; + time_t m_lastActivity; + ProtoParser m_parser; + + bool process(); + bool processPing(); + bool processPong(); + bool processData(); + bool answer(const USER& user); + bool answerNo(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + +Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) + : m_impl(new Impl(users, logger, config, fd, remote)) { } @@ -43,21 +234,208 @@ Conn::~Conn() { } +int Conn::sock() const +{ + return m_impl->sock(); +} + bool Conn::read() { - ssize_t res = read(m_sock, m_buffer, m_bufferSize); + return m_impl->read(); +} + +bool Conn::tick() +{ + return m_impl->tick(); +} + +bool Conn::isOk() const +{ + return m_impl->isOk(); +} + +Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) + : m_users(users), + m_logger(logger), + m_config(config), + m_sock(fd), + m_remote(remote), + m_ok(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing) +{ +} + +Conn::Impl::~Impl() +{ + close(m_sock); +} + +bool Conn::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); if (res < 0) { - m_state = ERROR; - Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Reason: '" + strerror(errno) + "'"); + m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + m_ok = false; return false; } - if (res == 0 && m_state != DATA) // EOF is ok for data. + m_lastActivity = time(NULL); + if (res == 0) + { + if (!m_parser.done()) + { + m_ok = false; + m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + return false; + } + return process(); + } + return m_parser.append(buffer.data(), res); +} + +bool Conn::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) { - m_state = ERROR; - Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Unexpected EOF."); + m_logger("Connection to " + m_remote + " timed out."); + m_ok = false; return false; } - m_bufferSize -= res; - return HandleBuffer(res); + if (difftime(now, m_lastPing) > PING_TIMEOUT) + sendPing(); + return true; +} + +bool Conn::Impl::process() +{ + switch (m_parser.packet()) + { + case PING: + return processPing(); + case PONG: + return processPong(); + case DATA: + return processData(); + } + m_logger("Received invalid packet type: " + m_parser.packetStr()); + return false; +} + +bool Conn::Impl::processPing() +{ + return sendPong(); +} + +bool Conn::Impl::processPong() +{ + m_lastActivity = time(NULL); + return true; +} + +bool Conn::Impl::processData() +{ + int handle = m_users.OpenSearch(); + + USER_PTR user = NULL; + bool match = true; + while (m_users.SearchNext(handle, &user)) + { + if (user == NULL) + continue; + + match = true; + for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it) + { + Config::Pairs::const_iterator pos = m_parser.data().find(it->first); + if (pos == m_parser.data().end()) + { + match = false; + break; + } + if (user->GetParamValue(it->second) != pos->second) + { + match = false; + break; + } + } + if (!match) + continue; + answer(*user); + break; + } + + if (!match) + answerNo(); + + m_users.CloseSearch(handle); + + return true; +} + +bool Conn::Impl::answer(const USER& user) +{ + boost::scoped_ptr reply(new MapGen); + for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it) + reply->add(it->first, new StringGen(user.GetParamValue(it->second))); + + boost::scoped_ptr modify(new MapGen); + for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it) + modify->add(it->first, new StringGen(user.GetParamValue(it->second))); + + PacketGen gen("data"); + gen.add("result", "ok") + .add("reply", reply.get()) + .add("modify", modify.get()); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::answerNo() +{ + PacketGen gen("data"); + gen.add("result", "ok"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPong() +{ + PacketGen gen("pong"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::write(void* data, const char* buf, size_t size) +{ + Conn::Impl& conn = *static_cast(data); + while (size > 0) + { + ssize_t res = ::write(conn.m_sock, buf, size); + if (res < 0) + { + conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno)); + conn.m_ok = false; + return false; + } + size -= res; + } + return true; } diff --git a/projects/stargazer/plugins/other/radius/conn.h b/projects/stargazer/plugins/other/radius/conn.h index 31ebe1dd..38f2db2f 100644 --- a/projects/stargazer/plugins/other/radius/conn.h +++ b/projects/stargazer/plugins/other/radius/conn.h @@ -21,35 +21,35 @@ #ifndef __STG_SGCONFIG_CONN_H__ #define __STG_SGCONFIG_CONN_H__ -#include "stg/os_int.h" +#include -#include #include +class USER; class USERS; +class PLUGIN_LOGGER; namespace STG { +class Config; + class Conn { public: - struct Error : public std::runtime_error - { - Error(const std::string& message) : runtime_error(message.c_str()) {} - }; - - Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config); + Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote); ~Conn(); - int sock() const { return m_sock; } + int sock() const; bool read(); + bool tick(); + + bool isOk() const; private: - USERS& m_users; - PLUGIN_LOGGER& m_logger; - const Config& m_config; + class Impl; + boost::scoped_ptr m_impl; }; } diff --git a/projects/stargazer/plugins/other/radius/radius.cpp b/projects/stargazer/plugins/other/radius/radius.cpp index fe989b28..daa8634c 100644 --- a/projects/stargazer/plugins/other/radius/radius.cpp +++ b/projects/stargazer/plugins/other/radius/radius.cpp @@ -24,10 +24,22 @@ #include "stg/users.h" #include "stg/plugin_creator.h" +#include +#include #include -#include +#include #include +#include +#include +#include // UNIX +#include // IP +#include // TCP +#include + +using STG::Config; +using STG::Conn; + namespace { @@ -41,10 +53,12 @@ extern "C" PLUGIN * GetPlugin() } RADIUS::RADIUS() - : m_running(false), + : m_config(m_settings), + m_running(false), m_stopped(true), m_users(NULL), m_store(NULL), + m_listenSocket(0), m_logger(GetPluginLogger(GetStgLogger(), "radius")) { } @@ -53,7 +67,7 @@ int RADIUS::ParseSettings() { try { m_config = STG::Config(m_settings); - return 0; + return reconnect() ? 0 : -1; } catch (const std::runtime_error& ex) { m_logger("Failed to parse settings. %s", ex.what()); return -1; @@ -107,6 +121,112 @@ void* RADIUS::run(void* d) return NULL; } +bool RADIUS::reconnect() +{ + if (!m_conns.empty()) + { + std::deque::const_iterator it; + for (it = m_conns.begin(); it != m_conns.end(); ++it) + delete(*it); + m_conns.clear(); + } + if (m_listenSocket != 0) + { + shutdown(m_listenSocket, SHUT_RDWR); + close(m_listenSocket); + if (m_config.connectionType == Config::UNIX) + unlink(m_config.bindAddress.c_str()); + } + if (m_config.connectionType == Config::UNIX) + m_listenSocket = createUNIX(); + else + m_listenSocket = createTCP(); + if (m_listenSocket == 0) + return false; + if (listen(m_listenSocket, 100) == -1) + { + m_error = std::string("Error starting to listen socket: ") + strerror(errno); + m_logger(m_error); + return false; + } + return true; +} + +int RADIUS::createUNIX() const +{ + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) + { + m_error = std::string("Error creating UNIX socket: ") + strerror(errno); + m_logger(m_error); + return 0; + } + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, m_config.bindAddress.c_str(), m_config.bindAddress.length()); + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + m_error = std::string("Error binding UNIX socket: ") + strerror(errno); + m_logger(m_error); + return 0; + } + return fd; +} + +int RADIUS::createTCP() const +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_family = AF_INET; /* Allow IPv4 */ + hints.ai_socktype = SOCK_STREAM; /* Stream socket */ + hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + addrinfo* ais = NULL; + int res = getaddrinfo(m_config.bindAddress.c_str(), m_config.portStr.c_str(), &hints, &ais); + if (res != 0) + { + m_error = "Error resolvin address '" + m_config.bindAddress + "': " + gai_strerror(res); + m_logger(m_error); + return 0; + } + + for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next) + { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) + { + m_error = std::string("Error creating TCP socket: ") + strerror(errno); + m_logger(m_error); + freeaddrinfo(ais); + return 0; + } + if (bind(fd, ai->ai_addr, ai->ai_addrlen) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + m_error = std::string("Error binding TCP socket: ") + strerror(errno); + m_logger(m_error); + continue; + } + freeaddrinfo(ais); + return fd; + } + + m_error = "Failed to resolve '" + m_config.bindAddress; + m_logger(m_error); + + freeaddrinfo(ais); + return 0; +} + void RADIUS::runImpl() { m_running = true; @@ -133,6 +253,11 @@ void RADIUS::runImpl() if (res > 0) handleEvents(fds); + else + { + for (std::deque::iterator it = m_conns.begin(); it != m_conns.end(); ++it) + (*it)->tick(); + } cleanupConns(); } @@ -163,7 +288,7 @@ void RADIUS::cleanupConns() { std::deque::iterator pos; for (pos = m_conns.begin(); pos != m_conns.end(); ++pos) - if (((*pos)->isDone() && !(*pos)->isKeepAlive()) || !(*pos)->isOk()) { + if (!(*pos)->isOk()) { delete *pos; *pos = NULL; } @@ -182,5 +307,46 @@ void RADIUS::handleEvents(const fd_set & fds) for (it = m_conns.begin(); it != m_conns.end(); ++it) if (FD_ISSET((*it)->sock(), &fds)) (*it)->read(); + else + (*it)->tick(); + } +} + +void RADIUS::acceptConnection() +{ + if (m_config.connectionType == Config::UNIX) + acceptUNIX(); + else + acceptTCP(); +} + +void RADIUS::acceptUNIX() +{ + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + socklen_t size = sizeof(addr); + int res = accept(m_listenSocket, reinterpret_cast(&addr), &size); + if (res == -1) + { + m_error = std::string("Failed to accept UNIX connection: ") + strerror(errno); + m_logger(m_error); + return; + } + m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, addr.sun_path)); +} + +void RADIUS::acceptTCP() +{ + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + socklen_t size = sizeof(addr); + int res = accept(m_listenSocket, reinterpret_cast(&addr), &size); + if (res == -1) + { + m_error = std::string("Failed to accept TCP connection: ") + strerror(errno); + m_logger(m_error); + return; } + std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port)); + m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, remote)); } diff --git a/projects/stargazer/plugins/other/radius/radius.h b/projects/stargazer/plugins/other/radius/radius.h index ff8ecf5e..2573cf9e 100644 --- a/projects/stargazer/plugins/other/radius/radius.h +++ b/projects/stargazer/plugins/other/radius/radius.h @@ -26,7 +26,11 @@ #include "stg/module_settings.h" #include "stg/logger.h" +#include "config.h" +#include "conn.h" + #include +#include #include #include @@ -43,10 +47,10 @@ public: RADIUS(); virtual ~RADIUS() {} - void SetUsers(USERS* u) { users = u; } - void SetStore(STORE* s) { store = s; } + void SetUsers(USERS* u) { m_users = u; } + void SetStore(STORE* s) { m_store = s; } void SetStgSettings(const SETTINGS*) {} - void SetSettings(const MODULE_SETTINGS& s) { settings = s; } + void SetSettings(const MODULE_SETTINGS& s) { m_settings = s; } int ParseSettings(); int Start(); @@ -67,12 +71,17 @@ private: static void* run(void*); - void rumImpl(); + bool reconnect(); + int createUNIX() const; + int createTCP() const; + void runImpl(); int maxFD() const; void buildFDSet(fd_set & fds) const; void cleanupConns(); void handleEvents(const fd_set & fds); void acceptConnection(); + void acceptUNIX(); + void acceptTCP(); mutable std::string m_error; STG::Config m_config; @@ -85,8 +94,10 @@ private: USERS* m_users; const STORE* m_store; + int m_listenSocket; + std::deque m_conns; + pthread_t m_thread; - pthread_mutex_t m_mutex; PLUGIN_LOGGER m_logger; };