]> git.stg.codes - stg.git/commitdiff
Implemented rlm_stg.
authorMaxim Mamontov <faust.madf@gmail.com>
Mon, 24 Aug 2015 13:03:20 +0000 (16:03 +0300)
committerMaxim Mamontov <faust.madf@gmail.com>
Mon, 24 Aug 2015 13:03:20 +0000 (16:03 +0300)
projects/rlm_stg/iface.cpp
projects/rlm_stg/stg_client.cpp
projects/rlm_stg/stg_client.h

index 485e9ef7439d9b3d1d5854e0a38429b8c6df472c..caa44386cf3f3dd6b625883af11837da9a5b0563 100644 (file)
@@ -9,6 +9,25 @@
 namespace
 {
 
+struct Response
+{
+    bool done;
+    pthread_mutex_t mutex;
+    pthread_cond_t cond;
+    RESULT result;
+
+    static bool callback(void* data, const RESULT& res)
+    {
+        Response& resp = *static_cast<Response*>(data);
+        pthread_mutex_lock(&resp.mutex);
+        resp.result = res;
+        resp.done = true;
+        pthread_cond_signal(&resp.cond);
+        pthread_mutex_unlock(&resp.mutex);
+        return true;
+    }
+} response;
+
 STG_PAIR* toSTGPairs(const PAIRS& source)
 {
     STG_PAIR * pairs = new STG_PAIR[source.size() + 1];
@@ -68,7 +87,12 @@ STG_RESULT stgRequest(STG_CLIENT::TYPE type, const char* userName, const char* p
         return emptyResult();
     }
     try {
-        return toResult(client->request(type, toString(userName), toString(password), fromSTGPairs(pairs)));
+        client->request(type, toString(userName), toString(password), fromSTGPairs(pairs));
+        pthread_mutex_lock(&response.mutex);
+        while (!response.done)
+            pthread_cond_wait(&response.cond, &response.mutex);
+        pthread_mutex_unlock(&response.mutex);
+        return toResult(response.result);
     } catch (const STG_CLIENT::Error& ex) {
         // TODO: log error
         return emptyResult();
@@ -79,7 +103,11 @@ STG_RESULT stgRequest(STG_CLIENT::TYPE type, const char* userName, const char* p
 
 int stgInstantiateImpl(const char* address)
 {
-    if (STG_CLIENT::configure(toString(address)))
+    pthread_mutex_init(&response.mutex, NULL);
+    pthread_cond_init(&response.cond, NULL);
+    response.done = false;
+
+    if (STG_CLIENT::configure(toString(address), &Response::callback, &response))
         return 1;
 
     return 0;
index 6987976b7d44bf011ec4bb8d406fd538601dcb26..399b971d8a7f76b6fc84103db6c4d5b815bc75f0 100644 (file)
 
 #include "stg_client.h"
 
+#include "stg/json_parser.h"
+#include "stg/json_generator.h"
 #include "stg/common.h"
 
-#include <boost/bind.hpp>
+#include <map>
+#include <utility>
+#include <cerrno>
+#include <cstring>
 
-#include <stdexcept>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h> // UNIX
+#include <netinet/in.h> // IP
+#include <netinet/tcp.h> // TCP
+#include <netdb.h>
+
+using STG::JSON::Parser;
+using STG::JSON::PairsParser;
+using STG::JSON::EnumParser;
+using STG::JSON::NodeParser;
+using STG::JSON::Gen;
+using STG::JSON::MapGen;
+using STG::JSON::StringGen;
 
 namespace {
 
+double CONN_TIMEOUT = 5;
+double PING_TIMEOUT = 1;
+
 STG_CLIENT* stgClient = NULL;
 
-unsigned fromType(STG_CLIENT::TYPE type)
+std::string toStage(STG_CLIENT::TYPE type)
 {
-    return static_cast<unsigned>(type);
+    switch (type)
+    {
+        case STG_CLIENT::AUTHORIZE: return "authorize";
+        case STG_CLIENT::AUTHENTICATE: return "authenticate";
+        case STG_CLIENT::POST_AUTH: return "postauth";
+        case STG_CLIENT::PRE_ACCT: return "preacct";
+        case STG_CLIENT::ACCOUNT: return "accounting";
+    }
+    return "";
 }
 
-STG::SGCP::TransportType toTransport(const std::string& value)
+enum Packet
 {
-    std::string type = ToLower(value);
-    if (type == "unix") return STG::SGCP::UNIX;
-    else if (type == "udp") return STG::SGCP::UDP;
-    else if (type == "tcp") return STG::SGCP::TCP;
-    throw ChannelConfig::Error("Invalid transport type. Should be 'unix', 'udp' or 'tcp'.");
-}
+    PING,
+    PONG,
+    DATA
+};
+
+std::map<std::string, Packet> packetCodes;
+std::map<std::string, bool> resultCodes;
+
+class PacketParser : public EnumParser<Packet>
+{
+    public:
+        PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
+            : EnumParser(next, packet, packetStr, packetCodes)
+        {
+            if (!packetCodes.empty())
+                return;
+            packetCodes["ping"] = PING;
+            packetCodes["pong"] = PONG;
+            packetCodes["data"] = DATA;
+        }
+};
+
+class ResultParser : public EnumParser<bool>
+{
+    public:
+        ResultParser(NodeParser* next, bool& result, std::string& resultStr)
+            : EnumParser(next, result, resultStr, resultCodes)
+        {
+            if (!resultCodes.empty())
+                return;
+            resultCodes["no"] = false;
+            resultCodes["ok"] = true;
+        }
+};
+
+class TopParser : public NodeParser
+{
+    public:
+        TopParser()
+            : m_packetParser(this, m_packet, m_packetStr),
+              m_resultParser(this, m_result, m_resultStr),
+              m_replyParser(this, m_reply),
+              m_modifyParser(this, m_modify)
+        {}
+
+        virtual NodeParser* parseStartMap() { return this; }
+        virtual NodeParser* parseMapKey(const std::string& value)
+        {
+            std::string key = ToLower(value);
+
+            if (key == "packet")
+                return &m_packetParser;
+            else if (key == "result")
+                return &m_resultParser;
+            else if (key == "reply")
+                return &m_replyParser;
+            else if (key == "modify")
+                return &m_modifyParser;
+
+            return this;
+        }
+        virtual NodeParser* parseEndMap() { return this; }
+
+        const std::string& packetStr() const { return m_packetStr; }
+        Packet packet() const { return m_packet; }
+        const std::string& resultStr() const { return m_resultStr; }
+        bool result() const { return m_result; }
+        const PairsParser::Pairs& reply() const { return m_reply; }
+        const PairsParser::Pairs& modify() const { return m_modify; }
+
+    private:
+        std::string m_packetStr;
+        Packet m_packet;
+        std::string m_resultStr;
+        bool m_result;
+        PairsParser::Pairs m_reply;
+        PairsParser::Pairs m_modify;
+
+        PacketParser m_packetParser;
+        ResultParser m_resultParser;
+        PairsParser m_replyParser;
+        PairsParser m_modifyParser;
+};
+
+class ProtoParser : public Parser
+{
+    public:
+        ProtoParser() : Parser( &m_topParser ) {}
+
+        const std::string& packetStr() const { return m_topParser.packetStr(); }
+        Packet packet() const { return m_topParser.packet(); }
+        const std::string& resultStr() const { return m_topParser.resultStr(); }
+        bool result() const { return m_topParser.result(); }
+        const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
+        const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
+
+    private:
+        TopParser m_topParser;
+};
+
+class PacketGen : public Gen
+{
+    public:
+        PacketGen(const std::string& type)
+            : m_type(type)
+        {
+            m_gen.add("packet", m_type);
+        }
+        void run(yajl_gen_t* handle) const
+        {
+            m_gen.run(handle);
+        }
+        PacketGen& add(const std::string& key, const std::string& value)
+        {
+            m_gen.add(key, new StringGen(value));
+            return *this;
+        }
+        PacketGen& add(const std::string& key, MapGen* map)
+        {
+            m_gen.add(key, map);
+            return *this;
+        }
+    private:
+        MapGen m_gen;
+        StringGen m_type;
+};
 
 }
 
+class STG_CLIENT::Impl
+{
+public:
+    Impl(const std::string& address, Callback callback, void* data);
+    ~Impl();
+
+    bool stop();
+
+    bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
+
+private:
+    ChannelConfig m_config;
+
+    int m_sock;
+
+    bool m_running;
+    bool m_stopped;
+
+    time_t m_lastPing;
+    time_t m_lastActivity;
+
+    pthread_t m_thread;
+    pthread_mutex_t m_mutex;
+
+    Callback m_callback;
+    void* m_data;
+
+    ProtoParser m_parser;
+
+    void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
+    void m_writePairBlock(const PAIRS& source);
+    PAIRS m_readPairBlock();
+
+    static void* run(void* );
+
+    void runImpl();
+
+    int connect();
+    int connectTCP();
+    int connectUNIX();
+
+    bool read();
+    bool tick();
+
+    bool process();
+    bool processPing();
+    bool processPong();
+    bool processData();
+    bool sendPing();
+    bool sendPong();
+
+    static bool write(void* data, const char* buf, size_t size);
+};
+
 ChannelConfig::ChannelConfig(std::string addr)
-    : transport(STG::SGCP::TCP)
 {
     // unix:pass@/var/run/stg.sock
     // tcp:secret@192.168.0.1:12345
@@ -56,7 +258,7 @@ ChannelConfig::ChannelConfig(std::string addr)
     size_t pos = addr.find_first_of(':');
     if (pos == std::string::npos)
         throw Error("Missing transport name.");
-    transport = toTransport(addr.substr(0, pos));
+    transport = ToLower(addr.substr(0, pos));
     addr = addr.substr(pos + 1);
     if (addr.empty())
         throw Error("Missing address to connect to.");
@@ -67,7 +269,7 @@ ChannelConfig::ChannelConfig(std::string addr)
         if (addr.empty())
             throw Error("Missing address to connect to.");
     }
-    if (transport == STG::SGCP::UNIX)
+    if (transport == "unix")
     {
         address = addr;
         return;
@@ -76,35 +278,28 @@ ChannelConfig::ChannelConfig(std::string addr)
     if (pos == std::string::npos)
         throw Error("Missing port.");
     address = addr.substr(0, pos);
-    if (str2x(addr.substr(pos + 1), port))
+    portStr = addr.substr(pos + 1);
+    if (str2x(portStr, port))
         throw Error("Invalid port value.");
 }
 
-STG_CLIENT::STG_CLIENT(const std::string& address)
-    : m_config(address),
-      m_proto(m_config.transport, m_config.key),
-      m_thread(boost::bind(&STG_CLIENT::m_run, this))
+STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data)
+    : m_impl(new Impl(address, callback, data))
 {
 }
 
 STG_CLIENT::~STG_CLIENT()
 {
-    stop();
 }
 
 bool STG_CLIENT::stop()
 {
-    return m_proto.stop();
+    return m_impl->stop();
 }
 
-RESULT STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
+bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-    m_writeHeader(type, userName, password);
-    m_writePairBlock(pairs);
-    RESULT result;
-    result.modify = m_readPairBlock();
-    result.reply = m_readPairBlock();
-    return result;
+    return m_impl->request(type, userName, password, pairs);
 }
 
 STG_CLIENT* STG_CLIENT::get()
@@ -112,12 +307,12 @@ STG_CLIENT* STG_CLIENT::get()
     return stgClient;
 }
 
-bool STG_CLIENT::configure(const std::string& address)
+bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data)
 {
     if ( stgClient != NULL && stgClient->stop() )
         delete stgClient;
     try {
-        stgClient = new STG_CLIENT(address);
+        stgClient = new STG_CLIENT(address, callback, data);
         return true;
     } catch (const ChannelConfig::Error& ex) {
         // TODO: Log it
@@ -125,49 +320,281 @@ bool STG_CLIENT::configure(const std::string& address)
     return false;
 }
 
-void STG_CLIENT::m_writeHeader(TYPE type, const std::string& userName, const std::string& password)
+STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data)
+    : m_config(address),
+      m_sock(connect()),
+      m_running(false),
+      m_stopped(true),
+      m_lastPing(time(NULL)),
+      m_lastActivity(m_lastPing),
+      m_callback(callback),
+      m_data(data)
 {
-    try {
-        m_proto.writeAll<uint64_t>(fromType(type));
-        m_proto.writeAll(userName);
-        m_proto.writeAll(password);
-    } catch (const STG::SGCP::Proto::Error& ex) {
-        throw Error(ex.what());
+    int res = pthread_create(&m_thread, NULL, run, this);
+    if (res != 0)
+        throw Error("Failed to create thread: " + std::string(strerror(errno)));
+}
+
+STG_CLIENT::Impl::~Impl()
+{
+    stop();
+    shutdown(m_sock, SHUT_RDWR);
+    close(m_sock);
+}
+
+bool STG_CLIENT::Impl::stop()
+{
+    if (m_stopped)
+        return true;
+
+    m_running = false;
+
+    for (size_t i = 0; i < 25 && !m_stopped; i++) {
+        struct timespec ts = {0, 200000000};
+        nanosleep(&ts, NULL);
     }
+
+    if (m_stopped) {
+        pthread_join(m_thread, NULL);
+        return true;
+    }
+
+    return false;
 }
 
-void STG_CLIENT::m_writePairBlock(const PAIRS& pairs)
+bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-    try {
-        m_proto.writeAll<uint64_t>(pairs.size());
-        for (size_t i = 0; i < pairs.size(); ++i) {
-            m_proto.writeAll(pairs[i].first);
-            m_proto.writeAll(pairs[i].second);
+    boost::scoped_ptr<MapGen> map(new MapGen);
+    for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
+        map->add(it->first, new StringGen(it->second));
+    map->add("Radius-Username", new StringGen(userName));
+    map->add("Radius-Userpass", new StringGen(password));
+
+    PacketGen gen("data");
+    gen.add("stage", toStage(type))
+       .add("pairs", map.get());
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &STG_CLIENT::Impl::write, this);
+}
+
+void STG_CLIENT::Impl::runImpl()
+{
+    m_running = true;
+
+    while (m_running) {
+        fd_set fds;
+
+        FD_ZERO(&fds);
+        FD_SET(m_sock, &fds);
+
+        struct timeval tv;
+        tv.tv_sec = 0;
+        tv.tv_usec = 500000;
+
+        int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
+        if (res < 0)
+        {
+            //m_error = std::string("'select' is failed: '") + strerror(errno) + "'.";
+            //m_logger(m_error);
+            break;
+        }
+
+        if (!m_running)
+            break;
+
+        if (res > 0)
+        {
+            if (FD_ISSET(m_sock, &fds))
+                m_running = read();
         }
-    } catch (const STG::SGCP::Proto::Error& ex) {
-        throw Error(ex.what());
+        else
+            m_running = tick();
     }
+
+    m_stopped = true;
 }
 
-PAIRS STG_CLIENT::m_readPairBlock()
+int STG_CLIENT::Impl::connect()
 {
-    try {
-        size_t count = m_proto.readAll<uint64_t>();
-        if (count == 0)
-            return PAIRS();
-        PAIRS res(count);
-        for (size_t i = 0; i < count; ++i) {
-            res[i].first = m_proto.readAll<std::string>();
-            res[i].second = m_proto.readAll<std::string>();
+    if (m_config.transport == "tcp")
+        return connectTCP();
+    else if (m_config.transport == "unix")
+        return connectUNIX();
+    throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
+}
+
+int STG_CLIENT::Impl::connectTCP()
+{
+    addrinfo hints;
+    memset(&hints, 0, sizeof(addrinfo));
+
+    hints.ai_family = AF_INET;       /* Allow IPv4 */
+    hints.ai_socktype = SOCK_STREAM; /* Stream socket */
+    hints.ai_flags = 0;     /* For wildcard IP address */
+    hints.ai_protocol = 0;           /* Any protocol */
+    hints.ai_canonname = NULL;
+    hints.ai_addr = NULL;
+    hints.ai_next = NULL;
+
+    addrinfo* ais = NULL;
+    int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
+    if (res != 0)
+        throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
+
+    for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
+    {
+        int fd = socket(AF_INET, SOCK_STREAM, 0);
+        if (fd == -1)
+        {
+            Error error(std::string("Error creating TCP socket: ") + strerror(errno));
+            freeaddrinfo(ais);
+            throw error;
+        }
+        if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
+        {
+            shutdown(fd, SHUT_RDWR);
+            close(fd);
+            // TODO: log it.
+            continue;
+        }
+        freeaddrinfo(ais);
+        return fd;
+    }
+
+    freeaddrinfo(ais);
+
+    throw Error("Failed to resolve '" + m_config.address);
+};
+
+int STG_CLIENT::Impl::connectUNIX()
+{
+    int fd = socket(AF_UNIX, SOCK_STREAM, 0);
+    if (fd == -1)
+        throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
+    struct sockaddr_un addr;
+    memset(&addr, 0, sizeof(addr));
+    addr.sun_family = AF_UNIX;
+    strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
+    if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
+    {
+        Error error(std::string("Error binding UNIX socket: ") + strerror(errno));
+        shutdown(fd, SHUT_RDWR);
+        close(fd);
+        throw error;
+    }
+    return fd;
+}
+
+bool STG_CLIENT::Impl::read()
+{
+    static std::vector<char> buffer(1024);
+    ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
+    if (res < 0)
+    {
+        //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+        return false;
+    }
+    m_lastActivity = time(NULL);
+    if (res == 0)
+    {
+        if (!m_parser.done())
+        {
+            //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+            return false;
+        }
+        return process();
+    }
+    return m_parser.append(buffer.data(), res);
+}
+
+bool STG_CLIENT::Impl::tick()
+{
+    time_t now = time(NULL);
+    if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
+    {
+        //m_logger("Connection to " + m_remote + " timed out.");
+        return false;
+    }
+    if (difftime(now, m_lastPing) > PING_TIMEOUT)
+        sendPing();
+    return true;
+}
+
+bool STG_CLIENT::Impl::process()
+{
+    switch (m_parser.packet())
+    {
+        case PING:
+            return processPing();
+        case PONG:
+            return processPong();
+        case DATA:
+            return processData();
+    }
+    //m_logger("Received invalid packet type: " + m_parser.packetStr());
+    return false;
+}
+
+bool STG_CLIENT::Impl::processPing()
+{
+    return sendPong();
+}
+
+bool STG_CLIENT::Impl::processPong()
+{
+    m_lastActivity = time(NULL);
+    return true;
+}
+
+bool STG_CLIENT::Impl::processData()
+{
+    RESULT result;
+    for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
+        result.reply.push_back(std::make_pair(it->first, it->second));
+    for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
+        result.modify.push_back(std::make_pair(it->first, it->second));
+    return m_callback(m_data, result);
+}
+
+bool STG_CLIENT::Impl::sendPing()
+{
+    PacketGen gen("ping");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &STG_CLIENT::Impl::write, this);
+}
+
+bool STG_CLIENT::Impl::sendPong()
+{
+    PacketGen gen("pong");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &STG_CLIENT::Impl::write, this);
+}
+
+bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size)
+{
+    STG_CLIENT::Impl& impl = *static_cast<STG_CLIENT::Impl*>(data);
+    while (size > 0)
+    {
+        ssize_t res = ::write(impl.m_sock, buf, size);
+        if (res < 0)
+        {
+            //conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
+            return false;
         }
-        return res;
-    } catch (const STG::SGCP::Proto::Error& ex) {
-        throw Error(ex.what());
+        size -= res;
     }
+    return true;
 }
 
