]> git.stg.codes - stg.git/blobdiff - projects/rlm_stg/stg_client.cpp
Added no_match settings to the radius plugin.
[stg.git] / projects / rlm_stg / stg_client.cpp
index 6987976b7d44bf011ec4bb8d406fd538601dcb26..239770deffdb068a77e83ba36c2dc2ddb193cc98 100644 (file)
 
 #include "stg_client.h"
 
+#include "conn.h"
+#include "radlog.h"
+
+#include "stg/locker.h"
 #include "stg/common.h"
 
-#include <boost/bind.hpp>
+#include <map>
+#include <utility>
 
-#include <stdexcept>
+using STG::RLM::Client;
+using STG::RLM::Conn;
+using STG::RLM::RESULT;
 
 namespace {
 
-STG_CLIENT* stgClient = NULL;
+Client* stgClient = NULL;
 
-unsigned fromType(STG_CLIENT::TYPE type)
-{
-    return static_cast<unsigned>(type);
 }
 
-STG::SGCP::TransportType toTransport(const std::string& value)
+class Client::Impl
 {
-    std::string type = ToLower(value);
-    if (type == "unix") return STG::SGCP::UNIX;
-    else if (type == "udp") return STG::SGCP::UDP;
-    else if (type == "tcp") return STG::SGCP::TCP;
-    throw ChannelConfig::Error("Invalid transport type. Should be 'unix', 'udp' or 'tcp'.");
-}
-
-}
+    public:
+        explicit Impl(const std::string& address);
+        ~Impl();
+
+        bool stop() { return m_conn ? m_conn->stop() : true; }
+
+        RESULT request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
+
+    private:
+        std::string m_address;
+        boost::scoped_ptr<Conn> m_conn;
+
+        pthread_mutex_t m_mutex;
+        pthread_cond_t m_cond;
+        bool m_done;
+        RESULT m_result;
+        bool m_status;
+
+        static bool callback(void* data, const RESULT& result, bool status)
+        {
+            Impl& impl = *static_cast<Impl*>(data);
+            STG_LOCKER lock(impl.m_mutex);
+            impl.m_result = result;
+            impl.m_status = status;
+            impl.m_done = true;
+            pthread_cond_signal(&impl.m_cond);
+            return true;
+        }
+};
 
-ChannelConfig::ChannelConfig(std::string addr)
-    : transport(STG::SGCP::TCP)
+Client::Impl::Impl(const std::string& address)
+    : m_address(address)
 {
-    // unix:pass@/var/run/stg.sock
-    // tcp:secret@192.168.0.1:12345
-    // udp:key@isp.com.ua:54321
-
-    size_t pos = addr.find_first_of(':');
-    if (pos == std::string::npos)
-        throw Error("Missing transport name.");
-    transport = toTransport(addr.substr(0, pos));
-    addr = addr.substr(pos + 1);
-    if (addr.empty())
-        throw Error("Missing address to connect to.");
-    pos = addr.find_first_of('@');
-    if (pos != std::string::npos) {
-        key = addr.substr(0, pos);
-        addr = addr.substr(pos + 1);
-        if (addr.empty())
-            throw Error("Missing address to connect to.");
+    try
+    {
+        m_conn.reset(new Conn(m_address, &Impl::callback, this));
     }
-    if (transport == STG::SGCP::UNIX)
+    catch (const std::runtime_error& ex)
     {
-        address = addr;
-        return;
+        RadLog("Connection error: %s.", ex.what());
     }
-    pos = addr.find_first_of(':');
-    if (pos == std::string::npos)
-        throw Error("Missing port.");
-    address = addr.substr(0, pos);
-    if (str2x(addr.substr(pos + 1), port))
-        throw Error("Invalid port value.");
+    pthread_mutex_init(&m_mutex, NULL);
+    pthread_cond_init(&m_cond, NULL);
+    m_done = false;
 }
 
-STG_CLIENT::STG_CLIENT(const std::string& address)
-    : m_config(address),
-      m_proto(m_config.transport, m_config.key),
-      m_thread(boost::bind(&STG_CLIENT::m_run, this))
+Client::Impl::~Impl()
 {
+    pthread_cond_destroy(&m_cond);
+    pthread_mutex_destroy(&m_mutex);
 }
 
-STG_CLIENT::~STG_CLIENT()
+RESULT Client::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-    stop();
+    STG_LOCKER lock(m_mutex);
+    if (!m_conn || !m_conn->connected())
+        m_conn.reset(new Conn(m_address, &Impl::callback, this));
+    if (!m_conn->connected())
+        throw Conn::Error("Failed to create connection to '" + m_address + "'.");
+
+    m_done = false;
+    m_conn->request(type, userName, password, pairs);
+    timespec ts;
+    clock_gettime(CLOCK_REALTIME, &ts);
+    ts.tv_sec += 5;
+    int res = 0;
+    while (!m_done && res == 0)
+        res = pthread_cond_timedwait(&m_cond, &m_mutex, &ts);
+    if (res != 0 || !m_status)
+        throw Conn::Error("Request failed.");
+    return m_result;
 }
 
