]> git.stg.codes - stg.git/blobdiff - projects/rlm_stg/stg_client.cpp
Moved connection-related functions into a separate file.
[stg.git] / projects / rlm_stg / stg_client.cpp
index 399b971d8a7f76b6fc84103db6c4d5b815bc75f0..8c7ed9b9e43539b716a30453bdd7b2ba779bf057 100644 (file)
@@ -20,6 +20,8 @@
 
 #include "stg_client.h"
 
+#include "radlog.h"
+
 #include "stg/json_parser.h"
 #include "stg/json_generator.h"
 #include "stg/common.h"
@@ -46,8 +48,8 @@ using STG::JSON::StringGen;
 
 namespace {
 
-double CONN_TIMEOUT = 5;
-double PING_TIMEOUT = 1;
+double CONN_TIMEOUT = 60;
+double PING_TIMEOUT = 10;
 
 STG_CLIENT* stgClient = NULL;
 
@@ -104,11 +106,13 @@ class ResultParser : public EnumParser<bool>
 class TopParser : public NodeParser
 {
     public:
-        TopParser()
+        typedef void (*Callback) (void* /*data*/);
+        TopParser(Callback callback, void* data)
             : m_packetParser(this, m_packet, m_packetStr),
               m_resultParser(this, m_result, m_resultStr),
               m_replyParser(this, m_reply),
-              m_modifyParser(this, m_modify)
+              m_modifyParser(this, m_modify),
+              m_callback(callback), m_data(data)
         {}
 
         virtual NodeParser* parseStartMap() { return this; }
@@ -127,7 +131,7 @@ class TopParser : public NodeParser
 
             return this;
         }
-        virtual NodeParser* parseEndMap() { return this; }
+        virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
 
         const std::string& packetStr() const { return m_packetStr; }
         Packet packet() const { return m_packet; }
@@ -148,12 +152,18 @@ class TopParser : public NodeParser
         ResultParser m_resultParser;
         PairsParser m_replyParser;
         PairsParser m_modifyParser;
+
+        Callback m_callback;
+        void* m_data;
 };
 
 class ProtoParser : public Parser
 {
     public:
-        ProtoParser() : Parser( &m_topParser ) {}
+        ProtoParser(TopParser::Callback callback, void* data)
+            : Parser( &m_topParser ),
+              m_topParser(callback, data)
+        {}
 
         const std::string& packetStr() const { return m_topParser.packetStr(); }
         Packet packet() const { return m_topParser.packet(); }
@@ -183,7 +193,7 @@ class PacketGen : public Gen
             m_gen.add(key, new StringGen(value));
             return *this;
         }
