]> git.stg.codes - stg.git/commitdiff
Implemented backend for rlm_stg.
authorMaxim Mamontov <faust.madf@gmail.com>
Mon, 24 Aug 2015 13:02:51 +0000 (16:02 +0300)
committerMaxim Mamontov <faust.madf@gmail.com>
Mon, 24 Aug 2015 13:02:51 +0000 (16:02 +0300)
projects/stargazer/plugins/other/radius/config.cpp
projects/stargazer/plugins/other/radius/config.h
projects/stargazer/plugins/other/radius/conn.cpp
projects/stargazer/plugins/other/radius/conn.h
projects/stargazer/plugins/other/radius/radius.cpp
projects/stargazer/plugins/other/radius/radius.h

index 8a90567ddec8c49bcc72ce3c016425f03e60e1b1..64ff3707d93e89f90e92c480b0fb5e4b5188c834 100644 (file)
@@ -39,6 +39,7 @@ struct ParserError : public std::runtime_error
           position(pos),
           error(message)
     {}
           position(pos),
           error(message)
     {}
+    virtual ~ParserError() throw() {}
 
     size_t position;
     std::string error;
 
     size_t position;
     std::string error;
@@ -51,12 +52,12 @@ size_t skipSpaces(const std::string& value, size_t start)
     return start;
 }
 
     return start;
 }
 
-size_t checkChar(const std:string& value, size_t start, char ch)
+size_t checkChar(const std::string& value, size_t start, char ch)
 {
     if (start >= value.length())
 {
     if (start >= value.length())
-        throw ParserError(start, "Unexpected end of string. Expected '" + std::string(ch) + "'.");
+        throw ParserError(start, "Unexpected end of string. Expected '" + std::string(1, ch) + "'.");
     if (value[start] != ch)
     if (value[start] != ch)
-        throw ParserError(start, "Expected '" + std::string(ch) + "', got '" + std::string(value[start]) + "'.");
+        throw ParserError(start, "Expected '" + std::string(1, ch) + "', got '" + std::string(1, value[start]) + "'.");
     return start + 1;
 }
 
     return start + 1;
 }
 
@@ -71,7 +72,7 @@ std::pair<size_t, std::string> readString(const std::string& value, size_t start
         else
             throw ParserError(start, "Unexpected whitespace. Expected string.");
     }
         else
             throw ParserError(start, "Unexpected whitespace. Expected string.");
     }
-    return dest;
+    return std::make_pair(start, dest);
 }
 
 Config::Pairs toPairs(const std::vector<std::string>& values)
 }
 
 Config::Pairs toPairs(const std::vector<std::string>& values)
@@ -90,11 +91,11 @@ Config::Pairs toPairs(const std::vector<std::string>& values)
         start = key.first;
         pair.first = key.second;
         start = skipSpaces(value, start);
         start = key.first;
         pair.first = key.second;
         start = skipSpaces(value, start);
-        start = checkChar(value, start, ',')
+        start = checkChar(value, start, ',');
         start = skipSpaces(value, start);
         start = skipSpaces(value, start);
-        std::pair<size_t, std::string> value = readString(value, start);
+        std::pair<size_t, std::string> val = readString(value, start);
         start = key.first;
         start = key.first;
-        pair.second = value.second;
+        pair.second = val.second;
         start = skipSpaces(value, start);
         start = checkChar(value, start, ')');
         if (res.find(pair.first) != res.end())
         start = skipSpaces(value, start);
         start = checkChar(value, start, ')');
         if (res.find(pair.first) != res.end())
@@ -125,7 +126,7 @@ T toInt(const std::vector<std::string>& values)
     if (values.empty())
         return 0;
     T res = 0;
     if (values.empty())
         return 0;
     T res = 0;
-    if (srt2x(values[0], res) == 0)
+    if (str2x(values[0], res) == 0)
         return res;
     return 0;
 }
         return res;
     return 0;
 }
@@ -133,24 +134,24 @@ T toInt(const std::vector<std::string>& values)
 Config::Pairs parseVector(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
 Config::Pairs parseVector(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].first == paramName)
-            return toPairs(params.moduleParams[i].second);
+        if (params.moduleParams[i].param == paramName)
+            return toPairs(params.moduleParams[i].value);
     return Config::Pairs();
 }
 
 bool parseBool(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
     return Config::Pairs();
 }
 
 bool parseBool(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].first == paramName)