-bool STG_CLIENT::stop()
+Client::Client(const std::string& address)
+    : m_impl(new Impl(address))
 {
-    return m_proto.stop();
 }
 
-RESULT STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
+Client::~Client()
 {
-    m_writeHeader(type, userName, password);
-    m_writePairBlock(pairs);
-    RESULT result;
-    result.modify = m_readPairBlock();
-    result.reply = m_readPairBlock();
-    return result;
 }
 
-STG_CLIENT* STG_CLIENT::get()
+bool Client::stop()
 {
-    return stgClient;
+    return m_impl->stop();
 }
 
-bool STG_CLIENT::configure(const std::string& address)
+RESULT Client::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-    if ( stgClient != NULL && stgClient->stop() )
-        delete stgClient;
-    try {
-        stgClient = new STG_CLIENT(address);
-        return true;
-    } catch (const ChannelConfig::Error& ex) {
-        // TODO: Log it
-    }
-    return false;
+    return m_impl->request(type, userName, password, pairs);
 }
 
-void STG_CLIENT::m_writeHeader(TYPE type, const std::string& userName, const std::string& password)
+Client* Client::get()
 {
-    try {
-        m_proto.writeAll<uint64_t>(fromType(type));
-        m_proto.writeAll(userName);
-        m_proto.writeAll(password);
-    } catch (const STG::SGCP::Proto::Error& ex) {
-        throw Error(ex.what());
-    }
-}
-
-void STG_CLIENT::m_writePairBlock(const PAIRS& pairs)
-{
-    try {
-        m_proto.writeAll<uint64_t>(pairs.size());
-        for (size_t i = 0; i < pairs.size(); ++i) {
-            m_proto.writeAll(pairs[i].first);
-            m_proto.writeAll(pairs[i].second);
-        }
-    } catch (const STG::SGCP::Proto::Error& ex) {
-        throw Error(ex.what());
-    }
+    return stgClient;
 }
 
-PAIRS STG_CLIENT::m_readPairBlock()
+bool Client::configure(const std::string& address)
 {
+    if ( stgClient != NULL )
+        return stgClient->configure(address);
     try {
-        size_t count = m_proto.readAll<uint64_t>();
-        if (count == 0)
-            return PAIRS();
-        PAIRS res(count);
-        for (size_t i = 0; i < count; ++i) {
-            res[i].first = m_proto.readAll<std::string>();
-            res[i].second = m_proto.readAll<std::string>();
-        }
-        return res;
-    } catch (const STG::SGCP::Proto::Error& ex) {
-        throw Error(ex.what());
+        stgClient = new Client(address);
+        return true;
+    } catch (const std::exception& ex) {
+        RadLog("Client configuration error: %s.", ex.what());
     }
-}
-
-void STG_CLIENT::m_run()
-{
-    m_proto.connect(m_config.address, m_config.port);
-    m_proto.run();
+    return false;
 }