-        PacketGen& add(const std::string& key, MapGen* map)
+        PacketGen& add(const std::string& key, MapGen& map)
         {
             m_gen.add(key, map);
             return *this;
@@ -199,9 +209,11 @@ class STG_CLIENT::Impl
 {
 public:
     Impl(const std::string& address, Callback callback, void* data);
+    Impl(const Impl& rhs);
     ~Impl();
 
     bool stop();
+    bool connected() const { return m_connected; }
 
     bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
 
@@ -224,6 +236,8 @@ private:
 
     ProtoParser m_parser;
 
+    bool m_connected;
+
     void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
     void m_writePairBlock(const PAIRS& source);
     PAIRS m_readPairBlock();
@@ -239,10 +253,10 @@ private:
     bool read();
     bool tick();
 
-    bool process();
-    bool processPing();
-    bool processPong();
-    bool processData();
+    static void process(void* data);
+    void processPing();
+    void processPong();
+    void processData();
     bool sendPing();
     bool sendPong();
 
@@ -288,6 +302,11 @@ STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data
 {
 }
 
+STG_CLIENT::STG_CLIENT(const STG_CLIENT& rhs)
+    : m_impl(new Impl(*rhs.m_impl))
+{
+}
+
 STG_CLIENT::~STG_CLIENT()
 {
 }
@@ -297,6 +316,11 @@ bool STG_CLIENT::stop()
     return m_impl->stop();
 }
 
+bool STG_CLIENT::connected() const
+{
+    return m_impl->connected();
+}
+
 bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
     return m_impl->request(type, userName, password, pairs);
@@ -314,8 +338,33 @@ bool STG_CLIENT::configure(const std::string& address, Callback callback, void*
     try {
         stgClient = new STG_CLIENT(address, callback, data);
         return true;
+    } catch (const std::exception& ex) {
+        // TODO: Log it
+        RadLog("Client configuration error: %s.", ex.what());
+    }
+    return false;
+}
+
+bool STG_CLIENT::reconnect()
+{
+    if (stgClient == NULL)
+    {
+        RadLog("Connection is not configured.");
+        return false;
+    }
+    if (!stgClient->stop())
+    {
+        RadLog("Failed to drop previous connection.");
+        return false;
+    }
+    try {
+        STG_CLIENT* old = stgClient;
+        stgClient = new STG_CLIENT(*old);
+        delete old;
+        return true;
     } catch (const ChannelConfig::Error& ex) {
         // TODO: Log it
+        RadLog("Client configuration error: %s.", ex.what());
     }
     return false;
 }
@@ -328,9 +377,28 @@ STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data
       m_lastPing(time(NULL)),
       m_lastActivity(m_lastPing),
       m_callback(callback),
-      m_data(data)
+      m_data(data),
+      m_parser(&STG_CLIENT::Impl::process, this),
+      m_connected(true)
+{
+    int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
+    if (res != 0)
+        throw Error("Failed to create thread: " + std::string(strerror(errno)));
+}
+
+STG_CLIENT::Impl::Impl(const Impl& rhs)
+    : m_config(rhs.m_config),
+      m_sock(connect()),
+      m_running(false),
+      m_stopped(true),
+      m_lastPing(time(NULL)),
+      m_lastActivity(m_lastPing),
+      m_callback(rhs.m_callback),
+      m_data(rhs.m_data),
+      m_parser(&STG_CLIENT::Impl::process, this),
+      m_connected(true)
 {
-    int res = pthread_create(&m_thread, NULL, run, this);
+    int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
     if (res != 0)
         throw Error("Failed to create thread: " + std::string(strerror(errno)));
 }
@@ -344,6 +412,8 @@ STG_CLIENT::Impl::~Impl()
 
 bool STG_CLIENT::Impl::stop()
 {
+    m_connected = false;
+
     if (m_stopped)
         return true;
 
@@ -364,15 +434,15 @@ bool STG_CLIENT::Impl::stop()
 
 bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-    boost::scoped_ptr<MapGen> map(new MapGen);
+    MapGen map;
     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
-        map->add(it->first, new StringGen(it->second));
-    map->add("Radius-Username", new StringGen(userName));
-    map->add("Radius-Userpass", new StringGen(password));
+        map.add(it->first, new StringGen(it->second));
+    map.add("Radius-Username", new StringGen(userName));
+    map.add("Radius-Userpass", new StringGen(password));
 
     PacketGen gen("data");
     gen.add("stage", toStage(type))
-       .add("pairs", map.get());
+       .add("pairs", map);
 
     m_lastPing = time(NULL);
 
@@ -396,8 +466,9 @@ void STG_CLIENT::Impl::runImpl()
         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
         if (res < 0)
         {
-            //m_error = std::string("'select' is failed: '") + strerror(errno) + "'.";
-            //m_logger(m_error);
+            if (errno == EINTR)
+                continue;
+            RadLog("'select' is failed: %s", strerror(errno));
             break;
         }
 
@@ -413,6 +484,7 @@ void STG_CLIENT::Impl::runImpl()
             m_running = tick();
     }
 
+    m_connected = false;
     m_stopped = true;
 }
 