-void STG_CLIENT::m_run()
+void* STG_CLIENT::Impl::run(void* data)
 {
-    m_proto.connect(m_config.address, m_config.port);
-    m_proto.run();
+    Impl& impl = *static_cast<Impl*>(data);
+    impl.runImpl();
+    return NULL;
 }
index d57af7dacfc7c2201d0404750e0d31d89d213a8c..82d2287ba04ffb080a75ce43abbd95d4ba70d359 100644 (file)
 #ifndef STG_CLIENT_H
 #define STG_CLIENT_H
 
-#include "stg/sgcp_proto.h" // Proto
-#include "stg/sgcp_types.h" // TransportType
 #include "stg/os_int.h"
 
-#include <boost/thread.hpp>
+#include <boost/scoped_ptr.hpp>
 
 #include <string>
 #include <vector>
-#include <utility>
+#include <stdexcept>
 
 typedef std::vector<std::pair<std::string, std::string> > PAIRS;
 
@@ -46,9 +44,10 @@ struct ChannelConfig {
 
     ChannelConfig(std::string address);
 
-    STG::SGCP::TransportType transport;
+    std::string transport;
     std::string key;
     std::string address;
+    std::string portStr;
     uint16_t port;
 };
 
@@ -66,26 +65,21 @@ public:
         Error(const std::string& message) : runtime_error(message) {}
     };
 
-    STG_CLIENT(const std::string& address);
+    typedef bool (*Callback)(void* data, const RESULT& result);
+
+    STG_CLIENT(const std::string& address, Callback callback, void* data);
     ~STG_CLIENT();
 
     bool stop();
 
     static STG_CLIENT* get();
-    static bool configure(const std::string& address);
+    static bool configure(const std::string& address, Callback callback, void* data);
 
-    RESULT request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
+    bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
 
 private:
-    ChannelConfig m_config;
-    STG::SGCP::Proto m_proto;
-    boost::thread m_thread;
-
-    void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
-    void m_writePairBlock(const PAIRS& source);
-    PAIRS m_readPairBlock();
-
-    void m_run();
+    class Impl;
+    boost::scoped_ptr<Impl> m_impl;
 };
 
 #endif