-            return toBool(params.moduleParams[i].second);
+        if (params.moduleParams[i].param == paramName)
+            return toBool(params.moduleParams[i].value);
     return false;
 }
 
 std::string parseString(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
     return false;
 }
 
 std::string parseString(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].first == paramName)
-            return toString(params.moduleParams[i].second);
+        if (params.moduleParams[i].param == paramName)
+            return toString(params.moduleParams[i].value);
     return "";
 }
 
     return "";
 }
 
@@ -158,8 +159,8 @@ template <typename T>
 T parseInt(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
 T parseInt(const std::string& paramName, const MODULE_SETTINGS& params)
 {
     for (size_t i = 0; i < params.moduleParams.size(); ++i)
-        if (params.moduleParams[i].first == paramName)
-            return toInt<T>(params.moduleParams[i].second);
+        if (params.moduleParams[i].param == paramName)
+            return toInt<T>(params.moduleParams[i].value);
     return 0;
 }
 
     return 0;
 }
 
@@ -171,6 +172,7 @@ Config::Config(const MODULE_SETTINGS& settings)
       reply(parseVector("reply", settings)),
       verbose(parseBool("verbose", settings)),
       bindAddress(parseString("bind_address", settings)),
       reply(parseVector("reply", settings)),
       verbose(parseBool("verbose", settings)),
       bindAddress(parseString("bind_address", settings)),