@@ -456,7 +528,7 @@ int STG_CLIENT::Impl::connectTCP()
         {
             shutdown(fd, SHUT_RDWR);
             close(fd);
-            // TODO: log it.
+            RadLog("'connect' is failed: %s", strerror(errno));
             continue;
         }
         freeaddrinfo(ais);
@@ -479,7 +551,7 @@ int STG_CLIENT::Impl::connectUNIX()
     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
     {
-        Error error(std::string("Error binding UNIX socket: ") + strerror(errno));
+        Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
         shutdown(fd, SHUT_RDWR);
         close(fd);
         throw error;
@@ -493,18 +565,15 @@ bool STG_CLIENT::Impl::read()
     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
     if (res < 0)
     {
-        //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+        RadLog("Failed to read data: %s", strerror(errno));
         return false;
     }
     m_lastActivity = time(NULL);
+    RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
     if (res == 0)
     {
-        if (!m_parser.done())
-        {
-            //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
-            return false;
-        }
-        return process();
+        m_parser.last();
+        return false;
     }
     return m_parser.append(buffer.data(), res);
 }
@@ -514,48 +583,59 @@ bool STG_CLIENT::Impl::tick()
     time_t now = time(NULL);
     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
     {
+        int delta = difftime(now, m_lastActivity);
+        RadLog("Connection timeout: %d sec.", delta);
         //m_logger("Connection to " + m_remote + " timed out.");
         return false;
     }
     if (difftime(now, m_lastPing) > PING_TIMEOUT)
+    {
+        int delta = difftime(now, m_lastPing);
+        RadLog("Ping timeout: %d sec. Sending ping...", delta);
         sendPing();
+    }
     return true;
 }
 
-bool STG_CLIENT::Impl::process()
+void STG_CLIENT::Impl::process(void* data)
 {
-    switch (m_parser.packet())
+    Impl& impl = *static_cast<Impl*>(data);
+    switch (impl.m_parser.packet())
     {
         case PING:
-            return processPing();
+            impl.processPing();
+            return;
         case PONG:
-            return processPong();
+            impl.processPong();
+            return;
         case DATA:
-            return processData();
+            impl.processData();
+            return;
     }
-    //m_logger("Received invalid packet type: " + m_parser.packetStr());
-    return false;
+    RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
 }
 
-bool STG_CLIENT::Impl::processPing()
+void STG_CLIENT::Impl::processPing()
 {
-    return sendPong();
+    RadLog("Got ping, sending pong.");
+    sendPong();
 }
 
-bool STG_CLIENT::Impl::processPong()
+void STG_CLIENT::Impl::processPong()
 {
+    RadLog("Got pong.");
     m_lastActivity = time(NULL);
-    return true;
 }
 
-bool STG_CLIENT::Impl::processData()
+void STG_CLIENT::Impl::processData()
 {
-    RESULT result;
+    RESULT data;
+    RadLog("Got data.");
     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
-        result.reply.push_back(std::make_pair(it->first, it->second));
+        data.reply.push_back(std::make_pair(it->first, it->second));
     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
-        result.modify.push_back(std::make_pair(it->first, it->second));
-    return m_callback(m_data, result);
+        data.modify.push_back(std::make_pair(it->first, it->second));
+    m_callback(m_data, data, m_parser.result());
 }
 
 bool STG_CLIENT::Impl::sendPing()
@@ -578,13 +658,17 @@ bool STG_CLIENT::Impl::sendPong()
 
 bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size)
 {
+    RadLog("Sending JSON:");
+    std::string json(buf, size);
+    RadLog("%s", json.c_str());
     STG_CLIENT::Impl& impl = *static_cast<STG_CLIENT::Impl*>(data);
     while (size > 0)
     {
-        ssize_t res = ::write(impl.m_sock, buf, size);
+        ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
         if (res < 0)
         {
-            //conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
+            impl.m_connected = false;
+            RadLog("Failed to write data: %s.", strerror(errno));
             return false;
         }
         size -= res;