X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/9a3ec37da47b35901d0ad25a257398895c37bfb1..3a9bc658f505e423b3be181948f1870a09915ea9:/projects/rlm_stg/stg_client.cpp diff --git a/projects/rlm_stg/stg_client.cpp b/projects/rlm_stg/stg_client.cpp index 6987976b..399b971d 100644 --- a/projects/rlm_stg/stg_client.cpp +++ b/projects/rlm_stg/stg_client.cpp @@ -20,34 +20,236 @@ #include "stg_client.h" +#include "stg/json_parser.h" +#include "stg/json_generator.h" #include "stg/common.h" -#include +#include +#include +#include +#include -#include +#include +#include +#include // UNIX +#include // IP +#include // TCP +#include + +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; namespace { +double CONN_TIMEOUT = 5; +double PING_TIMEOUT = 1; + STG_CLIENT* stgClient = NULL; -unsigned fromType(STG_CLIENT::TYPE type) +std::string toStage(STG_CLIENT::TYPE type) { - return static_cast(type); + switch (type) + { + case STG_CLIENT::AUTHORIZE: return "authorize"; + case STG_CLIENT::AUTHENTICATE: return "authenticate"; + case STG_CLIENT::POST_AUTH: return "postauth"; + case STG_CLIENT::PRE_ACCT: return "preacct"; + case STG_CLIENT::ACCOUNT: return "accounting"; + } + return ""; } -STG::SGCP::TransportType toTransport(const std::string& value) +enum Packet { - std::string type = ToLower(value); - if (type == "unix") return STG::SGCP::UNIX; - else if (type == "udp") return STG::SGCP::UDP; - else if (type == "tcp") return STG::SGCP::TCP; - throw ChannelConfig::Error("Invalid transport type. Should be 'unix', 'udp' or 'tcp'."); -} + PING, + PONG, + DATA +}; + +std::map packetCodes; +std::map resultCodes; + +class PacketParser : public EnumParser +{ + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class ResultParser : public EnumParser +{ + public: + ResultParser(NodeParser* next, bool& result, std::string& resultStr) + : EnumParser(next, result, resultStr, resultCodes) + { + if (!resultCodes.empty()) + return; + resultCodes["no"] = false; + resultCodes["ok"] = true; + } +}; + +class TopParser : public NodeParser +{ + public: + TopParser() + : m_packetParser(this, m_packet, m_packetStr), + m_resultParser(this, m_result, m_resultStr), + m_replyParser(this, m_reply), + m_modifyParser(this, m_modify) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) + { + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "result") + return &m_resultParser; + else if (key == "reply") + return &m_replyParser; + else if (key == "modify") + return &m_modifyParser; + + return this; + } + virtual NodeParser* parseEndMap() { return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& resultStr() const { return m_resultStr; } + bool result() const { return m_result; } + const PairsParser::Pairs& reply() const { return m_reply; } + const PairsParser::Pairs& modify() const { return m_modify; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_resultStr; + bool m_result; + PairsParser::Pairs m_reply; + PairsParser::Pairs m_modify; + + PacketParser m_packetParser; + ResultParser m_resultParser; + PairsParser m_replyParser; + PairsParser m_modifyParser; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser() : Parser( &m_topParser ) {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& resultStr() const { return m_topParser.resultStr(); } + bool result() const { return m_topParser.result(); } + const PairsParser::Pairs& reply() const { return m_topParser.reply(); } + const PairsParser::Pairs& modify() const { return m_topParser.modify(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen* map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; } +class STG_CLIENT::Impl +{ +public: + Impl(const std::string& address, Callback callback, void* data); + ~Impl(); + + bool stop(); + + bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs); + +private: + ChannelConfig m_config; + + int m_sock; + + bool m_running; + bool m_stopped; + + time_t m_lastPing; + time_t m_lastActivity; + + pthread_t m_thread; + pthread_mutex_t m_mutex; + + Callback m_callback; + void* m_data; + + ProtoParser m_parser; + + void m_writeHeader(TYPE type, const std::string& userName, const std::string& password); + void m_writePairBlock(const PAIRS& source); + PAIRS m_readPairBlock(); + + static void* run(void* ); + + void runImpl(); + + int connect(); + int connectTCP(); + int connectUNIX(); + + bool read(); + bool tick(); + + bool process(); + bool processPing(); + bool processPong(); + bool processData(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + ChannelConfig::ChannelConfig(std::string addr) - : transport(STG::SGCP::TCP) { // unix:pass@/var/run/stg.sock // tcp:secret@192.168.0.1:12345 @@ -56,7 +258,7 @@ ChannelConfig::ChannelConfig(std::string addr) size_t pos = addr.find_first_of(':'); if (pos == std::string::npos) throw Error("Missing transport name."); - transport = toTransport(addr.substr(0, pos)); + transport = ToLower(addr.substr(0, pos)); addr = addr.substr(pos + 1); if (addr.empty()) throw Error("Missing address to connect to."); @@ -67,7 +269,7 @@ ChannelConfig::ChannelConfig(std::string addr) if (addr.empty()) throw Error("Missing address to connect to."); } - if (transport == STG::SGCP::UNIX) + if (transport == "unix") { address = addr; return; @@ -76,35 +278,28 @@ ChannelConfig::ChannelConfig(std::string addr) if (pos == std::string::npos) throw Error("Missing port."); address = addr.substr(0, pos); - if (str2x(addr.substr(pos + 1), port)) + portStr = addr.substr(pos + 1); + if (str2x(portStr, port)) throw Error("Invalid port value."); } -STG_CLIENT::STG_CLIENT(const std::string& address) - : m_config(address), - m_proto(m_config.transport, m_config.key), - m_thread(boost::bind(&STG_CLIENT::m_run, this)) +STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data) + : m_impl(new Impl(address, callback, data)) { } STG_CLIENT::~STG_CLIENT() { - stop(); } bool STG_CLIENT::stop() { - return m_proto.stop(); + return m_impl->stop(); } -RESULT STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) +bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) { - m_writeHeader(type, userName, password); - m_writePairBlock(pairs); - RESULT result; - result.modify = m_readPairBlock(); - result.reply = m_readPairBlock(); - return result; + return m_impl->request(type, userName, password, pairs); } STG_CLIENT* STG_CLIENT::get() @@ -112,12 +307,12 @@ STG_CLIENT* STG_CLIENT::get() return stgClient; } -bool STG_CLIENT::configure(const std::string& address) +bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data) { if ( stgClient != NULL && stgClient->stop() ) delete stgClient; try { - stgClient = new STG_CLIENT(address); + stgClient = new STG_CLIENT(address, callback, data); return true; } catch (const ChannelConfig::Error& ex) { // TODO: Log it @@ -125,49 +320,281 @@ bool STG_CLIENT::configure(const std::string& address) return false; } -void STG_CLIENT::m_writeHeader(TYPE type, const std::string& userName, const std::string& password) +STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data) + : m_config(address), + m_sock(connect()), + m_running(false), + m_stopped(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing), + m_callback(callback), + m_data(data) { - try { - m_proto.writeAll(fromType(type)); - m_proto.writeAll(userName); - m_proto.writeAll(password); - } catch (const STG::SGCP::Proto::Error& ex) { - throw Error(ex.what()); + int res = pthread_create(&m_thread, NULL, run, this); + if (res != 0) + throw Error("Failed to create thread: " + std::string(strerror(errno))); +} + +STG_CLIENT::Impl::~Impl() +{ + stop(); + shutdown(m_sock, SHUT_RDWR); + close(m_sock); +} + +bool STG_CLIENT::Impl::stop() +{ + if (m_stopped) + return true; + + m_running = false; + + for (size_t i = 0; i < 25 && !m_stopped; i++) { + struct timespec ts = {0, 200000000}; + nanosleep(&ts, NULL); } + + if (m_stopped) { + pthread_join(m_thread, NULL); + return true; + } + + return false; } -void STG_CLIENT::m_writePairBlock(const PAIRS& pairs) +bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) { - try { - m_proto.writeAll(pairs.size()); - for (size_t i = 0; i < pairs.size(); ++i) { - m_proto.writeAll(pairs[i].first); - m_proto.writeAll(pairs[i].second); + boost::scoped_ptr map(new MapGen); + for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it) + map->add(it->first, new StringGen(it->second)); + map->add("Radius-Username", new StringGen(userName)); + map->add("Radius-Userpass", new StringGen(password)); + + PacketGen gen("data"); + gen.add("stage", toStage(type)) + .add("pairs", map.get()); + + m_lastPing = time(NULL); + + return generate(gen, &STG_CLIENT::Impl::write, this); +} + +void STG_CLIENT::Impl::runImpl() +{ + m_running = true; + + while (m_running) { + fd_set fds; + + FD_ZERO(&fds); + FD_SET(m_sock, &fds); + + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 500000; + + int res = select(m_sock + 1, &fds, NULL, NULL, &tv); + if (res < 0) + { + //m_error = std::string("'select' is failed: '") + strerror(errno) + "'."; + //m_logger(m_error); + break; + } + + if (!m_running) + break; + + if (res > 0) + { + if (FD_ISSET(m_sock, &fds)) + m_running = read(); } - } catch (const STG::SGCP::Proto::Error& ex) { - throw Error(ex.what()); + else + m_running = tick(); } + + m_stopped = true; } -PAIRS STG_CLIENT::m_readPairBlock() +int STG_CLIENT::Impl::connect() { - try { - size_t count = m_proto.readAll(); - if (count == 0) - return PAIRS(); - PAIRS res(count); - for (size_t i = 0; i < count; ++i) { - res[i].first = m_proto.readAll(); - res[i].second = m_proto.readAll(); + if (m_config.transport == "tcp") + return connectTCP(); + else if (m_config.transport == "unix") + return connectUNIX(); + throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'."); +} + +int STG_CLIENT::Impl::connectTCP() +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_family = AF_INET; /* Allow IPv4 */ + hints.ai_socktype = SOCK_STREAM; /* Stream socket */ + hints.ai_flags = 0; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + addrinfo* ais = NULL; + int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais); + if (res != 0) + throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res)); + + for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next) + { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) + { + Error error(std::string("Error creating TCP socket: ") + strerror(errno)); + freeaddrinfo(ais); + throw error; + } + if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + // TODO: log it. + continue; + } + freeaddrinfo(ais); + return fd; + } + + freeaddrinfo(ais); + + throw Error("Failed to resolve '" + m_config.address); +}; + +int STG_CLIENT::Impl::connectUNIX() +{ + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) + throw Error(std::string("Error creating UNIX socket: ") + strerror(errno)); + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length()); + if (::connect(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) + { + Error error(std::string("Error binding UNIX socket: ") + strerror(errno)); + shutdown(fd, SHUT_RDWR); + close(fd); + throw error; + } + return fd; +} + +bool STG_CLIENT::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); + if (res < 0) + { + //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + return false; + } + m_lastActivity = time(NULL); + if (res == 0) + { + if (!m_parser.done()) + { + //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + return false; + } + return process(); + } + return m_parser.append(buffer.data(), res); +} + +bool STG_CLIENT::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) + { + //m_logger("Connection to " + m_remote + " timed out."); + return false; + } + if (difftime(now, m_lastPing) > PING_TIMEOUT) + sendPing(); + return true; +} + +bool STG_CLIENT::Impl::process() +{ + switch (m_parser.packet()) + { + case PING: + return processPing(); + case PONG: + return processPong(); + case DATA: + return processData(); + } + //m_logger("Received invalid packet type: " + m_parser.packetStr()); + return false; +} + +bool STG_CLIENT::Impl::processPing() +{ + return sendPong(); +} + +bool STG_CLIENT::Impl::processPong() +{ + m_lastActivity = time(NULL); + return true; +} + +bool STG_CLIENT::Impl::processData() +{ + RESULT result; + for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it) + result.reply.push_back(std::make_pair(it->first, it->second)); + for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it) + result.modify.push_back(std::make_pair(it->first, it->second)); + return m_callback(m_data, result); +} + +bool STG_CLIENT::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &STG_CLIENT::Impl::write, this); +} + +bool STG_CLIENT::Impl::sendPong() +{ + PacketGen gen("pong"); + + m_lastPing = time(NULL); + + return generate(gen, &STG_CLIENT::Impl::write, this); +} + +bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size) +{ + STG_CLIENT::Impl& impl = *static_cast(data); + while (size > 0) + { + ssize_t res = ::write(impl.m_sock, buf, size); + if (res < 0) + { + //conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno)); + return false; } - return res; - } catch (const STG::SGCP::Proto::Error& ex) { - throw Error(ex.what()); + size -= res; } + return true; } -void STG_CLIENT::m_run() +void* STG_CLIENT::Impl::run(void* data) { - m_proto.connect(m_config.address, m_config.port); - m_proto.run(); + Impl& impl = *static_cast(data); + impl.runImpl(); + return NULL; }