+      portStr(parseString("port", settings)),
       port(parseInt<uint16_t>("port", settings)),
       key(parseString("key", settings))
 {
       port(parseInt<uint16_t>("port", settings)),
       key(parseString("key", settings))
 {
index 8e5055d8c11dd5589efd9575c9808254c729d76d..d6553e1474a045a5501b6f75bb199dda305da1ec 100644 (file)
@@ -34,7 +34,8 @@ namespace STG
 struct Config
 {
     typedef std::map<std::string, std::string> Pairs;
 struct Config
 {
     typedef std::map<std::string, std::string> Pairs;
-    typedef Pairs::value_type Pair;
+    typedef std::pair<std::string, std::string> Pair;
+    enum Type { UNIX, TCP };
 
     Config(const MODULE_SETTINGS& settings);
 
 
     Config(const MODULE_SETTINGS& settings);
 
@@ -44,7 +45,9 @@ struct Config
 
     bool verbose;
 
 
     bool verbose;
 
+    Type connectionType;
     std::string bindAddress;
     std::string bindAddress;
+    std::string portStr;
     uint16_t port;
     std::string key;
 };
     uint16_t port;
     std::string key;
 };
index 223eb2cf5b9ef1967c5b68681bd119d161138ddb..392624a6ef4846dd8d4c424baca068638e4ee4a2 100644 (file)
 
 #include "config.h"
 
 
 #include "config.h"
 
+#include "stg/json_parser.h"
+#include "stg/json_generator.h"
 #include "stg/users.h"
 #include "stg/user.h"
 #include "stg/logger.h"
 #include "stg/common.h"
 
 #include "stg/users.h"
 #include "stg/user.h"
 #include "stg/logger.h"
 #include "stg/common.h"
 
+#include <yajl/yajl_gen.h>
+
+#include <map>
 #include <cstring>
 #include <cerrno>
 
 #include <cstring>
 #include <cerrno>
 
+#include <unistd.h>
+
 using STG::Conn;
 using STG::Conn;
+using STG::Config;
+using STG::JSON::Parser;
+using STG::JSON::PairsParser;
+using STG::JSON::EnumParser;
+using STG::JSON::NodeParser;
+using STG::JSON::Gen;
+using STG::JSON::MapGen;
+using STG::JSON::StringGen;
 
 
-Conn::Conn(USERS& users, PLUGIN_LOGGER & logger, const Config& config)
-    : m_users(users),
-      m_logger(logger),
-      m_config(config)
+namespace
+{
+
+double CONN_TIMEOUT = 5;
+double PING_TIMEOUT = 1;
+
+enum Packet
+{
+    PING,
+    PONG,
+    DATA
+};
+
+enum Stage
+{
+    AUTHORIZE,
+    AUTHENTICATE,
+    PREACCT,
+    ACCOUNTING,
+    POSTAUTH
+};
+
+std::map<std::string, Packet> packetCodes;
+std::map<std::string, Stage> stageCodes;
+
+class PacketParser : public EnumParser<Packet>
+{
+    public:
+        PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
+            : EnumParser(next, packet, packetStr, packetCodes)
+        {
+            if (!packetCodes.empty())
+                return;
+            packetCodes["ping"] = PING;
+            packetCodes["pong"] = PONG;
+            packetCodes["data"] = DATA;
+        }
+};
+
+class StageParser : public EnumParser<Stage>
+{
+    public:
+        StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
+            : EnumParser(next, stage, stageStr, stageCodes)
+        {
+            if (!stageCodes.empty())
+                return;
+            stageCodes["authorize"] = AUTHORIZE;
+            stageCodes["authenticate"] = AUTHENTICATE;
+            stageCodes["preacct"] = PREACCT;
+            stageCodes["accounting"] = ACCOUNTING;
+            stageCodes["postauth"] = POSTAUTH;
+        }
+};
+
+class TopParser : public NodeParser
+{
+    public:
+        TopParser()
+            : m_packetParser(this, m_packet, m_packetStr),
+              m_stageParser(this, m_stage, m_stageStr),
+              m_pairsParser(this, m_data)
+        {}
+
+        virtual NodeParser* parseStartMap() { return this; }
+        virtual NodeParser* parseMapKey(const std::string& value)
+        {
+            std::string key = ToLower(value);
+
+            if (key == "packet")
+                return &m_packetParser;
+            else if (key == "stage")
+                return &m_stageParser;
+            else if (key == "pairs")
+                return &m_pairsParser;
+
+            return this;
+        }
+        virtual NodeParser* parseEndMap() { return this; }
+
+        const std::string& packetStr() const { return m_packetStr; }
+        Packet packet() const { return m_packet; }
+        const std::string& stageStr() const { return m_stageStr; }
+        Stage stage() const { return m_stage; }
+        const Config::Pairs& data() const { return m_data; }
+
+    private:
+        std::string m_packetStr;
+        Packet m_packet;
+        std::string m_stageStr;
+        Stage m_stage;
+        Config::Pairs m_data;
+
+        PacketParser m_packetParser;
+        StageParser m_stageParser;
+        PairsParser m_pairsParser;
+};
+
+class ProtoParser : public Parser
+{
+    public:
+        ProtoParser() : Parser( &m_topParser ) {}
+
+        const std::string& packetStr() const { return m_topParser.packetStr(); }
+        Packet packet() const { return m_topParser.packet(); }
+        const std::string& stageStr() const { return m_topParser.stageStr(); }
+        Stage stage() const { return m_topParser.stage(); }
+        const Config::Pairs& data() const { return m_topParser.data(); }
+
+    private:
+        TopParser m_topParser;
+};
+
+class PacketGen : public Gen
+{
+    public:
+        PacketGen(const std::string& type)
+            : m_type(type)
+        {
+            m_gen.add("packet", m_type);
+        }
+        void run(yajl_gen_t* handle) const
+        {
+            m_gen.run(handle);
+        }
+        PacketGen& add(const std::string& key, const std::string& value)
+        {
+            m_gen.add(key, new StringGen(value));
+            return *this;
+        }
+        PacketGen& add(const std::string& key, MapGen* map)
+        {
+            m_gen.add(key, map);
+            return *this;
+        }
+    private:
+        MapGen m_gen;
+        StringGen m_type;
+};
+
+}
+
+class Conn::Impl
+{
+    public:
+        Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
+        ~Impl();
+
+        int sock() const { return m_sock; }
+
+        bool read();
+        bool tick();
+
+        bool isOk() const { return m_ok; }
+
+    private:
+        USERS& m_users;
+        PLUGIN_LOGGER& m_logger;
+        const Config& m_config;
+        int m_sock;
+        std::string m_remote;
+        bool m_ok;
+        time_t m_lastPing;
+        time_t m_lastActivity;
+        ProtoParser m_parser;
+
+        bool process();
+        bool processPing();
+        bool processPong();
+        bool processData();
+        bool answer(const USER& user);
+        bool answerNo();
+        bool sendPing();
+        bool sendPong();
+
+        static bool write(void* data, const char* buf, size_t size);
+};
+
+Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
+    : m_impl(new Impl(users, logger, config, fd, remote))
 {
 }
 
 {
 }
 
@@ -43,21 +234,208 @@ Conn::~Conn()
 {
 }
 
 {
 }
 
+int Conn::sock() const
+{
+    return m_impl->sock();
+}
+
 bool Conn::read()
 {
 bool Conn::read()
 {
-    ssize_t res = read(m_sock, m_buffer, m_bufferSize);
+    return m_impl->read();
+}
+
+bool Conn::tick()
+{
+    return m_impl->tick();
+}
+
+bool Conn::isOk() const
+{
+    return m_impl->isOk();
+}
+
+Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
+    : m_users(users),
+      m_logger(logger),
+      m_config(config),
+      m_sock(fd),
+      m_remote(remote),
+      m_ok(true),
+      m_lastPing(time(NULL)),
+      m_lastActivity(m_lastPing)
+{
+}
+
+Conn::Impl::~Impl()
+{
+    close(m_sock);
+}
+
+bool Conn::Impl::read()
+{
+    static std::vector<char> buffer(1024);
+    ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
     if (res < 0)
     {
     if (res < 0)
     {
-        m_state = ERROR;
-        Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Reason: '" + strerror(errno) + "'");
+        m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+        m_ok = false;
         return false;
     }
         return false;
     }
-    if (res == 0 && m_state != DATA) // EOF is ok for data.
+    m_lastActivity = time(NULL);
+    if (res == 0)
+    {
+        if (!m_parser.done())
+        {
+            m_ok = false;
+            m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+            return false;
+        }
+        return process();
+    }
+    return m_parser.append(buffer.data(), res);
+}
+
+bool Conn::Impl::tick()
+{
+    time_t now = time(NULL);
+    if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
     {
     {
-        m_state = ERROR;
-        Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Unexpected EOF.");
+        m_logger("Connection to " + m_remote + " timed out.");
+        m_ok = false;
         return false;
     }
         return false;
     }
-    m_bufferSize -= res;
-    return HandleBuffer(res);
+    if (difftime(now, m_lastPing) > PING_TIMEOUT)
+        sendPing();
+    return true;
+}
+
+bool Conn::Impl::process()
+{
+    switch (m_parser.packet())
+    {
+        case PING:
+            return processPing();
+        case PONG:
+            return processPong();
+        case DATA:
+            return processData();
+    }
+    m_logger("Received invalid packet type: " + m_parser.packetStr());
+    return false;
+}
+
+bool Conn::Impl::processPing()
+{
+    return sendPong();
+}
+
+bool Conn::Impl::processPong()
+{
+    m_lastActivity = time(NULL);
+    return true;
+}
+
+bool Conn::Impl::processData()
+{
+    int handle = m_users.OpenSearch();
+
+    USER_PTR user = NULL;
+    bool match = true;
+    while (m_users.SearchNext(handle, &user))
+    {
+        if (user == NULL)
+            continue;
+
+        match = true;
+        for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
+        {
+            Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
+            if (pos == m_parser.data().end())
+            {
+                match = false;
+                break;
+            }
+            if (user->GetParamValue(it->second) != pos->second)
+            {
+                match = false;
+                break;
+            }
+        }
+        if (!match)
+            continue;
+        answer(*user);
+        break;
+    }
+
+    if (!match)
+        answerNo();
+
+    m_users.CloseSearch(handle);
+
+    return true;
+}
+
+bool Conn::Impl::answer(const USER& user)
+{
+    boost::scoped_ptr<MapGen> reply(new MapGen);
+    for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
+        reply->add(it->first, new StringGen(user.GetParamValue(it->second)));
+
+    boost::scoped_ptr<MapGen> modify(new MapGen);
+    for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
+        modify->add(it->first, new StringGen(user.GetParamValue(it->second)));
+
+    PacketGen gen("data");
+    gen.add("result", "ok")
+       .add("reply", reply.get())
+       .add("modify", modify.get());
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::answerNo()
+{
+    PacketGen gen("data");
+    gen.add("result", "ok");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPing()
+{
+    PacketGen gen("ping");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPong()
+{
+    PacketGen gen("pong");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::write(void* data, const char* buf, size_t size)
+{
+    Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
+    while (size > 0)
+    {
+        ssize_t res = ::write(conn.m_sock, buf, size);
+        if (res < 0)
+        {
+            conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
+            conn.m_ok = false;
+            return false;
+        }
+        size -= res;
+    }
+    return true;
 }
 }
index 31ebe1dd06322d1a1bf0041fa77eb0df3b077378..38f2db2ff8aaff28f22c1849030c3942ec70965c 100644 (file)
 #ifndef __STG_SGCONFIG_CONN_H__
 #define __STG_SGCONFIG_CONN_H__
 
 #ifndef __STG_SGCONFIG_CONN_H__
 #define __STG_SGCONFIG_CONN_H__
 
-#include "stg/os_int.h"
+#include <boost/scoped_ptr.hpp>
 
 
-#include <stdexcept>
 #include <string>
 
 #include <string>
 
+class USER;
 class USERS;
 class USERS;
+class PLUGIN_LOGGER;
 
 namespace STG
 {
 
 
 namespace STG
 {
 
+class Config;
+
 class Conn
 {
     public:
 class Conn
 {
     public:
-        struct Error : public std::runtime_error
-        {
-            Error(const std::string& message) : runtime_error(message.c_str()) {}
-        };
-
-        Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config);
+        Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
         ~Conn();
 
         ~Conn();
 
-        int sock() const { return m_sock; }
+        int sock() const;
 
         bool read();
 
         bool read();
+        bool tick();
+
+        bool isOk() const;
 
     private:
 
     private:
-        USERS& m_users;
-        PLUGIN_LOGGER& m_logger;
-        const Config& m_config;
+        class Impl;
+        boost::scoped_ptr<Impl> m_impl;
 };
 
 }
 };
 
 }
index fe989b288d42bc3c954a075fa1183724805be367..daa8634caa58ba8f3af9f6b85a8a724ee0c65b8c 100644 (file)
 #include "stg/users.h"
 #include "stg/plugin_creator.h"
 
 #include "stg/users.h"
 #include "stg/plugin_creator.h"
 
+#include <algorithm>
+#include <stdexcept>
 #include <csignal>
 #include <csignal>
-#include <cerror>
+#include <cerrno>
 #include <cstring>
 
 #include <cstring>
 
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h> // UNIX
+#include <netinet/in.h> // IP
+#include <netinet/tcp.h> // TCP
+#include <netdb.h>
+
+using STG::Config;
+using STG::Conn;
+
 namespace
 {
 
 namespace
 {
 
@@ -41,10 +53,12 @@ extern "C" PLUGIN * GetPlugin()
 }
 
 RADIUS::RADIUS()
 }
 
 RADIUS::RADIUS()
-    : m_running(false),
+    : m_config(m_settings),
+      m_running(false),
       m_stopped(true),
       m_users(NULL),
       m_store(NULL),
       m_stopped(true),
       m_users(NULL),
       m_store(NULL),
+      m_listenSocket(0),
       m_logger(GetPluginLogger(GetStgLogger(), "radius"))
 {
 }
       m_logger(GetPluginLogger(GetStgLogger(), "radius"))
 {
 }
@@ -53,7 +67,7 @@ int RADIUS::ParseSettings()
 {
     try {
         m_config = STG::Config(m_settings);
 {
     try {
         m_config = STG::Config(m_settings);
-        return 0;
+        return reconnect() ? 0 : -1;
     } catch (const std::runtime_error& ex) {
         m_logger("Failed to parse settings. %s", ex.what());
         return -1;
     } catch (const std::runtime_error& ex) {
         m_logger("Failed to parse settings. %s", ex.what());
         return -1;
@@ -107,6 +121,112 @@ void* RADIUS::run(void* d)
     return NULL;
 }
 
     return NULL;
 }
 
+bool RADIUS::reconnect()
+{
+    if (!m_conns.empty())
+    {
+        std::deque<STG::Conn *>::const_iterator it;
+        for (it = m_conns.begin(); it != m_conns.end(); ++it)
+            delete(*it);
+        m_conns.clear();
+    }
+    if (m_listenSocket != 0)
+    {
+        shutdown(m_listenSocket, SHUT_RDWR);
+        close(m_listenSocket);
+        if (m_config.connectionType == Config::UNIX)
+            unlink(m_config.bindAddress.c_str());
+    }
+    if (m_config.connectionType == Config::UNIX)
+        m_listenSocket = createUNIX();
+    else
+        m_listenSocket = createTCP();
+    if (m_listenSocket == 0)
+        return false;
+    if (listen(m_listenSocket, 100) == -1)
+    {
+        m_error = std::string("Error starting to listen socket: ") + strerror(errno);
+        m_logger(m_error);
+        return false;
+    }
+    return true;
+}
+
+int RADIUS::createUNIX() const
+{
+    int fd = socket(AF_UNIX, SOCK_STREAM, 0);
+    if (fd == -1)
+    {
+        m_error = std::string("Error creating UNIX socket: ") + strerror(errno);
+        m_logger(m_error);
+        return 0;
+    }
+    struct sockaddr_un addr;
+    memset(&addr, 0, sizeof(addr));
+    addr.sun_family = AF_UNIX;
+    strncpy(addr.sun_path, m_config.bindAddress.c_str(), m_config.bindAddress.length());
+    if (bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
+    {
+        shutdown(fd, SHUT_RDWR);
+        close(fd);
+        m_error = std::string("Error binding UNIX socket: ") + strerror(errno);
+        m_logger(m_error);
+        return 0;
+    }
+    return fd;
+}
+
+int RADIUS::createTCP() const
+{
+    addrinfo hints;
+    memset(&hints, 0, sizeof(addrinfo));
+
+    hints.ai_family = AF_INET;       /* Allow IPv4 */
+    hints.ai_socktype = SOCK_STREAM; /* Stream socket */
+    hints.ai_flags = AI_PASSIVE;     /* For wildcard IP address */
+    hints.ai_protocol = 0;           /* Any protocol */
+    hints.ai_canonname = NULL;
+    hints.ai_addr = NULL;
+    hints.ai_next = NULL;
+
+    addrinfo* ais = NULL;
+    int res = getaddrinfo(m_config.bindAddress.c_str(), m_config.portStr.c_str(), &hints, &ais);
+    if (res != 0)
+    {
+        m_error = "Error resolvin address '" + m_config.bindAddress + "': " + gai_strerror(res);
+        m_logger(m_error);
+        return 0;
+    }
+
+    for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
+    {
+        int fd = socket(AF_INET, SOCK_STREAM, 0);
+        if (fd == -1)
+        {
+            m_error = std::string("Error creating TCP socket: ") + strerror(errno);
+            m_logger(m_error);
+            freeaddrinfo(ais);
+            return 0;
+        }
+        if (bind(fd, ai->ai_addr, ai->ai_addrlen) == -1)
+        {
+            shutdown(fd, SHUT_RDWR);
+            close(fd);
+            m_error = std::string("Error binding TCP socket: ") + strerror(errno);
+            m_logger(m_error);
+            continue;
+        }
+        freeaddrinfo(ais);
+        return fd;
+    }
+
+    m_error = "Failed to resolve '" + m_config.bindAddress;
+    m_logger(m_error);
+
+    freeaddrinfo(ais);
+    return 0;
+}
+
 void RADIUS::runImpl()
 {
     m_running = true;
 void RADIUS::runImpl()
 {
     m_running = true;
@@ -133,6 +253,11 @@ void RADIUS::runImpl()
 
         if (res > 0)
             handleEvents(fds);
 
         if (res > 0)
             handleEvents(fds);
+        else
+        {
+            for (std::deque<Conn*>::iterator it = m_conns.begin(); it != m_conns.end(); ++it)
+                (*it)->tick();
+        }
 
         cleanupConns();
     }
 
         cleanupConns();
     }
@@ -163,7 +288,7 @@ void RADIUS::cleanupConns()
 {
     std::deque<STG::Conn *>::iterator pos;
     for (pos = m_conns.begin(); pos != m_conns.end(); ++pos)
 {
     std::deque<STG::Conn *>::iterator pos;
     for (pos = m_conns.begin(); pos != m_conns.end(); ++pos)
-        if (((*pos)->isDone() && !(*pos)->isKeepAlive()) || !(*pos)->isOk()) {
+        if (!(*pos)->isOk()) {
             delete *pos;
             *pos = NULL;
         }
             delete *pos;
             *pos = NULL;
         }
@@ -182,5 +307,46 @@ void RADIUS::handleEvents(const fd_set & fds)
         for (it = m_conns.begin(); it != m_conns.end(); ++it)
             if (FD_ISSET((*it)->sock(), &fds))
                 (*it)->read();
         for (it = m_conns.begin(); it != m_conns.end(); ++it)
             if (FD_ISSET((*it)->sock(), &fds))
                 (*it)->read();
+            else
+                (*it)->tick();
+    }
+}
+
+void RADIUS::acceptConnection()
+{
+    if (m_config.connectionType == Config::UNIX)
+        acceptUNIX();
+    else
+        acceptTCP();
+}
+
+void RADIUS::acceptUNIX()
+{
+    struct sockaddr_un addr;
+    memset(&addr, 0, sizeof(addr));
+    socklen_t size = sizeof(addr);
+    int res = accept(m_listenSocket, reinterpret_cast<sockaddr*>(&addr), &size);
+    if (res == -1)
+    {
+        m_error = std::string("Failed to accept UNIX connection: ") + strerror(errno);
+        m_logger(m_error);
+        return;
+    }
+    m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, addr.sun_path));
+}
+
+void RADIUS::acceptTCP()
+{
+    struct sockaddr_in addr;
+    memset(&addr, 0, sizeof(addr));
+    socklen_t size = sizeof(addr);
+    int res = accept(m_listenSocket, reinterpret_cast<sockaddr*>(&addr), &size);
+    if (res == -1)
+    {
+        m_error = std::string("Failed to accept TCP connection: ") + strerror(errno);
+        m_logger(m_error);
+        return;
     }
     }
+    std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port));
+    m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, remote));
 }
 }
