X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/fd082d2de4f5e845c09676b21b58cb5de7e05837..d78d51f3932ed3d3466e53eb7c7786d5d8216cc0:/projects/rlm_stg/stg_client.cpp diff --git a/projects/rlm_stg/stg_client.cpp b/projects/rlm_stg/stg_client.cpp index 75d443ad..8c7ed9b9 100644 --- a/projects/rlm_stg/stg_client.cpp +++ b/projects/rlm_stg/stg_client.cpp @@ -18,279 +18,667 @@ * Author : Maxim Mamontov */ -/* - * Realization of data access via Stargazer for RADIUS - * - * $Revision: 1.8 $ - * $Date: 2010/04/16 12:30:02 $ - * - */ +#include "stg_client.h" -#include -#include -#include // close +#include "radlog.h" + +#include "stg/json_parser.h" +#include "stg/json_generator.h" +#include "stg/common.h" +#include +#include #include #include -#include +#include +#include +#include // UNIX +#include // IP +#include // TCP +#include + +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; -#include "stg_client.h" +namespace { + +double CONN_TIMEOUT = 60; +double PING_TIMEOUT = 10; -using namespace std; +STG_CLIENT* stgClient = NULL; -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -STG_CLIENT::STG_CLIENT(const std::string & host, uint16_t port, uint16_t lp, const std::string & pass) - : localPort(lp), - password(pass), - framedIP(0) +std::string toStage(STG_CLIENT::TYPE type) { -sock = socket(AF_INET, SOCK_DGRAM, 0); -if (sock == -1) + switch (type) { - std::string message = strerror(errno); - message = "Socket create error: '" + message + "'"; - throw std::runtime_error(message); + case STG_CLIENT::AUTHORIZE: return "authorize"; + case STG_CLIENT::AUTHENTICATE: return "authenticate"; + case STG_CLIENT::POST_AUTH: return "postauth"; + case STG_CLIENT::PRE_ACCT: return "preacct"; + case STG_CLIENT::ACCOUNT: return "accounting"; } + return ""; +} -struct hostent * he = NULL; -he = gethostbyname(host.c_str()); -if (he == NULL) - { - throw std::runtime_error("gethostbyname error"); - } +enum Packet +{ + PING, + PONG, + DATA +}; -outerAddr.sin_family = AF_INET; -outerAddr.sin_port = htons(port); -outerAddr.sin_addr.s_addr = *(uint32_t *)he->h_addr; +std::map packetCodes; +std::map resultCodes; -InitEncrypt(); -} -//----------------------------------------------------------------------------- -STG_CLIENT::~STG_CLIENT() +class PacketParser : public EnumParser { -close(sock); -} -//----------------------------------------------------------------------------- -uint32_t STG_CLIENT::GetFramedIP() const -{ -return framedIP; -} -//----------------------------------------------------------------------------- -inline -void STG_CLIENT::InitEncrypt() -{ -unsigned char keyL[RAD_PASSWORD_LEN]; -memset(keyL, 0, RAD_PASSWORD_LEN); -strncpy((char *)keyL, password.c_str(), RAD_PASSWORD_LEN); -Blowfish_Init(&ctx, keyL, RAD_PASSWORD_LEN); -} -//----------------------------------------------------------------------------- -int STG_CLIENT::PrepareNet() + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class ResultParser : public EnumParser { -if (localPort != 0) - { - struct sockaddr_in localAddr; - localAddr.sin_family = AF_INET; - localAddr.sin_port = htons(localPort); - localAddr.sin_addr.s_addr = inet_addr("0.0.0.0");; + public: + ResultParser(NodeParser* next, bool& result, std::string& resultStr) + : EnumParser(next, result, resultStr, resultCodes) + { + if (!resultCodes.empty()) + return; + resultCodes["no"] = false; + resultCodes["ok"] = true; + } +}; - if (bind(sock, (struct sockaddr *)&localAddr, sizeof(localAddr))) +class TopParser : public NodeParser +{ + public: + typedef void (*Callback) (void* /*data*/); + TopParser(Callback callback, void* data) + : m_packetParser(this, m_packet, m_packetStr), + m_resultParser(this, m_result, m_resultStr), + m_replyParser(this, m_reply), + m_modifyParser(this, m_modify), + m_callback(callback), m_data(data) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) { - errorStr = "Bind failed"; - return -1; + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "result") + return &m_resultParser; + else if (key == "reply") + return &m_replyParser; + else if (key == "modify") + return &m_modifyParser; + + return this; } - } -return 0; + virtual NodeParser* parseEndMap() { m_callback(m_data); return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& resultStr() const { return m_resultStr; } + bool result() const { return m_result; } + const PairsParser::Pairs& reply() const { return m_reply; } + const PairsParser::Pairs& modify() const { return m_modify; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_resultStr; + bool m_result; + PairsParser::Pairs m_reply; + PairsParser::Pairs m_modify; + + PacketParser m_packetParser; + ResultParser m_resultParser; + PairsParser m_replyParser; + PairsParser m_modifyParser; + + Callback m_callback; + void* m_data; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser(TopParser::Callback callback, void* data) + : Parser( &m_topParser ), + m_topParser(callback, data) + {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& resultStr() const { return m_topParser.resultStr(); } + bool result() const { return m_topParser.result(); } + const PairsParser::Pairs& reply() const { return m_topParser.reply(); } + const PairsParser::Pairs& modify() const { return m_topParser.modify(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen& map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; + } -//----------------------------------------------------------------------------- -int STG_CLIENT::Start() + +class STG_CLIENT::Impl +{ +public: + Impl(const std::string& address, Callback callback, void* data); + Impl(const Impl& rhs); + ~Impl(); + + bool stop(); + bool connected() const { return m_connected; } + + bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs); + +private: + ChannelConfig m_config; + + int m_sock; + + bool m_running; + bool m_stopped; + + time_t m_lastPing; + time_t m_lastActivity; + + pthread_t m_thread; + pthread_mutex_t m_mutex; + + Callback m_callback; + void* m_data; + + ProtoParser m_parser; + + bool m_connected; + + void m_writeHeader(TYPE type, const std::string& userName, const std::string& password); + void m_writePairBlock(const PAIRS& source); + PAIRS m_readPairBlock(); + + static void* run(void* ); + + void runImpl(); + + int connect(); + int connectTCP(); + int connectUNIX(); + + bool read(); + bool tick(); + + static void process(void* data); + void processPing(); + void processPong(); + void processData(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + +ChannelConfig::ChannelConfig(std::string addr) { -return PrepareNet(); + // unix:pass@/var/run/stg.sock + // tcp:secret@192.168.0.1:12345 + // udp:key@isp.com.ua:54321 + + size_t pos = addr.find_first_of(':'); + if (pos == std::string::npos) + throw Error("Missing transport name."); + transport = ToLower(addr.substr(0, pos)); + addr = addr.substr(pos + 1); + if (addr.empty()) + throw Error("Missing address to connect to."); + pos = addr.find_first_of('@'); + if (pos != std::string::npos) { + key = addr.substr(0, pos); + addr = addr.substr(pos + 1); + if (addr.empty()) + throw Error("Missing address to connect to."); + } + if (transport == "unix") + { + address = addr; + return; + } + pos = addr.find_first_of(':'); + if (pos == std::string::npos) + throw Error("Missing port."); + address = addr.substr(0, pos); + portStr = addr.substr(pos + 1); + if (str2x(portStr, port)) + throw Error("Invalid port value."); } -//----------------------------------------------------------------------------- -int STG_CLIENT::Stop() + +STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data) + : m_impl(new Impl(address, callback, data)) { -return 0; } -//----------------------------------------------------------------------------- -string STG_CLIENT::GetUserPassword() const + +STG_CLIENT::STG_CLIENT(const STG_CLIENT& rhs) + : m_impl(new Impl(*rhs.m_impl)) { -return userPassword; } -//----------------------------------------------------------------------------- -int STG_CLIENT::Send(const RAD_PACKET & packet) + +STG_CLIENT::~STG_CLIENT() { -char buf[RAD_MAX_PACKET_LEN]; - -Encrypt(buf, (char *)&packet, sizeof(RAD_PACKET) / 8); +} -int res = sendto(sock, buf, sizeof(RAD_PACKET), 0, (struct sockaddr *)&outerAddr, sizeof(outerAddr)); +bool STG_CLIENT::stop() +{ + return m_impl->stop(); +} -if (res == -1) - errorStr = "Error sending data"; +bool STG_CLIENT::connected() const +{ + return m_impl->connected(); +} -return res; +bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) +{ + return m_impl->request(type, userName, password, pairs); } -//----------------------------------------------------------------------------- -int STG_CLIENT::RecvData(RAD_PACKET * packet) + +STG_CLIENT* STG_CLIENT::get() { -char buf[RAD_MAX_PACKET_LEN]; -int res; + return stgClient; +} -struct sockaddr_in addr; -socklen_t len = sizeof(struct sockaddr_in); +bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data) +{ + if ( stgClient != NULL && stgClient->stop() ) + delete stgClient; + try { + stgClient = new STG_CLIENT(address, callback, data); + return true; + } catch (const std::exception& ex) { + // TODO: Log it + RadLog("Client configuration error: %s.", ex.what()); + } + return false; +} -res = recvfrom(sock, buf, RAD_MAX_PACKET_LEN, 0, reinterpret_cast(&addr), &len); -if (res == -1) +bool STG_CLIENT::reconnect() +{ + if (stgClient == NULL) + { + RadLog("Connection is not configured."); + return false; + } + if (!stgClient->stop()) { - errorStr = "Error receiving data"; - return -1; + RadLog("Failed to drop previous connection."); + return false; + } + try { + STG_CLIENT* old = stgClient; + stgClient = new STG_CLIENT(*old); + delete old; + return true; + } catch (const ChannelConfig::Error& ex) { + // TODO: Log it + RadLog("Client configuration error: %s.", ex.what()); } + return false; +} -Decrypt((char *)packet, buf, res / 8); +STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data) + : m_config(address), + m_sock(connect()), + m_running(false), + m_stopped(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing), + m_callback(callback), + m_data(data), + m_parser(&STG_CLIENT::Impl::process, this), + m_connected(true) +{ + int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this); + if (res != 0) + throw Error("Failed to create thread: " + std::string(strerror(errno))); +} -return 0; +STG_CLIENT::Impl::Impl(const Impl& rhs) + : m_config(rhs.m_config), + m_sock(connect()), + m_running(false), + m_stopped(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing), + m_callback(rhs.m_callback), + m_data(rhs.m_data), + m_parser(&STG_CLIENT::Impl::process, this), + m_connected(true) +{ + int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this); + if (res != 0) + throw Error("Failed to create thread: " + std::string(strerror(errno))); } -//----------------------------------------------------------------------------- -int STG_CLIENT::Request(RAD_PACKET * packet, const std::string & login, const std::string & svc, uint8_t packetType) + +STG_CLIENT::Impl::~Impl() { -int res; + stop(); + shutdown(m_sock, SHUT_RDWR); + close(m_sock); +} -memcpy((void *)&packet->magic, (void *)RAD_ID, RAD_MAGIC_LEN); -packet->protoVer[0] = '0'; -packet->protoVer[1] = '1'; -packet->packetType = packetType; -packet->ip = 0; -strncpy((char *)packet->login, login.c_str(), RAD_LOGIN_LEN); -strncpy((char *)packet->service, svc.c_str(), RAD_SERVICE_LEN); +bool STG_CLIENT::Impl::stop() +{ + m_connected = false; -res = Send(*packet); -if (res == -1) - return -1; + if (m_stopped) + return true; -res = RecvData(packet); -if (res == -1) - return -1; + m_running = false; -if (strncmp((char *)packet->magic, RAD_ID, RAD_MAGIC_LEN)) - { - errorStr = "Magic invalid. Wanted: '"; - errorStr += RAD_ID; - errorStr += "', got: '"; - errorStr += (char *)packet->magic; - errorStr += "'"; - return -1; + for (size_t i = 0; i < 25 && !m_stopped; i++) { + struct timespec ts = {0, 200000000}; + nanosleep(&ts, NULL); } -return 0; -} -//----------------------------------------------------------------------------- -int STG_CLIENT::Authorize(const string & login, const string & svc) -{ -RAD_PACKET packet; + if (m_stopped) { + pthread_join(m_thread, NULL); + return true; + } -userPassword = ""; + return false; +} -if (Request(&packet, login, svc, RAD_AUTZ_PACKET)) - return -1; +bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) +{ + MapGen map; + for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it) + map.add(it->first, new StringGen(it->second)); + map.add("Radius-Username", new StringGen(userName)); + map.add("Radius-Userpass", new StringGen(password)); -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1; + PacketGen gen("data"); + gen.add("stage", toStage(type)) + .add("pairs", map); -userPassword = (char *)packet.password; + m_lastPing = time(NULL); -return 0; + return generate(gen, &STG_CLIENT::Impl::write, this); } -//----------------------------------------------------------------------------- -int STG_CLIENT::Authenticate(const string & login, const string & svc) -{ -RAD_PACKET packet; -userPassword = ""; +void STG_CLIENT::Impl::runImpl() +{ + m_running = true; -if (Request(&packet, login, svc, RAD_AUTH_PACKET)) - return -1; + while (m_running) { + fd_set fds; -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1; + FD_ZERO(&fds); + FD_SET(m_sock, &fds); -return 0; -} -//----------------------------------------------------------------------------- -int STG_CLIENT::PostAuthenticate(const string & login, const string & svc) -{ -RAD_PACKET packet; + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 500000; -userPassword = ""; + int res = select(m_sock + 1, &fds, NULL, NULL, &tv); + if (res < 0) + { + if (errno == EINTR) + continue; + RadLog("'select' is failed: %s", strerror(errno)); + break; + } -if (Request(&packet, login, svc, RAD_POST_AUTH_PACKET)) - return -1; + if (!m_running) + break; -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1; + if (res > 0) + { + if (FD_ISSET(m_sock, &fds)) + m_running = read(); + } + else + m_running = tick(); + } -if (svc == "Framed-User") - framedIP = packet.ip; -else - framedIP = 0; + m_connected = false; + m_stopped = true; +} -return 0; +int STG_CLIENT::Impl::connect() +{ + if (m_config.transport == "tcp") + return connectTCP(); + else if (m_config.transport == "unix") + return connectUNIX(); + throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'."); } -//----------------------------------------------------------------------------- -int STG_CLIENT::Account(const std::string & type, const string & login, const string & svc, const string & sessid) + +int STG_CLIENT::Impl::connectTCP() { -RAD_PACKET packet; + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_family = AF_INET; /* Allow IPv4 */ + hints.ai_socktype = SOCK_STREAM; /* Stream socket */ + hints.ai_flags = 0; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + addrinfo* ais = NULL; + int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais); + if (res != 0) + throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res)); + + for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next) + { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) + { + Error error(std::string("Error creating TCP socket: ") + strerror(errno)); + freeaddrinfo(ais); + throw error; + } + if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + RadLog("'connect' is failed: %s", strerror(errno)); + continue; + } + freeaddrinfo(ais); + return fd; + } -userPassword = ""; -strncpy((char *)packet.sessid, sessid.c_str(), RAD_SESSID_LEN); + freeaddrinfo(ais); -if (type == "Start") + throw Error("Failed to resolve '" + m_config.address); +}; + +int STG_CLIENT::Impl::connectUNIX() +{ + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) + throw Error(std::string("Error creating UNIX socket: ") + strerror(errno)); + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length()); + if (::connect(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) + { + Error error(std::string("Error connecting UNIX socket: ") + strerror(errno)); + shutdown(fd, SHUT_RDWR); + close(fd); + throw error; + } + return fd; +} + +bool STG_CLIENT::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); + if (res < 0) { - if (Request(&packet, login, svc, RAD_ACCT_START_PACKET)) - return -1; + RadLog("Failed to read data: %s", strerror(errno)); + return false; } -else if (type == "Stop") + m_lastActivity = time(NULL); + RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str()); + if (res == 0) { - if (Request(&packet, login, svc, RAD_ACCT_STOP_PACKET)) - return -1; + m_parser.last(); + return false; } -else if (type == "Interim-Update") + return m_parser.append(buffer.data(), res); +} + +bool STG_CLIENT::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) { - if (Request(&packet, login, svc, RAD_ACCT_UPDATE_PACKET)) - return -1; + int delta = difftime(now, m_lastActivity); + RadLog("Connection timeout: %d sec.", delta); + //m_logger("Connection to " + m_remote + " timed out."); + return false; } -else + if (difftime(now, m_lastPing) > PING_TIMEOUT) { - if (Request(&packet, login, svc, RAD_ACCT_OTHER_PACKET)) - return -1; + int delta = difftime(now, m_lastPing); + RadLog("Ping timeout: %d sec. Sending ping...", delta); + sendPing(); } + return true; +} + +void STG_CLIENT::Impl::process(void* data) +{ + Impl& impl = *static_cast(data); + switch (impl.m_parser.packet()) + { + case PING: + impl.processPing(); + return; + case PONG: + impl.processPong(); + return; + case DATA: + impl.processData(); + return; + } + RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str()); +} + +void STG_CLIENT::Impl::processPing() +{ + RadLog("Got ping, sending pong."); + sendPong(); +} -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1; +void STG_CLIENT::Impl::processPong() +{ + RadLog("Got pong."); + m_lastActivity = time(NULL); +} -return 0; +void STG_CLIENT::Impl::processData() +{ + RESULT data; + RadLog("Got data."); + for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it) + data.reply.push_back(std::make_pair(it->first, it->second)); + for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it) + data.modify.push_back(std::make_pair(it->first, it->second)); + m_callback(m_data, data, m_parser.result()); } -//----------------------------------------------------------------------------- -void STG_CLIENT::Encrypt(char * dst, const char * src, int len8) -{ -// len8 - длина в 8-ми байтовых блоках -if (dst != src) - memcpy(dst, src, len8 * 8); - -for (int i = 0; i < len8; i++) - Blowfish_Encrypt(&ctx, (uint32_t *)(dst + i*8), (uint32_t *)(dst + i*8 + 4)); + +bool STG_CLIENT::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &STG_CLIENT::Impl::write, this); } -//----------------------------------------------------------------------------- -void STG_CLIENT::Decrypt(char * dst, const char * src, int len8) + +bool STG_CLIENT::Impl::sendPong() { -// len8 - длина в 8-ми байтовых блоках -if (dst != src) - memcpy(dst, src, len8 * 8); + PacketGen gen("pong"); -for (int i = 0; i < len8; i++) - Blowfish_Decrypt(&ctx, (uint32_t *)(dst + i*8), (uint32_t *)(dst + i*8 + 4)); + m_lastPing = time(NULL); + + return generate(gen, &STG_CLIENT::Impl::write, this); +} + +bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size) +{ + RadLog("Sending JSON:"); + std::string json(buf, size); + RadLog("%s", json.c_str()); + STG_CLIENT::Impl& impl = *static_cast(data); + while (size > 0) + { + ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL); + if (res < 0) + { + impl.m_connected = false; + RadLog("Failed to write data: %s.", strerror(errno)); + return false; + } + size -= res; + } + return true; +} + +void* STG_CLIENT::Impl::run(void* data) +{ + Impl& impl = *static_cast(data); + impl.runImpl(); + return NULL; } -//-----------------------------------------------------------------------------