index ff8ecf5e08410839cf5ab08fad3396bf5110afd7..2573cf9ec5809e87b833b5ef75daf70dd88eef48 100644 (file)
 #include "stg/module_settings.h"
 #include "stg/logger.h"
 
 #include "stg/module_settings.h"
 #include "stg/logger.h"
 
+#include "config.h"
+#include "conn.h"
+
 #include <string>
 #include <string>
+#include <deque>
 
 #include <pthread.h>
 #include <unistd.h>
 
 #include <pthread.h>
 #include <unistd.h>
@@ -43,10 +47,10 @@ public:
     RADIUS();
     virtual ~RADIUS() {}
 
     RADIUS();
     virtual ~RADIUS() {}
 
-    void SetUsers(USERS* u) { users = u; }
-    void SetStore(STORE* s) { store = s; }
+    void SetUsers(USERS* u) { m_users = u; }
+    void SetStore(STORE* s) { m_store = s; }
     void SetStgSettings(const SETTINGS*) {}
     void SetStgSettings(const SETTINGS*) {}
-    void SetSettings(const MODULE_SETTINGS& s) { settings = s; }
+    void SetSettings(const MODULE_SETTINGS& s) { m_settings = s; }
     int ParseSettings();
 
     int Start();
     int ParseSettings();
 
     int Start();
@@ -67,12 +71,17 @@ private:
 
     static void* run(void*);
 
 
     static void* run(void*);
 
-    void rumImpl();
+    bool reconnect();
+    int createUNIX() const;
+    int createTCP() const;
+    void runImpl();
     int maxFD() const;
     void buildFDSet(fd_set & fds) const;
     void cleanupConns();
     void handleEvents(const fd_set & fds);
     void acceptConnection();
     int maxFD() const;
     void buildFDSet(fd_set & fds) const;
     void cleanupConns();
     void handleEvents(const fd_set & fds);
     void acceptConnection();
+    void acceptUNIX();
+    void acceptTCP();
 
     mutable std::string m_error;
     STG::Config m_config;
 
     mutable std::string m_error;
     STG::Config m_config;
@@ -85,8 +94,10 @@ private:
     USERS* m_users;
     const STORE* m_store;
 
     USERS* m_users;
     const STORE* m_store;
 
+    int m_listenSocket;
+    std::deque<STG::Conn*> m_conns;
+
     pthread_t m_thread;
     pthread_t m_thread;
-    pthread_mutex_t m_mutex;
 
     PLUGIN_LOGGER m_logger;
 };
 
     PLUGIN_LOGGER m_logger;
 };