]> git.stg.codes - stg.git/commitdiff
Merge branch 'stg-2.409' into stg-2.409-radius stg-2.409-radius
authorMaxim Mamontov <madf@madf.info>
Tue, 20 Mar 2018 09:58:59 +0000 (11:58 +0200)
committerMaxim Mamontov <madf@madf.info>
Tue, 20 Mar 2018 09:58:59 +0000 (11:58 +0200)
37 files changed:
include/stg/locker.h
include/stg/module_settings.h
include/stg/ref.h [new file with mode: 0644]
projects/rlm_stg/Makefile
projects/rlm_stg/build
projects/rlm_stg/conn.cpp [new file with mode: 0644]
projects/rlm_stg/conn.h [new file with mode: 0644]
projects/rlm_stg/event.h [deleted file]
projects/rlm_stg/iface.cpp
projects/rlm_stg/iface.h
projects/rlm_stg/radlog.c [new file with mode: 0644]
projects/rlm_stg/radlog.h [new file with mode: 0644]
projects/rlm_stg/rlm_stg.c
projects/rlm_stg/stg_client.cpp
projects/rlm_stg/stg_client.h
projects/rlm_stg/stgpair.h
projects/rlm_stg/types.h [new file with mode: 0644]
projects/stargazer/build
projects/stargazer/plugins/other/radius/Makefile
projects/stargazer/plugins/other/radius/config.cpp [new file with mode: 0644]
projects/stargazer/plugins/other/radius/config.h [new file with mode: 0644]
projects/stargazer/plugins/other/radius/conn.cpp [new file with mode: 0644]
projects/stargazer/plugins/other/radius/conn.h [new file with mode: 0644]
projects/stargazer/plugins/other/radius/radius.cpp
projects/stargazer/plugins/other/radius/radius.h
projects/stargazer/settings_impl.cpp
projects/stargazer/settings_impl.h
stglibs/common.lib/Makefile
stglibs/common.lib/blockio.cpp [new file with mode: 0644]
stglibs/common.lib/common.cpp
stglibs/common.lib/include/stg/blockio.h [new file with mode: 0644]
stglibs/common.lib/include/stg/common.h
stglibs/json.lib/Makefile [new file with mode: 0644]
stglibs/json.lib/generator.cpp [new file with mode: 0644]
stglibs/json.lib/include/stg/json_generator.h [new file with mode: 0644]
stglibs/json.lib/include/stg/json_parser.h [new file with mode: 0644]
stglibs/json.lib/parser.cpp [new file with mode: 0644]

index 2a395f37698b6c3cfa9e8814c850e3409c572071..16b0323f3876c78941fe41c2b7faa310741eb5bb 100644 (file)
 class STG_LOCKER
 {
 public:
+    explicit STG_LOCKER(pthread_mutex_t& m)
+        : mutex(&m)
+        {
+        pthread_mutex_lock(mutex);
+        }
     explicit STG_LOCKER(pthread_mutex_t * m)
         : mutex(m)
         {
index 669cc2254fab0fb59c39d5dd53df571be89f1db8..b94de10c25d2f56b71cabbaf5beef54e4a0f4fcb 100644 (file)
 //-----------------------------------------------------------------------------
 struct PARAM_VALUE
 {
+    PARAM_VALUE() {}
+    PARAM_VALUE(const std::string& p, const std::vector<std::string>& vs)
+        : param(p),
+          value(vs)
+    {}
+    PARAM_VALUE(const std::string& p, const std::vector<std::string>& vs, const std::vector<PARAM_VALUE>& ss)
+        : param(p),
+          value(vs),
+          sections(ss)
+    {}
     bool operator==(const PARAM_VALUE & rhs) const
         { return !strcasecmp(param.c_str(), rhs.param.c_str()); }
 
@@ -22,10 +32,16 @@ struct PARAM_VALUE
 
     std::string param;
     std::vector<std::string> value;
+    std::vector<PARAM_VALUE> sections;
 };
 //-----------------------------------------------------------------------------
 struct MODULE_SETTINGS
 {
+    MODULE_SETTINGS() {}
+    MODULE_SETTINGS(const std::string& name, const std::vector<PARAM_VALUE>& params)
+        : moduleName(name),
+          moduleParams(params)
+    {}
     bool operator==(const MODULE_SETTINGS & rhs) const
         { return !strcasecmp(moduleName.c_str(), rhs.moduleName.c_str()); }
 
diff --git a/include/stg/ref.h b/include/stg/ref.h
new file mode 100644 (file)
index 0000000..e1bdb54
--- /dev/null
@@ -0,0 +1,54 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+namespace STG
+{
+
+// The implementation is derived from
+// http://en.cppreference.com/w/cpp/utility/functional/reference_wrapper
+// and
+// http://en.cppreference.com/w/cpp/memory/addressof
+
+template< class T >
+T* AddressOf(T& arg)
+{
+    return reinterpret_cast<T*>(&const_cast<char&>(reinterpret_cast<const volatile char&>(arg)));
+}
+
+template <class T>
+class RefWrapper {
+    public:
+          typedef T type;
+          RefWrapper(T& ref) throw() : m_ptr(AddressOf(ref)) {}
+          RefWrapper(const RefWrapper& rhs) : m_ptr(rhs.m_ptr) {}
+          RefWrapper& operator=(const RefWrapper& rhs) throw() { m_ptr = rhs.m_ptr; }
+          operator T& () const throw() { return *m_ptr; }
+          T& get() const throw() { return *m_ptr; }
+
+          void operator()() { (*m_ptr)(); }
+          template <typename A>
+          void operator()(A a) { (*m_ptr)(a); }
+          template <typename A1, typename A2>
+          void operator()(A1 a1, A2 a2) { (*m_ptr)(a1, a2); }
+    private:
+          T* m_ptr;
+};
+
+} // namespace STG
index 05c43d95135b559866886ac904108ce3fc074cfc..27f486b410f443ac68bf6047b9868864bb7d1651 100644 (file)
@@ -8,12 +8,15 @@ LIB_NAME = rlm_stg
 
 PROG = $(LIB_NAME).so
 
-SRCS = ./rlm_stg.c \
-       ./iface.cpp \
-       ./stg_client.cpp
+SRCS = rlm_stg.c \
+       iface.cpp \
+       stg_client.cpp \
+       conn.cpp \
+       radlog.c
 
 STGLIBS = crypto \
-          common
+          common \
+          json
 
 STGLIBS_INCS = $(addprefix -I ../../stglibs/,$(addsuffix .lib/include,$(STGLIBS)))
 STGLIBS_LIBS = $(addprefix -L ../../stglibs/,$(addsuffix .lib,$(STGLIBS)))
@@ -21,16 +24,19 @@ STGLIBS_LIBS = $(addprefix -L ../../stglibs/,$(addsuffix .lib,$(STGLIBS)))
 LIBS += $(addprefix -lstg,$(STGLIBS)) $(LIB_THREAD) $(LIBICONV)
 
 ifeq ($(OS),linux)
-LIBS += -ldl
+LIBS += -ldl \
+       -lyajl
 else
 LIBS += -lintl \
-        -lc
+        -lc \
+       -lyajl
 endif
 
 SEARCH_DIRS = -I ../../include
 
 OBJS = $(notdir $(patsubst %.cpp, %.o, $(patsubst %.c, %.o, $(SRCS))))
 
+CFLAGS += -fPIC $(DEFS) $(STGLIBS_INCS) $(SEARCH_DIRS)
 CXXFLAGS += -fPIC $(DEFS) $(STGLIBS_INCS) $(SEARCH_DIRS)
 CFLAGS += $(DEFS) $(STGLIBS_INCS) $(SEARCH_DIRS)
 LDFLAGS += -shared -Wl,-rpath,$(PREFIX)/usr/lib/stg -Wl,-E $(STGLIBS_LIBS)
index 179d526188d9aef6e1370ca024d04d01063a27e9..8d63ae64e123b58de0c4370b116c46047c736ee0 100755 (executable)
@@ -20,6 +20,7 @@ if [ "$1" = "debug" ]
 then
    DEFS="$DEFS -DDEBUG"
    MAKEOPTS="$MAKEOPTS -j1"
+   CFLAGS="$CFLAGS -ggdb3 -W -Wall -Wextra"
    CXXFLAGS="$CXXFLAGS -ggdb3 -W -Wall -Wextra"
    DEBUG="yes"
 else
@@ -27,6 +28,7 @@ else
    DEBUG="no"
 fi
 
+CFLAGS="$CFLAGS -I/usr/local/include"
 CXXFLAGS="$CXXFLAGS -I/usr/local/include"
 LDFLAGS="$LDFLAGS -L/usr/local/lib"
 
@@ -61,7 +63,7 @@ printf "########################################################################
 printf "       Building rlm_stg for $sys $release\n"
 printf "#############################################################################\n"
 
-STG_LIBS="crypto.lib common.lib"
+STG_LIBS="crypto.lib common.lib json.lib"
 
 if [ "$OS" = "linux" ]
 then
@@ -150,6 +152,68 @@ else
 fi
 rm -f fake
 
+printf "Checking for -lyajl... "
+pkg-config --version > /dev/null 2> /dev/null
+if [ "$?" = "0" ]
+then
+    pkg-config --atleast-version=2.0.0 yajl
+    if [ "$?" != "0" ]
+    then
+        CHECK_YAJL=no
+        printf "no\n"
+        exit;
+    else
+        CHECK_YAJL=yes
+        printf `pkg-config --modversion yajl`"\n"
+    fi
+else
+    printf "#include <stdio.h>\n" > build_check.c
+    printf "#include <yajl/yajl_version.h>\n" >> build_check.c
+    printf "int main() { printf(\"%%d\", yajl_version()); return 0; }\n" >> build_check.c
+    $CC $CFLAGS $LDFLAGS build_check.c -lyajl -o fake > /dev/null 2> /dev/null
+    if [ $? != 0 ]
+    then
+        CHECK_YAJL=no
+        printf "no\n"
+        exit;
+    else
+        YAJL_VERSION=`./fake`
+        if [ $YAJL_VERSION -ge 20000 ]
+        then
+            CHECK_YAJL=yes
+            printf "${YAJL_VERSION}\n"
+        else
+            CHECK_YAJL=no
+            printf "no. Need at least version 2.0.0, existing version is ${YAJL_VERSION}\n"
+            exit;
+        fi
+    fi
+    rm -f fake
+fi
+
+printf "Checking for boost::scoped_ptr... "
+printf "#include <boost/scoped_ptr.hpp>\nint main() { boost::scoped_ptr<int> test(new int(1)); return 0; }\n" > build_check.cpp
+$CXX $CXXFLAGS $LDFLAGS build_check.cpp -o fake # > /dev/null 2> /dev/null
+if [ $? != 0 ]
+then
+    CHECK_BOOST_SCOPED_PTR=no
+    printf "no\n"
+    exit;
+else
+    CHECK_BOOST_SCOPED_PTR=yes
+    printf "yes\n"
+fi
+rm -f fake
+
+rm -f build_check.c
+rm -f build_check.cpp
+
+if [ "$CHECK_YAJL" = "yes" -a "$CHECK_BOOST_SCOPED_PTR" = "yes" ]
+then
+    STG_LIBS="$STG_LIBS
+              json.lib"
+fi
+
 printf "OS=$OS\n" > $CONFFILE
 printf "STG_TIME=yes\n" >> $CONFFILE
 printf "DEBUG=$DEBUG\n" >> $CONFFILE
@@ -165,6 +229,8 @@ do
     printf "$lib " >> $CONFFILE
 done
 printf "\n" >> $CONFFILE
+printf "CC=$CC\n" >> $CONFFILE
+printf "CXX=$CXX\n" >> $CONFFILE
 printf "CXXFLAGS=$CXXFLAGS\n" >> $CONFFILE
 printf "CFLAGS=$CFLAGS\n" >> $CONFFILE
 printf "LDFLAGS=$LDFLAGS\n" >> $CONFFILE
diff --git a/projects/rlm_stg/conn.cpp b/projects/rlm_stg/conn.cpp
new file mode 100644 (file)
index 0000000..3589a90
--- /dev/null
@@ -0,0 +1,686 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#include "conn.h"
+
+#include "radlog.h"
+#include "stgpair.h"
+
+#include "stg/json_parser.h"
+#include "stg/json_generator.h"
+#include "stg/locker.h"
+
+#include <cerrno>
+#include <cstring>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h> // UNIX
+#include <netinet/in.h> // IP
+#include <netinet/tcp.h> // TCP
+#include <netdb.h>
+
+namespace RLM = STG::RLM;
+
+using RLM::Conn;
+using STG::JSON::Parser;
+using STG::JSON::PairsParser;
+using STG::JSON::EnumParser;
+using STG::JSON::NodeParser;
+using STG::JSON::Gen;
+using STG::JSON::MapGen;
+using STG::JSON::StringGen;
+
+namespace
+{
+
+double CONN_TIMEOUT = 60;
+double PING_TIMEOUT = 10;
+
+struct ChannelConfig {
+    struct Error : std::runtime_error {
+        explicit Error(const std::string& message) : runtime_error(message) {}
+    };
+
+    explicit ChannelConfig(std::string address);
+
+    std::string transport;
+    std::string key;
+    std::string address;
+    std::string portStr;
+    uint16_t port;
+};
+
+std::string toStage(RLM::REQUEST_TYPE type)
+{
+    switch (type)
+    {
+        case RLM::AUTHORIZE: return "authorize";
+        case RLM::AUTHENTICATE: return "authenticate";
+        case RLM::POST_AUTH: return "postauth";
+        case RLM::PRE_ACCT: return "preacct";
+        case RLM::ACCOUNT: return "accounting";
+    }
+    return "";
+}
+
+enum Packet
+{
+    PING,
+    PONG,
+    DATA
+};
+
+std::map<std::string, Packet> packetCodes;
+std::map<std::string, bool> resultCodes;
+std::map<std::string, int> returnCodes;
+
+class PacketParser : public EnumParser<Packet>
+{
+    public:
+        PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
+            : EnumParser(next, packet, packetStr, packetCodes)
+        {
+            if (!packetCodes.empty())
+                return;
+            packetCodes["ping"] = PING;
+            packetCodes["pong"] = PONG;
+            packetCodes["data"] = DATA;
+        }
+};
+
+class ResultParser : public EnumParser<bool>
+{
+    public:
+        ResultParser(NodeParser* next, bool& result, std::string& resultStr)
+            : EnumParser(next, result, resultStr, resultCodes)
+        {
+            if (!resultCodes.empty())
+                return;
+            resultCodes["no"] = false;
+            resultCodes["ok"] = true;
+        }
+};
+
+class ReturnCodeParser : public EnumParser<int>
+{
+    public:
+        ReturnCodeParser(NodeParser* next, int& returnCode, std::string& returnCodeStr)
+            : EnumParser(next, returnCode, returnCodeStr, returnCodes)
+        {
+            if (!returnCodes.empty())
+                return;
+            returnCodes["reject"]   = STG_REJECT;
+            returnCodes["fail"]     = STG_FAIL;
+            returnCodes["ok"]       = STG_OK;
+            returnCodes["handled"]  = STG_HANDLED;
+            returnCodes["invalid"]  = STG_INVALID;
+            returnCodes["userlock"] = STG_USERLOCK;
+            returnCodes["notfound"] = STG_NOTFOUND;
+            returnCodes["noop"]     = STG_NOOP;
+            returnCodes["updated"]  = STG_UPDATED;
+        }
+};
+
+class TopParser : public NodeParser
+{
+    public:
+        typedef void (*Callback) (void* /*data*/);
+        TopParser(Callback callback, void* data)
+            : m_packet(PING),
+              m_result(false),
+              m_returnCode(STG_REJECT),
+              m_packetParser(this, m_packet, m_packetStr),
+              m_resultParser(this, m_result, m_resultStr),
+              m_returnCodeParser(this, m_returnCode, m_returnCodeStr),
+              m_replyParser(this, m_reply),
+              m_modifyParser(this, m_modify),
+              m_callback(callback), m_data(data)
+        {}
+
+        virtual NodeParser* parseStartMap() { return this; }
+        virtual NodeParser* parseMapKey(const std::string& value)
+        {
+            std::string key = ToLower(value);
+
+            if (key == "packet")
+                return &m_packetParser;
+            else if (key == "result")
+                return &m_resultParser;
+            else if (key == "reply")
+                return &m_replyParser;
+            else if (key == "modify")
+                return &m_modifyParser;
+            else if (key == "return_code")
+                return &m_returnCodeParser;
+
+            return this;
+        }
+        virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
+
+        const std::string& packetStr() const { return m_packetStr; }
+        Packet packet() const { return m_packet; }
+        const std::string& resultStr() const { return m_resultStr; }
+        bool result() const { return m_result; }
+        const std::string& returnCodeStr() const { return m_returnCodeStr; }
+        int returnCode() const { return m_returnCode; }
+        const PairsParser::Pairs& reply() const { return m_reply; }
+        const PairsParser::Pairs& modify() const { return m_modify; }
+
+    private:
+        std::string m_packetStr;
+        Packet m_packet;
+        std::string m_resultStr;
+        bool m_result;
+        std::string m_returnCodeStr;
+        int m_returnCode;
+        PairsParser::Pairs m_reply;
+        PairsParser::Pairs m_modify;
+
+        PacketParser m_packetParser;
+        ResultParser m_resultParser;
+        ReturnCodeParser m_returnCodeParser;
+        PairsParser m_replyParser;
+        PairsParser m_modifyParser;
+
+        Callback m_callback;
+        void* m_data;
+};
+
+class ProtoParser : public Parser
+{
+    public:
+        ProtoParser(TopParser::Callback callback, void* data)
+            : Parser( &m_topParser ),
+              m_topParser(callback, data)
+        {}
+
+        const std::string& packetStr() const { return m_topParser.packetStr(); }
+        Packet packet() const { return m_topParser.packet(); }
+        const std::string& resultStr() const { return m_topParser.resultStr(); }
+        bool result() const { return m_topParser.result(); }
+        const std::string& returnCodeStr() const { return m_topParser.returnCodeStr(); }
+        int returnCode() const { return m_topParser.returnCode(); }
+        const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
+        const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
+
+    private:
+        TopParser m_topParser;
+};
+
+class PacketGen : public Gen
+{
+    public:
+        explicit PacketGen(const std::string& type)
+            : m_type(type)
+        {
+            m_gen.add("packet", m_type);
+        }
+        void run(yajl_gen_t* handle) const
+        {
+            m_gen.run(handle);
+        }
+        PacketGen& add(const std::string& key, const std::string& value)
+        {
+            m_gen.add(key, new StringGen(value));
+            return *this;
+        }
+        PacketGen& add(const std::string& key, MapGen& map)
+        {
+            m_gen.add(key, map);
+            return *this;
+        }
+    private:
+        MapGen m_gen;
+        StringGen m_type;
+};
+
+}
+
+class Conn::Impl
+{
+public:
+    Impl(const std::string& address, Callback callback, void* data);
+    ~Impl();
+
+    bool stop();
+    bool connected() const { return m_connected; }
+
+    bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
+
+private:
+    ChannelConfig m_config;
+
+    int m_sock;
+
+    bool m_running;
+    bool m_stopped;
+
+    time_t m_lastPing;
+    time_t m_lastActivity;
+
+    pthread_t m_thread;
+    pthread_mutex_t m_mutex;
+
+    Callback m_callback;
+    void* m_data;
+
+    ProtoParser m_parser;
+
+    bool m_connected;
+
+    void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password);
+    void m_writePairBlock(const PAIRS& source);
+    PAIRS m_readPairBlock();
+
+    static void* run(void* );
+
+    void runImpl();
+
+    bool start();
+
+    int connect();
+    int connectTCP();
+    int connectUNIX();
+
+    bool read();
+    bool tick();
+
+    static void process(void* data);
+    void processPing();
+    void processPong();
+    void processData();
+    bool sendPing();
+    bool sendPong();
+
+    static bool write(void* data, const char* buf, size_t size);
+};
+
+ChannelConfig::ChannelConfig(std::string addr)
+{
+    // unix:pass@/var/run/stg.sock
+    // tcp:secret@192.168.0.1:12345
+    // udp:key@isp.com.ua:54321
+
+    size_t pos = addr.find_first_of(':');
+    if (pos == std::string::npos)
+        throw Error("Missing transport name.");
+    transport = ToLower(addr.substr(0, pos));
+    addr = addr.substr(pos + 1);
+    if (addr.empty())
+        throw Error("Missing address to connect to.");
+    pos = addr.find_first_of('@');
+    if (pos != std::string::npos) {
+        key = addr.substr(0, pos);
+        addr = addr.substr(pos + 1);
+        if (addr.empty())
+            throw Error("Missing address to connect to.");
+    }
+    if (transport == "unix")
+    {
+        address = addr;
+        return;
+    }
+    pos = addr.find_first_of(':');
+    if (pos == std::string::npos)
+        throw Error("Missing port.");
+    address = addr.substr(0, pos);
+    portStr = addr.substr(pos + 1);
+    if (str2x(portStr, port))
+        throw Error("Invalid port value.");
+}
+
+Conn::Conn(const std::string& address, Callback callback, void* data)
+    : m_impl(new Impl(address, callback, data))
+{
+}
+
+Conn::~Conn()
+{
+}
+
+bool Conn::stop()
+{
+    return m_impl->stop();
+}
+
+bool Conn::connected() const
+{
+    return m_impl->connected();
+}
+
+bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
+{
+    return m_impl->request(type, userName, password, pairs);
+}
+
+Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
+    : m_config(address),
+      m_sock(connect()),
+      m_running(false),
+      m_stopped(true),
+      m_lastPing(time(NULL)),
+      m_lastActivity(m_lastPing),
+      m_callback(callback),
+      m_data(data),
+      m_parser(&Conn::Impl::process, this),
+      m_connected(true)
+{
+    pthread_mutex_init(&m_mutex, NULL);
+}
+
+Conn::Impl::~Impl()
+{
+    stop();
+    shutdown(m_sock, SHUT_RDWR);
+    close(m_sock);
+    pthread_mutex_destroy(&m_mutex);
+}
+
+bool Conn::Impl::stop()
+{
+    m_connected = false;
+
+    if (m_stopped)
+        return true;
+
+    m_running = false;
+
+    for (size_t i = 0; i < 25 && !m_stopped; i++) {
+        struct timespec ts = {0, 200000000};
+        nanosleep(&ts, NULL);
+    }
+
+    if (m_stopped) {
+        pthread_join(m_thread, NULL);
+        return true;
+    }
+
+    return false;
+}
+
+bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
+{
+    if (!m_running)
+        if (!start())
+            return false;
+    MapGen map;
+    for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
+        map.add(it->first, new StringGen(it->second));
+    map.add("Radius-Username", new StringGen(userName));
+    map.add("Radius-Userpass", new StringGen(password));
+
+    PacketGen gen("data");
+    gen.add("stage", toStage(type))
+       .add("pairs", map);
+
+    STG_LOCKER lock(m_mutex);
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+void Conn::Impl::runImpl()
+{
+    m_running = true;
+
+    while (m_running) {
+        fd_set fds;
+
+        FD_ZERO(&fds);
+        FD_SET(m_sock, &fds);
+
+        struct timeval tv;
+        tv.tv_sec = 0;
+        tv.tv_usec = 500000;
+
+        int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
+        if (res < 0)
+        {
+            if (errno == EINTR)
+                continue;
+            RadLog("'select' is failed: %s", strerror(errno));
+            break;
+        }
+
+
+        if (!m_running)
+            break;
+
+        STG_LOCKER lock(m_mutex);
+
+        if (res > 0)
+        {
+            if (FD_ISSET(m_sock, &fds))
+                m_running = read();
+        }
+        else
+            m_running = tick();
+    }
+
+    m_connected = false;
+    m_stopped = true;
+}
+
+bool Conn::Impl::start()
+{
+    int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
+    if (res != 0)
+        return false;
+    return true;
+}
+
+int Conn::Impl::connect()
+{
+    if (m_config.transport == "tcp")
+        return connectTCP();
+    else if (m_config.transport == "unix")
+        return connectUNIX();
+    throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
+}
+
+int Conn::Impl::connectTCP()
+{
+    addrinfo hints;
+    memset(&hints, 0, sizeof(addrinfo));
+
+    hints.ai_family = AF_INET;       /* Allow IPv4 */
+    hints.ai_socktype = SOCK_STREAM; /* Stream socket */
+    hints.ai_flags = 0;     /* For wildcard IP address */
+    hints.ai_protocol = 0;           /* Any protocol */
+    hints.ai_canonname = NULL;
+    hints.ai_addr = NULL;
+    hints.ai_next = NULL;
+
+    addrinfo* ais = NULL;
+    int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
+    if (res != 0)
+        throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
+
+    for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
+    {
+        int fd = socket(AF_INET, SOCK_STREAM, 0);
+        if (fd == -1)
+        {
+            Error error(std::string("Error creating TCP socket: ") + strerror(errno));
+            freeaddrinfo(ais);
+            throw error;
+        }
+        if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
+        {
+            shutdown(fd, SHUT_RDWR);
+            close(fd);
+            RadLog("'connect' is failed: %s", strerror(errno));
+            continue;
+        }
+        freeaddrinfo(ais);
+        return fd;
+    }
+
+    freeaddrinfo(ais);
+
+    throw Error("Failed to resolve '" + m_config.address);
+};
+
+int Conn::Impl::connectUNIX()
+{
+    int fd = socket(AF_UNIX, SOCK_STREAM, 0);
+    if (fd == -1)
+        throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
+    struct sockaddr_un addr;
+    memset(&addr, 0, sizeof(addr));
+    addr.sun_family = AF_UNIX;
+    strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
+    if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
+    {
+        Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
+        shutdown(fd, SHUT_RDWR);
+        close(fd);
+        throw error;
+    }
+    return fd;
+}
+
+bool Conn::Impl::read()
+{
+    static std::vector<char> buffer(1024);
+    ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
+    if (res < 0)
+    {
+        RadLog("Failed to read data: %s", strerror(errno));
+        return false;
+    }
+    m_lastActivity = time(NULL);
+    RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
+    if (res == 0)
+    {
+        m_parser.last();
+        return false;
+    }
+    return m_parser.append(buffer.data(), res);
+}
+
+bool Conn::Impl::tick()
+{
+    time_t now = time(NULL);
+    if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
+    {
+        int delta = difftime(now, m_lastActivity);
+        RadLog("Connection timeout: %d sec.", delta);
+        //m_logger("Connection to " + m_remote + " timed out.");
+        return false;
+    }
+    if (difftime(now, m_lastPing) > PING_TIMEOUT)
+    {
+        int delta = difftime(now, m_lastPing);
+        RadLog("Ping timeout: %d sec. Sending ping...", delta);
+        sendPing();
+    }
+    return true;
+}
+
+void Conn::Impl::process(void* data)
+{
+    Impl& impl = *static_cast<Impl*>(data);
+    switch (impl.m_parser.packet())
+    {
+        case PING:
+            impl.processPing();
+            return;
+        case PONG:
+            impl.processPong();
+            return;
+        case DATA:
+            impl.processData();
+            return;
+    }
+    RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
+}
+
+void Conn::Impl::processPing()
+{
+    sendPong();
+}
+
+void Conn::Impl::processPong()
+{
+    m_lastActivity = time(NULL);
+}
+
+void Conn::Impl::processData()
+{
+    RESULT data;
+    if (m_parser.result())
+    {
+        for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
+            data.reply.push_back(std::make_pair(it->first, it->second));
+        for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
+            data.modify.push_back(std::make_pair(it->first, it->second));
+        data.returnCode = STG_UPDATED;
+    }
+    else
+        data.returnCode = m_parser.returnCode();
+    m_callback(m_data, data);
+}
+
+bool Conn::Impl::sendPing()
+{
+    PacketGen gen("ping");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPong()
+{
+    PacketGen gen("pong");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::write(void* data, const char* buf, size_t size)
+{
+    std::string json(buf, size);
+    RadLog("Sending JSON: %s", json.c_str());
+    Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
+    while (size > 0)
+    {
+        ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
+        if (res < 0)
+        {
+            impl.m_connected = false;
+            RadLog("Failed to write data: %s.", strerror(errno));
+            return false;
+        }
+        size -= res;
+    }
+    return true;
+}
+
+void* Conn::Impl::run(void* data)
+{
+    Impl& impl = *static_cast<Impl*>(data);
+    impl.runImpl();
+    return NULL;
+}
diff --git a/projects/rlm_stg/conn.h b/projects/rlm_stg/conn.h
new file mode 100644 (file)
index 0000000..6233b15
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#ifndef __STG_RLM_CONN_H__
+#define __STG_RLM_CONN_H__
+
+#include "types.h"
+
+#include "stg/os_int.h"
+
+#include <boost/scoped_ptr.hpp>
+
+#include <string>
+#include <stdexcept>
+
+namespace STG
+{
+namespace RLM
+{
+
+class Conn
+{
+    public:
+        struct Error : std::runtime_error {
+            explicit Error(const std::string& message) : runtime_error(message) {}
+        };
+
+        typedef bool (*Callback)(void* /*data*/, const RESULT& /*result*/);
+
+        Conn(const std::string& address, Callback callback, void* data);
+        ~Conn();
+
+        bool stop();
+        bool connected() const;
+
+        bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
+
+    private:
+        class Impl;
+        boost::scoped_ptr<Impl> m_impl;
+};
+
+}
+}
+
+#endif
diff --git a/projects/rlm_stg/event.h b/projects/rlm_stg/event.h
deleted file mode 100644 (file)
index 83a72d0..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-#ifndef FR_EVENT_H
-#define FR_EVENT_H
-
-/*
- * event.h     Simple event queue
- *
- * Version:    $Id: event.h,v 1.1 2010/08/14 04:13:52 faust Exp $
- *
- *   This program is free software; you can redistribute it and/or modify
- *   it under the terms of the GNU General Public License as published by
- *   the Free Software Foundation; either version 2 of the License, or
- *   (at your option) any later version.
- *
- *   This program is distributed in the hope that it will be useful,
- *   but WITHOUT ANY WARRANTY; without even the implied warranty of
- *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- *   GNU General Public License for more details.
- *
- *   You should have received a copy of the GNU General Public License
- *   along with this program; if not, write to the Free Software
- *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
- *
- * Copyright 2007 The FreeRADIUS server project
- * Copyright 2007 Alan DeKok <aland@deployingradius.com>
- */
-
-//#include <freeradius/ident.h>
-//RCSIDH(event_h, "$Id: event.h,v 1.1 2010/08/14 04:13:52 faust Exp $")
-
-typedef struct fr_event_list_t fr_event_list_t;
-typedef struct fr_event_t fr_event_t;
-
-typedef        void (*fr_event_callback_t)(void *);
-typedef        void (*fr_event_status_t)(struct timeval *);
-typedef void (*fr_event_fd_handler_t)(fr_event_list_t *el, int sock, void *ctx);
-
-fr_event_list_t *fr_event_list_create(fr_event_status_t status);
-void fr_event_list_free(fr_event_list_t *el);
-
-int fr_event_list_num_elements(fr_event_list_t *el);
-
-int fr_event_insert(fr_event_list_t *el,
-                     fr_event_callback_t callback,
-                     void *ctx, struct timeval *when, fr_event_t **ev_p);
-int fr_event_delete(fr_event_list_t *el, fr_event_t **ev_p);
-
-int fr_event_run(fr_event_list_t *el, struct timeval *when);
-
-int fr_event_now(fr_event_list_t *el, struct timeval *when);
-
-int fr_event_fd_insert(fr_event_list_t *el, int type, int fd,
-                        fr_event_fd_handler_t handler, void *ctx);
-int fr_event_fd_delete(fr_event_list_t *el, int type, int fd);
-int fr_event_loop(fr_event_list_t *el);
-void fr_event_loop_exit(fr_event_list_t *el, int code);
-
-#endif /* FR_HASH_H */
index a74f32594136bc65f37084535e91e5887b680721..f97593f476a41ab0829523639f723ce9e3a55831 100644 (file)
 #include "iface.h"
 
-#include "thriftclient.h"
+#include "stg_client.h"
+#include "types.h"
+#include "radlog.h"
 
-int stgInstantiateImpl(const char * server, uint16_t port, const char * password)
+#include <stdexcept>
+#include <cstring>
+
+#include <strings.h>
+
+namespace RLM = STG::RLM;
+
+using RLM::Client;
+using RLM::PAIRS;
+using RLM::RESULT;
+using RLM::REQUEST_TYPE;
+
+namespace
 {
-    if (STG_CLIENT_ST::Get().Configure(server, port, password))
-        return 1;
 
-    return 0;
+STG_PAIR* toSTGPairs(const PAIRS& source)
+{
+    STG_PAIR * pairs = new STG_PAIR[source.size() + 1];
+    for (size_t pos = 0; pos < source.size(); ++pos) {
+        bzero(pairs[pos].key, sizeof(pairs[pos].key));
+        bzero(pairs[pos].value, sizeof(pairs[pos].value));
+        strncpy(pairs[pos].key, source[pos].first.c_str(), sizeof(pairs[pos].key));
+        strncpy(pairs[pos].value, source[pos].second.c_str(), sizeof(pairs[pos].value));
+    }
+    bzero(pairs[source.size()].key, sizeof(pairs[source.size()].key));
+    bzero(pairs[source.size()].value, sizeof(pairs[source.size()].value));
+
+    return pairs;
 }
 
-const STG_PAIR * stgAuthorizeImpl(const char * userName, const char * serviceType)
+PAIRS fromSTGPairs(const STG_PAIR* pairs)
 {
-    return STG_CLIENT_ST::Get().Authorize(userName, serviceType);
+    const STG_PAIR* pair = pairs;
+    PAIRS res;
+
+    while (!emptyPair(pair)) {
+        res.push_back(std::pair<std::string, std::string>(pair->key, pair->value));
+        ++pair;
+    }
+
+    return res;
+}
+
+STG_RESULT toResult(const RESULT& source)
+{
+    STG_RESULT result;
+    result.modify = toSTGPairs(source.modify);
+    result.reply = toSTGPairs(source.reply);
+    result.returnCode = source.returnCode;
+    return result;
 }
 
-const STG_PAIR * stgAuthenticateImpl(const char * userName, const char * serviceType)
+STG_RESULT emptyResult()
 {
-    return STG_CLIENT_ST::Get().Authenticate(userName, serviceType);
+    STG_RESULT result = {NULL, NULL, STG_REJECT};
+    return result;
 }
 
-const STG_PAIR * stgPostAuthImpl(const char * userName, const char * serviceType)
+std::string toString(const char* value)
 {
-    return STG_CLIENT_ST::Get().PostAuth(userName, serviceType);
+    if (value == NULL)
+        return "";
+    else
+        return value;
 }
 
-/*const STG_PAIR * stgPreAcctImpl(const char * userName, const char * serviceType)
+STG_RESULT stgRequest(REQUEST_TYPE type, const char* userName, const char* password, const STG_PAIR* pairs)
 {
-    return STG_CLIENT_ST::Get().PreAcct(userName, serviceType);
-}*/
+    Client* client = Client::get();
+    if (client == NULL) {
+        RadLog("Client is not configured.");
+        return emptyResult();
+    }
+    try {
+        return toResult(client->request(type, toString(userName), toString(password), fromSTGPairs(pairs)));
+    } catch (const std::runtime_error& ex) {
+        RadLog("Error: '%s'.", ex.what());
+        return emptyResult();
+    }
+}
+
+}
+
+int stgInstantiateImpl(const char* address)
+{
+    if (Client::configure(toString(address)))
+        return 1;
+
+    return 0;
+}
+
+STG_RESULT stgAuthorizeImpl(const char* userName, const char* password, const STG_PAIR* pairs)
+{
+    return stgRequest(RLM::AUTHORIZE, userName, password, pairs);
+}
+
+STG_RESULT stgAuthenticateImpl(const char* userName, const char* password, const STG_PAIR* pairs)
+{
+    return stgRequest(RLM::AUTHENTICATE, userName, password, pairs);
+}
+
+STG_RESULT stgPostAuthImpl(const char* userName, const char* password, const STG_PAIR* pairs)
+{
+    return stgRequest(RLM::POST_AUTH, userName, password, pairs);
+}
 
-const STG_PAIR * stgAccountingImpl(const char * userName, const char * serviceType, const char * statusType, const char * sessionId)
+STG_RESULT stgPreAcctImpl(const char* userName, const char* password, const STG_PAIR* pairs)
 {
-    return STG_CLIENT_ST::Get().Account(userName, serviceType, statusType, sessionId);
+    return stgRequest(RLM::PRE_ACCT, userName, password, pairs);
 }
 
-void deletePairs(const STG_PAIR * pairs)
+STG_RESULT stgAccountingImpl(const char* userName, const char* password, const STG_PAIR* pairs)
 {
-    delete[] pairs;
+    return stgRequest(RLM::ACCOUNT, userName, password, pairs);
 }
index 831c31231bae8c30e09869902c7b2c97d628655d..e863e939ce650d525c85bddf527138bd36f1b32f 100644 (file)
@@ -9,14 +9,12 @@
 extern "C" {
 #endif
 
-int stgInstantiateImpl(const char * server, uint16_t port, const char * password);
-const STG_PAIR * stgAuthorizeImpl(const char * userName, const char * serviceType);
-const STG_PAIR * stgAuthenticateImpl(const char * userName, const char * serviceType);
-const STG_PAIR * stgPostAuthImpl(const char * userName, const char * serviceType);
-/*const STG_PAIR * stgPreAcctImpl(const char * userName, const char * serviceType);*/
-const STG_PAIR * stgAccountingImpl(const char * userName, const char * serviceType, const char * statusType, const char * sessionId);
-
-void deletePairs(const STG_PAIR * pairs);
+int stgInstantiateImpl(const char* address);
+STG_RESULT stgAuthorizeImpl(const char* userName, const char* password, const STG_PAIR* vps);
+STG_RESULT stgAuthenticateImpl(const char* userName, const char* password, const STG_PAIR* vps);
+STG_RESULT stgPostAuthImpl(const char* userName, const char* password, const STG_PAIR* vps);
+STG_RESULT stgPreAcctImpl(const char* userName, const char* password, const STG_PAIR* vps);
+STG_RESULT stgAccountingImpl(const char* userName, const char* password, const STG_PAIR* vps);
 
 #ifdef __cplusplus
 }
diff --git a/projects/rlm_stg/radlog.c b/projects/rlm_stg/radlog.c
new file mode 100644 (file)
index 0000000..523dc1c
--- /dev/null
@@ -0,0 +1,23 @@
+#include "radlog.h"
+
+//#ifndef NDEBUG
+//#define NDEBUG
+#include <freeradius/ident.h>
+#include <freeradius/radiusd.h>
+#include <freeradius/modules.h>
+//#undef NDEBUG
+//#endif
+
+#include <stdarg.h>
+
+void RadLog(const char* format, ...)
+{
+    char buf[1024];
+
+    va_list vl;
+    va_start(vl, format);
+    vsnprintf(buf, sizeof(buf), format, vl);
+    va_end(vl);
+
+    DEBUG("[rlm_stg] *** %s", buf);
+}
diff --git a/projects/rlm_stg/radlog.h b/projects/rlm_stg/radlog.h
new file mode 100644 (file)
index 0000000..00a5dcb
--- /dev/null
@@ -0,0 +1,14 @@
+#ifndef __STG_RADLOG_H__
+#define __STG_RADLOG_H__
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void RadLog(const char* format, ...);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
index e1caf57504faeac210119071efb5fa963dd4f300..84df4e7b7cbf683d2a44b4e0e129f2431fedf8d9 100644 (file)
  *
  */
 
-#ifndef NDEBUG
-#define NDEBUG
+#include "iface.h"
+#include "stgpair.h"
+
 #include <freeradius/ident.h>
 #include <freeradius/radiusd.h>
 #include <freeradius/modules.h>
-#undef NDEBUG
-#endif
 
-#include "stgpair.h"
-#include "iface.h"
+#include <stddef.h> // size_t
 
 typedef struct rlm_stg_t {
-    char * server;
-    uint16_t port;
-    char * password;
+    char* address;
 } rlm_stg_t;
 
 static const CONF_PARSER module_config[] = {
-  { "server",  PW_TYPE_STRING_PTR, offsetof(rlm_stg_t,server), NULL,  "localhost"},
-  { "port",  PW_TYPE_INTEGER,     offsetof(rlm_stg_t,port), NULL,  "9091" },
-  { "password",  PW_TYPE_STRING_PTR, offsetof(rlm_stg_t,password), NULL,  "123456"},
+  { "address",  PW_TYPE_STRING_PTR, offsetof(rlm_stg_t, address), NULL,  "unix:/var/run/stg.sock"},
 
   { NULL, -1, 0, NULL, NULL }        /* end the list */
 };
 
-int emptyPair(const STG_PAIR * pair);
+static void deletePairs(STG_PAIR* pairs)
+{
+    free(pairs);
+}
+
+static size_t toVPS(const STG_PAIR* pairs, VALUE_PAIR** vps)
+{
+    const STG_PAIR* pair = pairs;
+    size_t count = 0;
+
+    while (!emptyPair(pair)) {
+        VALUE_PAIR* vp = pairmake(pair->key, pair->value, T_OP_SET);
+        if (vp != NULL) {
+            pairadd(vps, vp);
+            ++count;
+        }
+        ++pair;
+    }
+
+    return count;
+}
+
+static size_t toReply(STG_RESULT result, REQUEST* request)
+{
+    size_t count = 0;
+
+    count += toVPS(result.modify, &request->config_items);
+    pairfree(&request->reply->vps);
+    count += toVPS(result.reply, &request->reply->vps);
+
+    deletePairs(result.modify);
+    deletePairs(result.reply);
+
+    return count;
+}
+
+static int countVPS(const VALUE_PAIR* pairs)
+{
+    unsigned count = 0;
+    while (pairs != NULL) {
+        ++count;
+        pairs = pairs->next;
+    }
+    return count;
+}
+
+static STG_PAIR* fromVPS(const VALUE_PAIR* pairs)
+{
+    unsigned size = countVPS(pairs);
+    STG_PAIR* res = (STG_PAIR*)malloc(sizeof(STG_PAIR) * (size + 1));
+    size_t pos = 0;
+    while (pairs != NULL) {
+        bzero(res[pos].key, sizeof(res[0].key));
+        bzero(res[pos].value, sizeof(res[0].value));
+        strncpy(res[pos].key, pairs->name, sizeof(res[0].key));
+        vp_prints_value(res[pos].value, sizeof(res[0].value), (VALUE_PAIR*)pairs, 0);
+        ++pos;
+        pairs = pairs->next;
+    }
+    bzero(res[pos].key, sizeof(res[0].key));
+    bzero(res[pos].value, sizeof(res[0].value));
+    return res;
+}
+
+static int toRLMCode(int code)
+{
+    switch (code)
+    {
+        case STG_REJECT:   return RLM_MODULE_REJECT;
+        case STG_FAIL:     return RLM_MODULE_FAIL;
+        case STG_OK:       return RLM_MODULE_OK;
+        case STG_HANDLED:  return RLM_MODULE_HANDLED;
+        case STG_INVALID:  return RLM_MODULE_INVALID;
+        case STG_USERLOCK: return RLM_MODULE_USERLOCK;
+        case STG_NOTFOUND: return RLM_MODULE_NOTFOUND;
+        case STG_NOOP:     return RLM_MODULE_NOOP;
+        case STG_UPDATED:  return RLM_MODULE_UPDATED;
+    }
+    return RLM_MODULE_REJECT;
+}
 
 /*
  *    Do any per-module initialization that is separate to each
@@ -63,17 +136,17 @@ int emptyPair(const STG_PAIR * pair);
  *    that must be referenced in later calls, store a handle to it
  *    in *instance otherwise put a null pointer there.
  */
-static int stg_instantiate(CONF_SECTION *conf, void **instance)
+static int stg_instantiate(CONF_SECTION* conf, void** instance)
 {
-    rlm_stg_t *data;
+    rlm_stg_tdata;
 
     /*
      *    Set up a storage area for instance data
      */
     data = rad_malloc(sizeof(*data));
-    if (!data) {
+    if (!data)
         return -1;
-    }
+
     memset(data, 0, sizeof(*data));
 
     /*
@@ -85,7 +158,7 @@ static int stg_instantiate(CONF_SECTION *conf, void **instance)
         return -1;
     }
 
-    if (!stgInstantiateImpl(data->server, data->port)) {
+    if (!stgInstantiateImpl(data->address)) {
         free(data);
         return -1;
     }
@@ -101,158 +174,165 @@ static int stg_instantiate(CONF_SECTION *conf, void **instance)
  *    from the database. The authentication code only needs to check
  *    the password, the rest is done here.
  */
-static int stg_authorize(void *, REQUEST *request)
+static int stg_authorize(void* instance, REQUEST* request)
 {
-    const STG_PAIR * pairs;
-    const STG_PAIR * pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
+    const char* username = NULL;
+    const char* password = NULL;
 
     instance = instance;
 
     DEBUG("rlm_stg: stg_authorize()");
 
     if (request->username) {
-        DEBUG("rlm_stg: stg_authorize() request username field: '%s'", request->username->vp_strvalue);
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_authorize() request username field: '%s'", username);
     }
+
     if (request->password) {
-        DEBUG("rlm_stg: stg_authorize() request password field: '%s'", request->password->vp_strvalue);
-    }
-    // Here we need to define Framed-Protocol
-    VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE);
-    if (svc) {
-        DEBUG("rlm_stg: stg_authorize() Service-Type defined as '%s'", svc->vp_strvalue);
-        pairs = stgAuthorizeImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue);
-    } else {
-        DEBUG("rlm_stg: stg_authorize() Service-Type undefined");
-        pairs = stgAuthorizeImpl((const char *)request->username->vp_strvalue, "");
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_authorize() request password field: '%s'", password);
     }
-    if (!pairs) {
+
+    result = stgAuthorizeImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_authorize() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->config_items, pwd);
-        DEBUG("Adding pair '%s': '%s'", pair->key, pair->value);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
 /*
  *    Authenticate the user with the given password.
  */
-static int stg_authenticate(void *, REQUEST *request)
+static int stg_authenticate(void* instance, REQUEST* request)
 {
-    const STG_PAIR * pairs;
-    const STG_PAIR * pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
+    const char* username = NULL;
+    const char* password = NULL;
 
     instance = instance;
 
     DEBUG("rlm_stg: stg_authenticate()");
 
-    VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE);
-    if (svc) {
-        DEBUG("rlm_stg: stg_authenticate() Service-Type defined as '%s'", svc->vp_strvalue);
-        pairs = stgAuthenticateImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue);
-    } else {
-        DEBUG("rlm_stg: stg_authenticate() Service-Type undefined");
-        pairs = stgAuthenticateImpl((const char *)request->username->vp_strvalue, "");
+    if (request->username) {
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_authenticate() request username field: '%s'", username);
+    }
+
+    if (request->password) {
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_authenticate() request password field: '%s'", password);
     }
-    if (!pairs) {
+
+    result = stgAuthenticateImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_authenticate() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, pwd);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
 /*
  *    Massage the request before recording it or proxying it
  */
-static int stg_preacct(void *, REQUEST *)
+static int stg_preacct(void* instance, REQUEST* request)
 {
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
+    size_t count = 0;
+    const char* username = NULL;
+    const char* password = NULL;
+
     DEBUG("rlm_stg: stg_preacct()");
 
     instance = instance;
 
-    return RLM_MODULE_OK;
+    if (request->username) {
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_preacct() request username field: '%s'", username);
+    }
+
+    if (request->password) {
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_preacct() request password field: '%s'", password);
+    }
+
+    result = stgPreAcctImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
+        DEBUG("rlm_stg: stg_preacct() failed.");
+        return RLM_MODULE_REJECT;
+    }
+
+    count = toReply(result, request);
+
+    if (count)
+        return RLM_MODULE_UPDATED;
+
+    return toRLMCode(result.returnCode);
 }
 
 /*
  *    Write accounting information to this modules database.
  */
-static int stg_accounting(void *, REQUEST * request)
+static int stg_accounting(void* instance, REQUEST* request)
 {
-    const STG_PAIR * pairs;
-    const STG_PAIR * pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
-
-    instance = instance;
+    const char* username = NULL;
+    const char* password = NULL;
 
     DEBUG("rlm_stg: stg_accounting()");
 
-    VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE);
-    VALUE_PAIR * sessid = pairfind(request->packet->vps, PW_ACCT_SESSION_ID);
-    VALUE_PAIR * sttype = pairfind(request->packet->vps, PW_ACCT_STATUS_TYPE);
+    instance = instance;
 
-    if (!sessid) {
-        DEBUG("rlm_stg: stg_accounting() Acct-Session-ID undefined");
-        return RLM_MODULE_FAIL;
+    if (request->username) {
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_accounting() request username field: '%s'", username);
     }
 
-    if (sttype) {
-        DEBUG("Acct-Status-Type := %s", sttype->vp_strvalue);
-        if (svc) {
-            DEBUG("rlm_stg: stg_accounting() Service-Type defined as '%s'", svc->vp_strvalue);
-            pairs = stgAccountingImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue, (const char *)sttype->vp_strvalue, (const char *)sessid->vp_strvalue);
-        } else {
-            DEBUG("rlm_stg: stg_accounting() Service-Type undefined");
-            pairs = stgAccountingImpl((const char *)request->username->vp_strvalue, "", (const char *)sttype->vp_strvalue, (const char *)sessid->vp_strvalue);
-        }
-    } else {
-        DEBUG("rlm_stg: stg_accounting() Acct-Status-Type := NULL");
-        return RLM_MODULE_OK;
+    if (request->password) {
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_accounting() request password field: '%s'", password);
     }
-    if (!pairs) {
+
+    result = stgAccountingImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_accounting() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, pwd);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_OK;
+    return toRLMCode(result.returnCode);
 }
 
 /*
@@ -265,86 +345,76 @@ static int stg_accounting(void *, REQUEST * request)
  *    max. number of logins, do a second pass and validate all
  *    logins by querying the terminal server (using eg. SNMP).
  */
-static int stg_checksimul(void *, REQUEST *request)
+static int stg_checksimul(void* instance, REQUEST* request)
 {
     DEBUG("rlm_stg: stg_checksimul()");
 
     instance = instance;
 
-    request->simul_count=0;
+    request->simul_count = 0;
 
     return RLM_MODULE_OK;
 }
 
-static int stg_postauth(void *, REQUEST *request)
+static int stg_postauth(void* instance, REQUEST* request)
 {
-    const STG_PAIR * pairs;
-    const STG_PAIR * pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
-
-    instance = instance;
+    const char* username = NULL;
+    const char* password = NULL;
 
     DEBUG("rlm_stg: stg_postauth()");
 
-    VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE);
+    instance = instance;
+
+    if (request->username) {
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_postauth() request username field: '%s'", username);
+    }
 
-    if (svc) {
-        DEBUG("rlm_stg: stg_postauth() Service-Type defined as '%s'", svc->vp_strvalue);
-        pairs = stgPostAuthImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue);
-    } else {
-        DEBUG("rlm_stg: stg_postauth() Service-Type undefined");
-        pairs = stgPostAuthImpl((const char *)request->username->vp_strvalue, "");
+    if (request->password) {
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_postauth() request password field: '%s'", password);
     }
-    if (!pairs) {
+
+    result = stgPostAuthImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_postauth() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, pwd);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
-    return RLM_MODULE_NOOP;
+    return toRLMCode(result.returnCode);
 }
 
-static int stg_detach(void *instance)
+static int stg_detach(voidinstance)
 {
-    free(((struct rlm_stg_t *)instance)->server);
+    free(((struct rlm_stg_t*)instance)->address);
     free(instance);
     return 0;
 }
 
-/*
- *    The module name should be the only globally exported symbol.
- *    That is, everything else should be 'static'.
- *
- *    If the module needs to temporarily modify it's instantiation
- *    data, the type should be changed to RLM_TYPE_THREAD_UNSAFE.
- *    The server will then take care of ensuring that the module
- *    is single-threaded.
- */
 module_t rlm_stg = {
     RLM_MODULE_INIT,
     "stg",
-    RLM_TYPE_THREAD_SAFE,        /* type */
-    stg_instantiate,        /* instantiation */
-    stg_detach,            /* detach */
+    RLM_TYPE_THREAD_UNSAFE, /* type */
+    stg_instantiate,      /* instantiation */
+    stg_detach,           /* detach */
     {
-        stg_authenticate,    /* authentication */
+        stg_authenticate, /* authentication */
         stg_authorize,    /* authorization */
-        stg_preacct,    /* preaccounting */
-        stg_accounting,    /* accounting */
-        stg_checksimul,    /* checksimul */
-        NULL,            /* pre-proxy */
-        NULL,            /* post-proxy */
-        stg_postauth            /* post-auth */
+        stg_preacct,      /* preaccounting */
+        stg_accounting,   /* accounting */
+        stg_checksimul,   /* checksimul */
+        NULL,    /* pre-proxy */
+        NULL,   /* post-proxy */
+        stg_postauth      /* post-auth */
     },
 };
index 113e71c97891ca42c284d73b4616f019eb9564f8..e34c50cdbaf30137125e29d50c55b1d338c47336 100644 (file)
  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
  */
 
-/*
- *  Realization of data access via Stargazer for RADIUS
- *
- *  $Revision: 1.8 $
- *  $Date: 2010/04/16 12:30:02 $
- *
- */
-
-#include <netdb.h>
-#include <sys/types.h>
-#include <unistd.h> // close
-
-#include <cerrno>
-#include <cstring>
-#include <vector>
-#include <utility>
-
-#include <stdexcept>
-
 #include "stg_client.h"
 
-typedef std::vector<std::pair<std::string, std::string> > PAIRS;
-
-//-----------------------------------------------------------------------------
+#include "conn.h"
+#include "radlog.h"
 
-STG_CLIENT::STG_CLIENT(const std::string & host, uint16_t port, uint16_t lp, const std::string & pass)
-    : password(pass),
-      framedIP(0)
-{
-/*sock = socket(AF_INET, SOCK_DGRAM, 0);
-if (sock == -1)
-    {
-    std::string message = strerror(errno);
-    message = "Socket create error: '" + message + "'";
-    throw std::runtime_error(message);
-    }
+#include "stg/locker.h"
+#include "stg/common.h"
 
-struct hostent * he = NULL;
-he = gethostbyname(host.c_str());
-if (he == NULL)
-    {
-    throw std::runtime_error("gethostbyname error");
-    }
+#include <map>
+#include <utility>
 
-outerAddr.sin_family = AF_INET;
-outerAddr.sin_port = htons(port);
-outerAddr.sin_addr.s_addr = *(uint32_t *)he->h_addr;
+using STG::RLM::Client;
+using STG::RLM::Conn;
+using STG::RLM::RESULT;
 
-InitEncrypt(&ctx, password);
+namespace {
 
-PrepareNet();*/
-}
+Client* stgClient = NULL;
 
-STG_CLIENT::~STG_CLIENT()
-{
-/*close(sock);*/
 }
 
-int STG_CLIENT::PrepareNet()
+class Client::Impl
 {
-return 0;
-}
-
-int STG_CLIENT::Send(const RAD_PACKET & packet)
+    public:
+        explicit Impl(const std::string& address);
+        ~Impl();
+
+        bool stop() { return m_conn ? m_conn->stop() : true; }
+
+        RESULT request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
+
+    private:
+        std::string m_address;
+        boost::scoped_ptr<Conn> m_conn;
+
+        pthread_mutex_t m_mutex;
+        pthread_cond_t m_cond;
+        bool m_done;
+        RESULT m_result;
+
+        static bool callback(void* data, const RESULT& result)
+        {
+            Impl& impl = *static_cast<Impl*>(data);
+            STG_LOCKER lock(impl.m_mutex);
+            impl.m_result = result;
+            impl.m_done = true;
+            pthread_cond_signal(&impl.m_cond);
+            return true;
+        }
+};
+
+Client::Impl::Impl(const std::string& address)
+    : m_address(address)
 {
-/*char buf[RAD_MAX_PACKET_LEN];
-    
-Encrypt(&ctx, buf, (char *)&packet, sizeof(RAD_PACKET) / 8);
-
-int res = sendto(sock, buf, sizeof(RAD_PACKET), 0, (struct sockaddr *)&outerAddr, sizeof(outerAddr));
-
-if (res == -1)
-    errorStr = "Error sending data";
-
-return res;*/
-}
-
-int STG_CLIENT::RecvData(RAD_PACKET * packet)
-{
-/*char buf[RAD_MAX_PACKET_LEN];
-int res;
-
-struct sockaddr_in addr;
-socklen_t len = sizeof(struct sockaddr_in);
-
-res = recvfrom(sock, buf, RAD_MAX_PACKET_LEN, 0, reinterpret_cast<struct sockaddr *>(&addr), &len);
-if (res == -1)
+    try
     {
-    errorStr = "Error receiving data";
-    return -1;
+        m_conn.reset(new Conn(m_address, &Impl::callback, this));
     }
-
-Decrypt(&ctx, (char *)packet, buf, res / 8);
-
-return 0;*/
-}
-
-int STG_CLIENT::Request(RAD_PACKET * packet, const std::string & login, const std::string & svc, uint8_t packetType)
-{
-/*int res;
-
-memcpy((void *)&packet->magic, (void *)RAD_ID, RAD_MAGIC_LEN);
-packet->protoVer[0] = '0';
-packet->protoVer[1] = '1';
-packet->packetType = packetType;
-packet->ip = 0;
-strncpy((char *)packet->login, login.c_str(), RAD_LOGIN_LEN);
-strncpy((char *)packet->service, svc.c_str(), RAD_SERVICE_LEN);
-
-res = Send(*packet);
-if (res == -1)
-    return -1;
-
-res = RecvData(packet);
-if (res == -1)
-    return -1;
-
-if (strncmp((char *)packet->magic, RAD_ID, RAD_MAGIC_LEN))
+    catch (const std::runtime_error& ex)
     {
-    errorStr = "Magic invalid. Wanted: '";
-    errorStr += RAD_ID;
-    errorStr += "', got: '";
-    errorStr += (char *)packet->magic;
-    errorStr += "'";
-    return -1;
+        RadLog("Connection error: %s.", ex.what());
     }
-
-return 0;*/
+    pthread_mutex_init(&m_mutex, NULL);
+    pthread_cond_init(&m_cond, NULL);
+    m_done = false;
 }
 
-//-----------------------------------------------------------------------------
-
-const STG_PAIRS * STG_CLIENT::Authorize(const std::string & login, const std::string & svc)
+Client::Impl::~Impl()
 {
-/*RAD_PACKET packet;
-
-userPassword = "";
-
-if (Request(&packet, login, svc, RAD_AUTZ_PACKET))
-    return -1;
-
-if (packet.packetType != RAD_ACCEPT_PACKET)
-    return -1;
-
-userPassword = (char *)packet.password;*/
-
-PAIRS pairs;
-pairs.push_back(std::make_pair("Cleartext-Password", userPassword));
-
-return ToSTGPairs(pairs);
+    pthread_cond_destroy(&m_cond);
+    pthread_mutex_destroy(&m_mutex);
 }
 
-const STG_PAIRS * STG_CLIENT::Authenticate(const std::string & login, const std::string & svc)
+RESULT Client::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-/*RAD_PACKET packet;
-
-userPassword = "";
-
-if (Request(&packet, login, svc, RAD_AUTH_PACKET))
-    return -1;
-
-if (packet.packetType != RAD_ACCEPT_PACKET)
-    return -1;*/
-
-PAIRS pairs;
-
-return ToSTGPairs(pairs);
+    STG_LOCKER lock(m_mutex);
+    if (!m_conn || !m_conn->connected())
+        m_conn.reset(new Conn(m_address, &Impl::callback, this));
+    if (!m_conn->connected())
+        throw Conn::Error("Failed to create connection to '" + m_address + "'.");
+
+    m_done = false;
+    m_conn->request(type, userName, password, pairs);
+    timespec ts;
+    clock_gettime(CLOCK_REALTIME, &ts);
+    ts.tv_sec += 5;
+    int res = 0;
+    while (!m_done && res == 0)
+        res = pthread_cond_timedwait(&m_cond, &m_mutex, &ts);
+    if (res != 0)
+        throw Conn::Error("Request failed.");
+    return m_result;
 }
 
-const STG_PAIRS * STG_CLIENT::PostAuth(const std::string & login, const std::string & svc)
+Client::Client(const std::string& address)
+    : m_impl(new Impl(address))
 {
-/*RAD_PACKET packet;
-
-userPassword = "";
-
-if (Request(&packet, login, svc, RAD_POST_AUTH_PACKET))
-    return -1;
-
-if (packet.packetType != RAD_ACCEPT_PACKET)
-    return -1;
-
-if (svc == "Framed-User")
-    framedIP = packet.ip;
-else
-    framedIP = 0;*/
-
-PAIRS pairs;
-pairs.push_back(std::make_pair("Framed-IP-Address", inet_ntostring(framedIP)));
-
-return ToSTGPairs(pairs);
 }
 
-const STG_PAIRS * STG_CLIENT::PreAcct(const std::string & login, const std::String & service)
+Client::~Client()
 {
-PAIRS pairs;
-
-return ToSTGPairs(pairs);
 }
 
-const STG_PAIRS * STG_CLIENT::Account(const std::string & type, const std::string & login, const std::string & svc, const std::string & sessid)
+bool Client::stop()
 {
-/*RAD_PACKET packet;
-
-userPassword = "";
-strncpy((char *)packet.sessid, sessid.c_str(), RAD_SESSID_LEN);
-
-if (type == "Start")
-    {
-    if (Request(&packet, login, svc, RAD_ACCT_START_PACKET))
-        return -1;
-    }
-else if (type == "Stop")
-    {
-    if (Request(&packet, login, svc, RAD_ACCT_STOP_PACKET))
-        return -1;
-    }
-else if (type == "Interim-Update")
-    {
-    if (Request(&packet, login, svc, RAD_ACCT_UPDATE_PACKET))
-        return -1;
-    }
-else
-    {
-    if (Request(&packet, login, svc, RAD_ACCT_OTHER_PACKET))
-        return -1;
-    }
-
-if (packet.packetType != RAD_ACCEPT_PACKET)
-    return -1;*/
-
-PAIRS pairs;
-
-return ToSTGPairs(pairs);
+    return m_impl->stop();
 }
 
-//-----------------------------------------------------------------------------
-
-std::string STG_CLIENT_ST::m_host;
-uint16_t STG_CLIENT_ST::m_port(6666);
-std::string STG_CLIENT_ST::m_password;
-
-//-----------------------------------------------------------------------------
-
-STG_CLIENT * STG_CLIENT_ST::Get()
+RESULT Client::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
 {
-    static STG_CLIENT * stgClient = NULL;
-    if ( stgClient == NULL )
-        stgClient = new STG_CLIENT(m_host, m_port, m_password);
-    return stgClient;
+    return m_impl->request(type, userName, password, pairs);
 }
 
-void STG_CLIENT_ST::Configure(const std::string & host, uint16_t port, const std::string & password)
+Client* Client::get()
 {
-    m_host = host;
-    m_port = port;
-    m_password = password;
+    return stgClient;
 }
 
-//-----------------------------------------------------------------------------
-
-const STG_PAIR * ToSTGPairs(const PAIRS & source)
+bool Client::configure(const std::string& address)
 {
-    STG_PAIR * pairs = new STG_PAIR[source.size() + 1];
-    for (size_t pos = 0; pos < source.size(); ++pos) {
-        bzero(pairs[pos].key, sizeof(STG_PAIR::key));
-        bzero(pairs[pos].value, sizeof(STG_PAIR::value));
-        strncpy(pairs[pos].key, source[pos].first.c_str(), sizeof(STG_PAIR::key));
-        strncpy(pairs[pos].value, source[pos].second.c_str(), sizeof(STG_PAIR::value));
-        ++pos;
+    if ( stgClient != NULL )
+        return stgClient->configure(address);
+    try {
+        stgClient = new Client(address);
+        return true;
+    } catch (const std::exception& ex) {
+        RadLog("Client configuration error: %s.", ex.what());
     }
-    bzero(pairs[sources.size()].key, sizeof(STG_PAIR::key));
-    bzero(pairs[sources.size()].value, sizeof(STG_PAIR::value));
-
-    return pairs;
+    return false;
 }
index 5ee000c7e949b1e8735815c122db7d510036e10a..917d0e511fa18c68bef09587725a765dff3ab191 100644 (file)
  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
  */
 
-/*
- *  Header file for client part of data access via Stargazer for RADIUS
- *
- *  $Revision: 1.4 $
- *  $Date: 2010/04/16 12:30:02 $
- *
- */
+#ifndef __STG_RLM_CLIENT_H__
+#define __STG_RLM_CLIENT_H__
 
-#ifndef STG_CLIENT_H
-#define STG_CLIENT_H
+#include "types.h"
 
-#include <string>
+#include "stg/os_int.h"
 
-#include <netinet/in.h>
-#include <arpa/inet.h>
-#include <sys/socket.h> // socklen_t
+#include <boost/scoped_ptr.hpp>
 
-#include "stg/blowfish.h"
-#include "stg/rad_packets.h"
+#include <string>
 
-#include "stgpair.h"
+namespace STG
+{
+namespace RLM
+{
 
-class STG_CLIENT
+class Client
 {
 public:
-    STG_CLIENT(const std::string & host, uint16_t port, const std::string & password);
-    ~STG_CLIENT();
-
-    const STG_PAIR * Authorize(const std::string & login, const std::string & service);
-    const STG_PAIR * Authenticate(const std::string & login, const std::string & service);
-    const STG_PAIR * PostAuth(const std::string & login, const std::string & service);
-    const STG_PAIR * PreAcct(const std::string & login, const std::string & service);
-    const STG_PAIR * Account(const std::string & type, const std::string & login, const std::string & service, const std::string & sessionId);
+    explicit Client(const std::string& address);
+    ~Client();
 
-private:
-    std::string password;
+    bool stop();
 
-    int PrepareNet();
+    static Client* get();
+    static bool configure(const std::string& address);
 
-    int Request(RAD_PACKET * packet, const std::string & login, const std::string & svc, uint8_t packetType);
+    RESULT request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
 
-    int RecvData(RAD_PACKET * packet);
-    int Send(const RAD_PACKET & packet);
+private:
+    class Impl;
+    boost::scoped_ptr<Impl> m_impl;
 };
 
-struct STG_CLIENT_ST
-{
-    public:
-        static void Configure(const std::string & host, uint16_t port, const std::string & password);
-        static STG_CLIENT * Get();
-
-    private:
-        static std::string m_host;
-        static uint16_t m_port;
-        static std::string m_password;
-};
+} // namespace RLM
+} // namespace STG
 
 #endif
index 19b42bc182fe724ea1d83db6faf9639372b2e18f..ef7ab4b778897fdd3e738e36346b70a7ebe7633a 100644 (file)
@@ -1,12 +1,47 @@
 #ifndef __STG_STGPAIR_H__
 #define __STG_STGPAIR_H__
 
+#include <stddef.h>
+
 #define STGPAIR_KEYLENGTH 64
 #define STGPAIR_VALUELENGTH 256
 
+#ifdef __cplusplus
+extern "C" {
+#endif
+
 typedef struct STG_PAIR {
     char key[STGPAIR_KEYLENGTH];
     char value[STGPAIR_VALUELENGTH];
 } STG_PAIR;
 
+typedef struct STG_RESULT {
+    STG_PAIR* modify;
+    STG_PAIR* reply;
+    int returnCode;
+} STG_RESULT;
+
+inline
+int emptyPair(const STG_PAIR* pair)
+{
+    return pair == NULL || pair->key[0] == '\0' || pair->value[0] == '\0';
+}
+
+enum
+{
+    STG_REJECT,
+    STG_FAIL,
+    STG_OK,
+    STG_HANDLED,
+    STG_INVALID,
+    STG_USERLOCK,
+    STG_NOTFOUND,
+    STG_NOOP,
+    STG_UPDATED
+};
+
+#ifdef __cplusplus
+}
+#endif
+
 #endif
diff --git a/projects/rlm_stg/types.h b/projects/rlm_stg/types.h
new file mode 100644 (file)
index 0000000..2bc721f
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#ifndef __STG_RLM_CLIENT_CONN_H__
+#define __STG_RLM_CLIENT_CONN_H__
+
+#include <vector>
+#include <string>
+
+namespace STG
+{
+namespace RLM
+{
+
+typedef std::vector<std::pair<std::string, std::string> > PAIRS;
+
+struct RESULT
+{
+    PAIRS modify;
+    PAIRS reply;
+    int returnCode;
+};
+
+enum REQUEST_TYPE {
+    AUTHORIZE,
+    AUTHENTICATE,
+    POST_AUTH,
+    PRE_ACCT,
+    ACCOUNT
+};
+
+} // namespace RLM
+} // namespace STG
+
+#endif
index f54e5ffb38b0960e8ac282cad9d55107e34536f0..9c63e65636e5526d644d88fa5b963e5050a59b60 100755 (executable)
@@ -105,7 +105,6 @@ PLUGINS="authorization/ao
          configuration/sgconfig
          other/ping
          other/rscript
-         other/radius
          other/smux
          store/files
          capture/cap_nf"
@@ -418,6 +417,55 @@ else
 fi
 rm -f fake
 
+printf "Checking for -lyajl... "
+pkg-config --version > /dev/null 2> /dev/null
+if [ "$?" = "0" ]
+then
+    pkg-config --atleast-version=2.0.0 yajl
+    if [ "$?" != "0" ]
+    then
+        CHECK_YAJL=no
+        printf "no\n"
+    else
+        CHECK_YAJL=yes
+        printf `pkg-config --modversion yajl`"\n"
+    fi
+else
+    printf "#include <stdio.h>\n" > build_check.c
+    printf "#include <yajl/yajl_version.h>\n" >> build_check.c
+    printf "int main() { printf(\"%%d\", yajl_version()); return 0; }\n" >> build_check.c
+    $CC $CFLAGS $LDFLAGS build_check.c -lyajl -o fake > /dev/null 2> /dev/null
+    if [ $? != 0 ]
+    then
+        CHECK_YAJL=no
+        printf "no\n"
+    else
+        YAJL_VERSION=`./fake`
+        if [ $YAJL_VERSION -ge 20000 ]
+        then
+            CHECK_YAJL=yes
+            printf "${YAJL_VERSION}\n"
+        else
+            CHECK_YAJL=no
+            printf "no. Need at least version 2.0.0, existing version is ${YAJL_VERSION}\n"
+        fi
+    fi
+    rm -f fake
+fi
+
+printf "Checking for boost::scoped_ptr... "
+printf "#include <boost/scoped_ptr.hpp>\nint main() { boost::scoped_ptr<int> test(new int(1)); return 0; }\n" > build_check.cpp
+$CXX $CXXFLAGS $LDFLAGS build_check.cpp -o fake # > /dev/null 2> /dev/null
+if [ $? != 0 ]
+then
+    CHECK_BOOST_SCOPED_PTR=no
+    printf "no\n"
+else
+    CHECK_BOOST_SCOPED_PTR=yes
+    printf "yes\n"
+fi
+rm -f fake
+
 if [ "$OS" = "linux" ]
 then
     printf "Checking for linux/netfilter_ipv4/ip_queue.h... "
@@ -440,6 +488,7 @@ then
 fi
 
 rm -f build_check.c
+rm -f build_check.cpp
 
 if [ "$CHECK_EXPAT" != "yes" ]
 then
@@ -479,6 +528,14 @@ then
              capture/nfqueue"
 fi
 
+if [ "$CHECK_YAJL" = "yes" -a "$CHECK_BOOST_SCOPED_PTR" = "yes" ]
+then
+    PLUGINS="$PLUGINS
+             other/radius"
+    STG_LIBS="$STG_LIBS
+              json.lib"
+fi
+
 printf "OS=$OS\n" > $CONFFILE
 printf "STG_TIME=yes\n" >> $CONFFILE
 printf "DEBUG=$DEBUG\n" >> $CONFFILE
index 62a05183acbb016994b74df5a7e8a34b9c6dbc85..14428914abb392e53d24058dbf00e26fc56b50e6 100644 (file)
@@ -4,16 +4,19 @@
 
 include ../../../../../Makefile.conf
 
-LIBS += $(LIB_THREAD)
+LIBS += $(LIB_THREAD) -lyajl
 
 PROG = mod_radius.so
 
-SRCS = ./radius.cpp
+SRCS = radius.cpp \
+       config.cpp \
+       conn.cpp
 
 STGLIBS = common \
          crypto \
          logger \
-         scriptexecuter
+         scriptexecuter \
+         json
 
 include ../../Makefile.in
 
diff --git a/projects/stargazer/plugins/other/radius/config.cpp b/projects/stargazer/plugins/other/radius/config.cpp
new file mode 100644 (file)
index 0000000..7982285
--- /dev/null
@@ -0,0 +1,386 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#include "config.h"
+
+#include "stg/user.h"
+#include "stg/common.h"
+
+#include <vector>
+#include <stdexcept>
+
+#include <strings.h> // strncasecmp
+
+using STG::Config;
+
+namespace
+{
+
+struct ParserError : public std::runtime_error
+{
+    ParserError(const std::string& message)
+        : runtime_error("Config is not valid. " + message),
+          position(0),
+          error(message)
+    {}
+    ParserError(size_t pos, const std::string& message)
+        : runtime_error("Parsing error at position " + x2str(pos) + ". " + message),
+          position(pos),
+          error(message)
+    {}
+    virtual ~ParserError() throw() {}
+
+    size_t position;
+    std::string error;
+};
+
+size_t skipSpaces(const std::string& value, size_t start)
+{
+    while (start < value.length() && std::isspace(value[start]))
+        ++start;
+    return start;
+}
+
+size_t checkChar(const std::string& value, size_t start, char ch)
+{
+    if (start >= value.length())
+        throw ParserError(start, "Unexpected end of string. Expected '" + std::string(1, ch) + "'.");
+    if (value[start] != ch)
+        throw ParserError(start, "Expected '" + std::string(1, ch) + "', got '" + std::string(1, value[start]) + "'.");
+    return start + 1;
+}
+
+std::pair<size_t, std::string> readString(const std::string& value, size_t start)
+{
+    std::string dest;
+    while (start < value.length() && !std::isspace(value[start]) &&
+           value[start] != ',' && value[start] != '(' && value[start] != ')')
+        dest.push_back(value[start++]);
+    if (dest.empty()) {
+        if (start == value.length())
+            throw ParserError(start, "Unexpected end of string. Expected string.");
+        else
+            throw ParserError(start, "Unexpected whitespace. Expected string.");
+    }
+    return std::make_pair(start, dest);
+}
+
+Config::Pairs toPairs(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return Config::Pairs();
+    std::string value(values[0]);
+    Config::Pairs res;
+    size_t start = 0;
+    while (start < value.size()) {
+        Config::Pair pair;
+        start = skipSpaces(value, start);
+        if (!res.empty())
+        {
+            start = checkChar(value, start, ',');
+            start = skipSpaces(value, start);
+        }
+        size_t pairStart = start;
+        start = checkChar(value, start, '(');
+        const std::pair<size_t, std::string> key = readString(value, start);
+        start = key.first;
+        pair.first = key.second;
+        start = skipSpaces(value, start);
+        start = checkChar(value, start, ',');
+        start = skipSpaces(value, start);
+        const std::pair<size_t, std::string> val = readString(value, start);
+        start = val.first;
+        pair.second = val.second;
+        start = skipSpaces(value, start);
+        start = checkChar(value, start, ')');
+        if (res.find(pair.first) != res.end())
+            throw ParserError(pairStart, "Duplicate field.");
+        res.insert(pair);
+    }
+    return res;
+}
+
+bool toBool(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return false;
+    std::string value(values[0]);
+    return strncasecmp(value.c_str(), "yes", 3) == 0;
+}
+
+std::string toString(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return "";
+    return values[0];
+}
+
+uid_t toUID(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return -1;
+    uid_t res = str2uid(values[0]);
+    if (res == static_cast<uid_t>(-1))
+        throw ParserError("Invalid user name: '" + values[0] + "'");
+    return res;
+}
+
+gid_t toGID(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return -1;
+    gid_t res = str2gid(values[0]);
+    if (res == static_cast<gid_t>(-1))
+        throw ParserError("Invalid group name: '" + values[0] + "'");
+    return res;
+}
+
+mode_t toMode(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return -1;
+    mode_t res = str2mode(values[0]);
+    if (res == static_cast<mode_t>(-1))
+        throw ParserError("Invalid mode: '" + values[0] + "'");
+    return res;
+}
+
+template <typename T>
+T toInt(const std::vector<std::string>& values)
+{
+    if (values.empty())
+        return 0;
+    T res = 0;
+    if (str2x(values[0], res) == 0)
+        return res;
+    return 0;
+}
+
+uint16_t toPort(const std::string& value)
+{
+    if (value.empty())
+        return 0;
+    uint16_t res = 0;
+    if (str2x(value, res) == 0)
+        return res;
+    throw ParserError("'" + value + "' is not a valid port number.");
+}
+
+typedef std::map<std::string, Config::ReturnCode> Codes;
+
+// One-time call to initialize the list of codes.
+Codes getCodes()
+{
+    Codes res;
+    res["reject"]   = Config::REJECT;
+    res["fail"]     = Config::FAIL;
+    res["ok"]       = Config::OK;
+    res["handled"]  = Config::HANDLED;
+    res["invalid"]  = Config::INVALID;
+    res["userlock"] = Config::USERLOCK;
+    res["notfound"] = Config::NOTFOUND;
+    res["noop"]     = Config::NOOP;
+    res["updated"]  = Config::UPDATED;
+    return res;
+}
+
+Config::ReturnCode toReturnCode(const std::vector<std::string>& values)
+{
+    static Codes codes(getCodes());
+    if (values.empty())
+        return Config::REJECT;
+    std::string code = ToLower(values[0]);
+    const Codes::const_iterator it = codes.find(code);
+    if (it == codes.end())
+        return Config::REJECT;
+    return it->second;
+}
+
+Config::Pairs parseVector(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toPairs(params[i].value);
+    return Config::Pairs();
+}
+
+Config::Authorize parseAuthorize(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return Config::Authorize(toPairs(params[i].value));
+    return Config::Authorize();
+}
+
+Config::ReturnCode parseReturnCode(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toReturnCode(params[i].value);
+    return Config::REJECT;
+}
+
+bool parseBool(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toBool(params[i].value);
+    return false;
+}
+
+std::string parseString(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toString(params[i].value);
+    return "";
+}
+
+std::string parseAddress(Config::Type connectionType, const std::string& value)
+{
+    size_t pos = value.find_first_of(':');
+    if (pos == std::string::npos)
+        throw ParserError("Connection type is not specified. Should be either 'unix' or 'tcp'.");
+    if (connectionType == Config::UNIX)
+        return value.substr(pos + 1);
+    std::string address(value.substr(pos + 1));
+    pos = address.find_first_of(':', pos + 1);
+    if (pos == std::string::npos)
+        throw ParserError("Port is not specified.");
+    return address.substr(0, pos - 1);
+}
+
+std::string parsePort(Config::Type connectionType, const std::string& value)
+{
+    size_t pos = value.find_first_of(':');
+    if (pos == std::string::npos)
+        throw ParserError("Connection type is not specified. Should be either 'unix' or 'tcp'.");
+    if (connectionType == Config::UNIX)
+        return "";
+    std::string address(value.substr(pos + 1));
+    pos = address.find_first_of(':', pos + 1);
+    if (pos == std::string::npos)
+        throw ParserError("Port is not specified.");
+    return address.substr(pos + 1);
+}
+
+Config::Type parseConnectionType(const std::string& address)
+{
+    size_t pos = address.find_first_of(':');
+    if (pos == std::string::npos)
+        throw ParserError("Connection type is not specified. Should be either 'unix' or 'tcp'.");
+    std::string type = ToLower(address.substr(0, pos));
+    if (type == "unix")
+        return Config::UNIX;
+    else if (type == "tcp")
+        return Config::TCP;
+    throw ParserError("Invalid connection type. Should be either 'unix' or 'tcp', got '" + type + "'");
+}
+
+Config::Section parseSection(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return Config::Section(parseVector("match", params[i].sections),
+                                   parseVector("modify", params[i].sections),
+                                   parseVector("reply", params[i].sections),
+                                   parseReturnCode("no_match", params[i].sections),
+                                   parseAuthorize("authorize", params[i].sections));
+    return Config::Section();
+}
+
+uid_t parseUID(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toUID(params[i].value);
+    return -1;
+}
+
+gid_t parseGID(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toGID(params[i].value);
+    return -1;
+}
+
+mode_t parseMode(const std::string& paramName, const std::vector<PARAM_VALUE>& params)
+{
+    for (size_t i = 0; i < params.size(); ++i)
+        if (params[i].param == paramName)
+            return toMode(params[i].value);
+    return -1;
+}
+
+} // namespace anonymous
+
+bool Config::Authorize::check(const USER& user, const Config::Pairs& radiusData) const
+{
+    if (!m_auth)
+        return false; // No flag - no authorization.
+
+    if (m_cond.empty())
+        return true; // Empty parameter - always authorize.
+
+    Config::Pairs::const_iterator it = m_cond.begin();
+    for (; it != m_cond.end(); ++it)
+    {
+        const Config::Pairs::const_iterator pos = radiusData.find(it->first);
+        if (pos == radiusData.end())
+            return false; // No required Radius parameter.
+        if (user.GetParamValue(it->second) != pos->second)
+            return false; // No match with the user.
+    }
+
+    return true;
+}
+
+Config::Config(const MODULE_SETTINGS& settings)
+    : autz(parseSection("autz", settings.moduleParams)),
+      auth(parseSection("auth", settings.moduleParams)),
+      postauth(parseSection("postauth", settings.moduleParams)),
+      preacct(parseSection("preacct", settings.moduleParams)),
+      acct(parseSection("acct", settings.moduleParams)),
+      verbose(parseBool("verbose", settings.moduleParams)),
+      address(parseString("bind_address", settings.moduleParams)),
+      connectionType(parseConnectionType(address)),
+      bindAddress(parseAddress(connectionType, address)),
+      portStr(parsePort(connectionType, address)),
+      port(toPort(portStr)),
+      key(parseString("key", settings.moduleParams)),
+      sockUID(parseUID("sock_owner", settings.moduleParams)),
+      sockGID(parseGID("sock_group", settings.moduleParams)),
+      sockMode(parseMode("sock_mode", settings.moduleParams))
+{
+    size_t count = 0;
+    if (autz.authorize.exists())
+        ++count;
+    if (auth.authorize.exists())
+        ++count;
+    if (postauth.authorize.exists())
+        ++count;
+    if (preacct.authorize.exists())
+        ++count;
+    if (acct.authorize.exists())
+        ++count;
+    if (count > 0)
+        throw ParserError("Authorization flag is specified in more than one section.");
+}
diff --git a/projects/stargazer/plugins/other/radius/config.h b/projects/stargazer/plugins/other/radius/config.h
new file mode 100644 (file)
index 0000000..44d5ed8
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#ifndef __STG_RADIUS_CONFIG_H__
+#define __STG_RADIUS_CONFIG_H__
+
+#include "stg/module_settings.h"
+
+#include "stg/os_int.h"
+
+#include <map>
+#include <string>
+
+#include <unistd.h> // uid_t, gid_t
+#include <sys/stat.h> // mode_t
+
+class USER;
+
+namespace STG
+{
+
+struct Config
+{
+    typedef std::map<std::string, std::string> Pairs;
+    typedef std::pair<std::string, std::string> Pair;
+    enum Type { UNIX, TCP };
+    enum ReturnCode
+    {
+        REJECT,   // Reject the request immediately.
+        FAIL,     // Module failed.
+        OK,       // Module is OK, continue.
+        HANDLED,  // The request is handled, no further handling.
+        INVALID,  // The request is invalud.
+        USERLOCK, // Reject the request, user is locked.
+        NOTFOUND, // User not found.
+        NOOP,     // Module performed no action.
+        UPDATED   // Module sends some updates.
+    };
+
+    class Authorize
+    {
+        public:
+            Authorize() : m_auth(false) {}
+            Authorize(const Pairs& cond) : m_auth(true), m_cond(cond) {}
+
+            bool check(const USER& user, const Pairs& radiusData) const;
+            bool exists() const { return m_auth; }
+        private:
+            bool m_auth;
+            Pairs m_cond;
+    };
+
+    struct Section
+    {
+        Section() {}
+        Section(const Pairs& ma, const Pairs& mo, const Pairs& re, ReturnCode code, const Authorize& auth)
+            : match(ma), modify(mo), reply(re), returnCode(code), authorize(auth) {}
+        Pairs match;
+        Pairs modify;
+        Pairs reply;
+        ReturnCode returnCode;
+        Authorize authorize;
+    };
+
+    Config() {}
+    Config(const MODULE_SETTINGS& settings);
+
+    Section autz;
+    Section auth;
+    Section postauth;
+    Section preacct;
+    Section acct;
+
+    bool verbose;
+
+    std::string address;
+    Type connectionType;
+    std::string bindAddress;
+    std::string portStr;
+    uint16_t port;
+    std::string key;
+
+    uid_t sockUID;
+    gid_t sockGID;
+    mode_t sockMode;
+};
+
+} // namespace STG
+
+#endif
diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp
new file mode 100644 (file)
index 0000000..a209409
--- /dev/null
@@ -0,0 +1,527 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#include "conn.h"
+
+#include "radius.h"
+#include "config.h"
+
+#include "stg/json_parser.h"
+#include "stg/json_generator.h"
+#include "stg/users.h"
+#include "stg/user.h"
+#include "stg/logger.h"
+#include "stg/common.h"
+
+#include <yajl/yajl_gen.h>
+
+#include <map>
+#include <stdexcept>
+#include <cstring>
+#include <cerrno>
+
+#include <unistd.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+
+using STG::Conn;
+using STG::Config;
+using STG::JSON::Parser;
+using STG::JSON::PairsParser;
+using STG::JSON::EnumParser;
+using STG::JSON::NodeParser;
+using STG::JSON::Gen;
+using STG::JSON::MapGen;
+using STG::JSON::StringGen;
+
+namespace
+{
+
+double CONN_TIMEOUT = 60;
+double PING_TIMEOUT = 10;
+
+enum Packet
+{
+    PING,
+    PONG,
+    DATA
+};
+
+enum Stage
+{
+    AUTHORIZE,
+    AUTHENTICATE,
+    PREACCT,
+    ACCOUNTING,
+    POSTAUTH
+};
+
+std::map<std::string, Packet> packetCodes;
+std::map<std::string, Stage> stageCodes;
+
+class PacketParser : public EnumParser<Packet>
+{
+    public:
+        PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
+            : EnumParser(next, packet, packetStr, packetCodes)
+        {
+            if (!packetCodes.empty())
+                return;
+            packetCodes["ping"] = PING;
+            packetCodes["pong"] = PONG;
+            packetCodes["data"] = DATA;
+        }
+};
+
+class StageParser : public EnumParser<Stage>
+{
+    public:
+        StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
+            : EnumParser(next, stage, stageStr, stageCodes)
+        {
+            if (!stageCodes.empty())
+                return;
+            stageCodes["authorize"] = AUTHORIZE;
+            stageCodes["authenticate"] = AUTHENTICATE;
+            stageCodes["preacct"] = PREACCT;
+            stageCodes["accounting"] = ACCOUNTING;
+            stageCodes["postauth"] = POSTAUTH;
+        }
+};
+
+class TopParser : public NodeParser
+{
+    public:
+        typedef void (*Callback) (void* /*data*/);
+        TopParser(Callback callback, void* data)
+            : m_packetParser(this, m_packet, m_packetStr),
+              m_stageParser(this, m_stage, m_stageStr),
+              m_pairsParser(this, m_data),
+              m_callback(callback), m_callbackData(data)
+        {}
+
+        virtual NodeParser* parseStartMap() { return this; }
+        virtual NodeParser* parseMapKey(const std::string& value)
+        {
+            std::string key = ToLower(value);
+
+            if (key == "packet")
+                return &m_packetParser;
+            else if (key == "stage")
+                return &m_stageParser;
+            else if (key == "pairs")
+                return &m_pairsParser;
+
+            return this;
+        }
+        virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
+
+        const std::string& packetStr() const { return m_packetStr; }
+        Packet packet() const { return m_packet; }
+        const std::string& stageStr() const { return m_stageStr; }
+        Stage stage() const { return m_stage; }
+        const Config::Pairs& data() const { return m_data; }
+
+    private:
+        std::string m_packetStr;
+        Packet m_packet;
+        std::string m_stageStr;
+        Stage m_stage;
+        Config::Pairs m_data;
+
+        PacketParser m_packetParser;
+        StageParser m_stageParser;
+        PairsParser m_pairsParser;
+
+        Callback m_callback;
+        void* m_callbackData;
+};
+
+class ProtoParser : public Parser
+{
+    public:
+        ProtoParser(TopParser::Callback callback, void* data)
+            : Parser( &m_topParser ),
+              m_topParser(callback, data)
+        {}
+
+        const std::string& packetStr() const { return m_topParser.packetStr(); }
+        Packet packet() const { return m_topParser.packet(); }
+        const std::string& stageStr() const { return m_topParser.stageStr(); }
+        Stage stage() const { return m_topParser.stage(); }
+        const Config::Pairs& data() const { return m_topParser.data(); }
+
+    private:
+        TopParser m_topParser;
+};
+
+class PacketGen : public Gen
+{
+    public:
+        PacketGen(const std::string& type)
+            : m_type(type)
+        {
+            m_gen.add("packet", m_type);
+        }
+        void run(yajl_gen_t* handle) const
+        {
+            m_gen.run(handle);
+        }
+        PacketGen& add(const std::string& key, const std::string& value)
+        {
+            m_gen.add(key, new StringGen(value));
+            return *this;
+        }
+        PacketGen& add(const std::string& key, MapGen* map)
+        {
+            m_gen.add(key, map);
+            return *this;
+        }
+        PacketGen& add(const std::string& key, MapGen& map)
+        {
+            m_gen.add(key, map);
+            return *this;
+        }
+    private:
+        MapGen m_gen;
+        StringGen m_type;
+};
+
+std::string toString(Config::ReturnCode code)
+{
+    switch (code)
+    {
+        case Config::REJECT:   return "reject";
+        case Config::FAIL:     return "fail";
+        case Config::OK:       return "ok";
+        case Config::HANDLED:  return "handled";
+        case Config::INVALID:  return "invalid";
+        case Config::USERLOCK: return "userlock";
+        case Config::NOTFOUND: return "notfound";
+        case Config::NOOP:     return "noop";
+        case Config::UPDATED:  return "noop";
+    }
+    return "reject";
+}
+
+}
+
+class Conn::Impl
+{
+    public:
+        Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
+        ~Impl();
+
+        int sock() const { return m_sock; }
+
+        bool read();
+        bool tick();
+
+        bool isOk() const { return m_ok; }
+
+    private:
+        USERS& m_users;
+        PLUGIN_LOGGER& m_logger;
+        RADIUS& m_plugin;
+        const Config& m_config;
+        int m_sock;
+        std::string m_remote;
+        bool m_ok;
+        time_t m_lastPing;
+        time_t m_lastActivity;
+        ProtoParser m_parser;
+        std::set<std::string> m_authorized;
+
+        template <typename T>
+        const T& stageMember(T Config::Section::* member) const
+        {
+            switch (m_parser.stage())
+            {
+                case AUTHORIZE: return m_config.autz.*member;
+                case AUTHENTICATE: return m_config.auth.*member;
+                case POSTAUTH: return m_config.postauth.*member;
+                case PREACCT: return m_config.preacct.*member;
+                case ACCOUNTING: return m_config.acct.*member;
+            }
+            throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
+        }
+
+        const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
+        const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
+        const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
+        Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
+        const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
+
+        static void process(void* data);
+        void processPing();
+        void processPong();
+        void processData();
+        bool answer(const USER& user);
+        bool answerNo();
+        bool sendPing();
+        bool sendPong();
+
+        static bool write(void* data, const char* buf, size_t size);
+};
+
+Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
+    : m_impl(new Impl(users, logger, plugin, config, fd, remote))
+{
+}
+
+Conn::~Conn()
+{
+}
+
+int Conn::sock() const
+{
+    return m_impl->sock();
+}
+
+bool Conn::read()
+{
+    return m_impl->read();
+}
+
+bool Conn::tick()
+{
+    return m_impl->tick();
+}
+
+bool Conn::isOk() const
+{
+    return m_impl->isOk();
+}
+
+Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
+    : m_users(users),
+      m_logger(logger),
+      m_plugin(plugin),
+      m_config(config),
+      m_sock(fd),
+      m_remote(remote),
+      m_ok(true),
+      m_lastPing(time(NULL)),
+      m_lastActivity(m_lastPing),
+      m_parser(&Conn::Impl::process, this)
+{
+}
+
+Conn::Impl::~Impl()
+{
+    close(m_sock);
+
+    std::set<std::string>::const_iterator it = m_authorized.begin();
+    for (; it != m_authorized.end(); ++it)
+        m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
+}
+
+bool Conn::Impl::read()
+{
+    static std::vector<char> buffer(1024);
+    ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
+    if (res < 0)
+    {
+        m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+        m_ok = false;
+        return false;
+    }
+    printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
+    m_lastActivity = time(NULL);
+    if (res == 0)
+    {
+        m_ok = false;
+        return true;
+    }
+    return m_parser.append(buffer.data(), res);
+}
+
+bool Conn::Impl::tick()
+{
+    time_t now = time(NULL);
+    if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
+    {
+        int delta = difftime(now, m_lastActivity);
+        printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
+        m_logger("Connection to " + m_remote + " timed out.");
+        m_ok = false;
+        return false;
+    }
+    if (difftime(now, m_lastPing) > PING_TIMEOUT)
+    {
+        int delta = difftime(now, m_lastPing);
+        printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
+        sendPing();
+    }
+    return true;
+}
+
+void Conn::Impl::process(void* data)
+{
+    Impl& impl = *static_cast<Impl*>(data);
+    try
+    {
+        switch (impl.m_parser.packet())
+        {
+            case PING:
+                impl.processPing();
+                return;
+            case PONG:
+                impl.processPong();
+                return;
+            case DATA:
+                impl.processData();
+                return;
+        }
+    }
+    catch (const std::exception& ex)
+    {
+        printfd(__FILE__, "Processing error. %s", ex.what());
+        impl.m_logger("Processing error. %s", ex.what());
+    }
+    printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
+    impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
+}
+
+void Conn::Impl::processPing()
+{
+    printfd(__FILE__, "Got ping. Sending pong...\n");
+    sendPong();
+}
+
+void Conn::Impl::processPong()
+{
+    printfd(__FILE__, "Got pong.\n");
+    m_lastActivity = time(NULL);
+}
+
+void Conn::Impl::processData()
+{
+    printfd(__FILE__, "Got data.\n");
+    int handle = m_users.OpenSearch();
+
+    USER_PTR user = NULL;
+    bool matched = false;
+    while (m_users.SearchNext(handle, &user) == 0)
+    {
+        if (user == NULL)
+            continue;
+
+        matched = true;
+        for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
+        {
+            Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
+            if (pos == m_parser.data().end())
+            {
+                matched = false;
+                break;
+            }
+            if (user->GetParamValue(it->second) != pos->second)
+            {
+                matched = false;
+                break;
+            }
+        }
+        if (!matched)
+            continue;
+        answer(*user);
+        if (authorize().check(*user, m_parser.data()))
+        {
+            m_plugin.authorize(*user);
+            m_authorized.insert(user->GetLogin());
+        }
+        break;
+    }
+
+    if (!matched)
+        answerNo();
+
+    m_users.CloseSearch(handle);
+}
+
+bool Conn::Impl::answer(const USER& user)
+{
+    printfd(__FILE__, "Got match. Sending answer...\n");
+    MapGen replyData;
+    for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
+        replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
+
+    MapGen modifyData;
+    for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
+        modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
+
+    PacketGen gen("data");
+    gen.add("result", "ok")
+       .add("reply", replyData)
+       .add("modify", modifyData);
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::answerNo()
+{
+    printfd(__FILE__, "No match. Sending answer...\n");
+    PacketGen gen("data");
+    gen.add("result", "no");
+    gen.add("return_code", toString(returnCode()));
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPing()
+{
+    PacketGen gen("ping");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPong()
+{
+    PacketGen gen("pong");
+
+    m_lastPing = time(NULL);
+
+    return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::write(void* data, const char* buf, size_t size)
+{
+    std::string json(buf, size);
+    printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
+    Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
+    while (size > 0)
+    {
+        ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
+        if (res < 0)
+        {
+            conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
+            conn.m_ok = false;
+            return false;
+        }
+        size -= res;
+    }
+    return true;
+}
diff --git a/projects/stargazer/plugins/other/radius/conn.h b/projects/stargazer/plugins/other/radius/conn.h
new file mode 100644 (file)
index 0000000..96e7430
--- /dev/null
@@ -0,0 +1,58 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#ifndef __STG_SGCONFIG_CONN_H__
+#define __STG_SGCONFIG_CONN_H__
+
+#include <boost/scoped_ptr.hpp>
+
+#include <string>
+
+class USER;
+class USERS;
+class PLUGIN_LOGGER;
+class RADIUS;
+
+namespace STG
+{
+
+struct Config;
+
+class Conn
+{
+    public:
+        Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
+        ~Conn();
+
+        int sock() const;
+
+        bool read();
+        bool tick();
+
+        bool isOk() const;
+
+    private:
+        class Impl;
+        boost::scoped_ptr<Impl> m_impl;
+};
+
+}
+
+#endif
index 8e52cdb48ba203f8763082f67ef58f711de00672..45a9d0e19446f80054df4c9bdf972b0ba79ef73f 100644 (file)
  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
  */
 
-/*
- *  This file contains a realization of radius data access plugin for Stargazer
- *
- *  $Revision: 1.14 $
- *  $Date: 2009/12/13 14:17:13 $
- *
- */
+#include "radius.h"
+
+#include "stg/store.h"
+#include "stg/users.h"
+#include "stg/plugin_creator.h"
+#include "stg/common.h"
 
+#include <algorithm>
+#include <stdexcept>
 #include <csignal>
 #include <cerrno>
-#include <algorithm>
+#include <cstring>
 
-#include "stg/store.h"
-#include "stg/common.h"
-#include "stg/user_conf.h"
-#include "stg/user_property.h"
-#include "stg/plugin_creator.h"
-#include "radius.h"
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h> // UNIX
+#include <netinet/in.h> // IP
+#include <netinet/tcp.h> // TCP
+#include <netdb.h>
 
-extern volatile time_t stgTime;
+using STG::Config;
+using STG::Conn;
 
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
 namespace
 {
-PLUGIN_CREATOR<RADIUS> radc;
 
-void InitEncrypt(BLOWFISH_CTX * ctx, const std::string & password);
-void Decrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8);
-void Encrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8);
-}
-extern "C" PLUGIN * GetPlugin();
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-PLUGIN * GetPlugin()
-{
-return radc.GetPlugin();
+PLUGIN_CREATOR<RADIUS> creator;
+
 }
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-int RAD_SETTINGS::ParseServices(const std::vector<std::string> & str, std::list<std::string> * lst)
+
+extern "C" PLUGIN * GetPlugin()
 {
-std::copy(str.begin(), str.end(), std::back_inserter(*lst));
-std::list<std::string>::iterator it(std::find(lst->begin(),
-                               lst->end(),
-                               "empty"));
-if (it != lst->end())
-    *it = "";
-
-return 0;
+    return creator.GetPlugin();
 }
-//-----------------------------------------------------------------------------
-int RAD_SETTINGS::ParseSettings(const MODULE_SETTINGS & s)
-{
-int p;
-PARAM_VALUE pv;
-std::vector<PARAM_VALUE>::const_iterator pvi;
-///////////////////////////
-pv.param = "Port";
-pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv);
-if (pvi == s.moduleParams.end())
-    {
-    errorStr = "Parameter \'Port\' not found.";
-    printfd(__FILE__, "Parameter 'Port' not found\n");
-    return -1;
-    }
-if (ParseIntInRange(pvi->value[0], 2, 65535, &p))
-    {
-    errorStr = "Cannot parse parameter \'Port\': " + errorStr;
-    printfd(__FILE__, "Cannot parse parameter 'Port'\n");
-    return -1;
-    }
-port = static_cast<uint16_t>(p);
-///////////////////////////
-pv.param = "Password";
-pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv);
-if (pvi == s.moduleParams.end())
-    {
-    errorStr = "Parameter \'Password\' not found.";
-    printfd(__FILE__, "Parameter 'Password' not found\n");
-    return -1;
-    }
-password = pvi->value[0];
-///////////////////////////
-pv.param = "AuthServices";
-pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv);
-if (pvi != s.moduleParams.end())
-    {
-    ParseServices(pvi->value, &authServices);
-    }
-///////////////////////////
-pv.param = "AcctServices";
-pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv);
-if (pvi != s.moduleParams.end())
-    {
-    ParseServices(pvi->value, &acctServices);
-    }
 
-return 0;
-}
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
 RADIUS::RADIUS()
-    : ctx(),
-      errorStr(),
-      radSettings(),
-      settings(),
-      authServices(),
-      acctServices(),
-      sessions(),
-      nonstop(false),
-      isRunning(false),
-      users(NULL),
-      stgSettings(NULL),
-      store(NULL),
-      thread(),
-      mutex(),
-      sock(-1),
-      packet(),
-      logger(GetPluginLogger(GetStgLogger(), "radius"))
+    : m_config(),
+      m_running(false),
+      m_stopped(true),
+      m_users(NULL),
+      m_store(NULL),
+      m_listenSocket(0),
+      m_logger(GetPluginLogger(GetStgLogger(), "radius"))
 {
-InitEncrypt(&ctx, "");
 }
-//-----------------------------------------------------------------------------
+
 int RADIUS::ParseSettings()
 {
-int ret = radSettings.ParseSettings(settings);
-if (ret)
-    errorStr = radSettings.GetStrError();
-return ret;
-}
-//-----------------------------------------------------------------------------
-int RADIUS::PrepareNet()
-{
-sock = socket(AF_INET, SOCK_DGRAM, 0);
-
-if (sock < 0)
-    {
-    errorStr = "Cannot create socket.";
-    logger("Cannot create a socket: %s", strerror(errno));
-    printfd(__FILE__, "Cannot create socket\n");
-    return -1;
-    }
-
-struct sockaddr_in inAddr;
-inAddr.sin_family = AF_INET;
-inAddr.sin_port = htons(radSettings.GetPort());
-inAddr.sin_addr.s_addr = inet_addr("0.0.0.0");
-
-if (bind(sock, (struct sockaddr*)&inAddr, sizeof(inAddr)) < 0)
-    {
-    errorStr = "RADIUS: Bind failed.";
-    logger("Cannot bind the socket: %s", strerror(errno));
-    printfd(__FILE__, "Cannot bind socket\n");
-    return -1;
+    try {
+        m_config = STG::Config(m_settings);
+        return reconnect() ? 0 : -1;
+    } catch (const std::runtime_error& ex) {
+        m_logger("Failed to parse settings. %s", ex.what());
+        return -1;
     }
-
-return 0;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::FinalizeNet()
-{
-close(sock);
-return 0;
-}
-//-----------------------------------------------------------------------------
+
 int RADIUS::Start()
 {
-std::string password(radSettings.GetPassword());
-
-authServices = radSettings.GetAuthServices();
-acctServices = radSettings.GetAcctServices();
-
-InitEncrypt(&ctx, password);
+    if (m_running)
+        return 0;
 
-nonstop = true;
+    int res = pthread_create(&m_thread, NULL, run, this);
+    if (res == 0)
+        return 0;
 
-if (PrepareNet())
-    {
+    m_error = strerror(res);
+    m_logger("Failed to create thread: '" + m_error + "'.");
     return -1;
-    }
-
-if (!isRunning)
-    {
-    if (pthread_create(&thread, NULL, Run, this))
-        {
-        errorStr = "Cannot create thread.";
-       logger("Cannot create thread.");
-        printfd(__FILE__, "Cannot create thread\n");
-        return -1;
-        }
-    }
-
-errorStr = "";
-return 0;
 }
-//-----------------------------------------------------------------------------
+
 int RADIUS::Stop()
 {
-if (!IsRunning())
-    return 0;
+    std::set<std::string>::const_iterator it = m_logins.begin();
+    for (; it != m_logins.end(); ++it)
+        m_users->Unauthorize(*it, this, "Stopping RADIUS plugin.");
+    m_logins.clear();
 
-nonstop = false;
+    if (m_stopped)
+        return 0;
 
-std::map<std::string, RAD_SESSION>::iterator it;
-for (it = sessions.begin(); it != sessions.end(); ++it)
-    {
-    USER_PTR ui;
-    if (users->FindByName(it->second.userName, &ui))
-        {
-        users->Unauthorize(ui->GetLogin(), this);
-        }
-    }
-sessions.erase(sessions.begin(), sessions.end());
+    m_running = false;
 
-FinalizeNet();
-
-if (isRunning)
-    {
-    //5 seconds to thread stops itself
-    for (int i = 0; i < 25 && isRunning; i++)
-        {
+    for (size_t i = 0; i < 25 && !m_stopped; i++) {
         struct timespec ts = {0, 200000000};
         nanosleep(&ts, NULL);
-        }
     }
 
-if (isRunning)
-    return -1;
-
-return 0;
-}
-//-----------------------------------------------------------------------------
-void * RADIUS::Run(void * d)
-{
-sigset_t signalSet;
-sigfillset(&signalSet);
-pthread_sigmask(SIG_BLOCK, &signalSet, NULL);
-
-RADIUS * rad = static_cast<RADIUS *>(d);
-RAD_PACKET packet;
-
-rad->isRunning = true;
-
-while (rad->nonstop)
-    {
-    if (!WaitPackets(rad->sock))
-        {
-        continue;
-        }
-    struct sockaddr_in outerAddr;
-    if (rad->RecvData(&packet, &outerAddr))
-        {
-        printfd(__FILE__, "RADIUS::Run Error on RecvData\n");
-        }
-    else
-        {
-        if (rad->ProcessData(&packet))
-            {
-            packet.packetType = RAD_REJECT_PACKET;
-            }
-        rad->Send(packet, &outerAddr);
-        }
+    if (m_stopped) {
+        pthread_join(m_thread, NULL);
+        return 0;
     }
 
-rad->isRunning = false;
-
-return NULL;
-}
-//-----------------------------------------------------------------------------
-int RADIUS::RecvData(RAD_PACKET * packet, struct sockaddr_in * outerAddr)
-{
-    int8_t buf[RAD_MAX_PACKET_LEN];
-    socklen_t outerAddrLen = sizeof(struct sockaddr_in);
-    ssize_t dataLen = recvfrom(sock, buf, RAD_MAX_PACKET_LEN, 0, reinterpret_cast<struct sockaddr *>(outerAddr), &outerAddrLen);
-    if (dataLen < 0)
-       {
-       logger("recvfrom error: %s", strerror(errno));
-       return -1;
-       }
-    if (dataLen == 0)
-       return -1;
-
-    Decrypt(&ctx, (char *)packet, (const char *)buf, dataLen / 8);
-
-    if (strncmp((char *)packet->magic, RAD_ID, RAD_MAGIC_LEN))
-        {
-        printfd(__FILE__, "RADIUS::RecvData Error magic. Wanted: '%s', got: '%s'\n", RAD_ID, packet->magic);
-        return -1;
-        }
+    if (m_config.connectionType == Config::UNIX)
+        unlink(m_config.bindAddress.c_str());
 
-    return 0;
-}
-//-----------------------------------------------------------------------------
-ssize_t RADIUS::Send(const RAD_PACKET & packet, struct sockaddr_in * outerAddr)
-{
-size_t len = sizeof(RAD_PACKET);
-char buf[1032];
-
-Encrypt(&ctx, buf, (char *)&packet, len / 8);
-ssize_t res = sendto(sock, buf, len, 0, reinterpret_cast<struct sockaddr *>(outerAddr), sizeof(struct sockaddr_in));
-if (res < 0)
-    logger("sendto error: %s", strerror(errno));
-return res;
-}
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessData(RAD_PACKET * packet)
-{
-if (strncmp((const char *)packet->protoVer, "01", 2))
-    {
-    printfd(__FILE__, "RADIUS::ProcessData packet.protoVer incorrect\n");
+    m_error = "Failed to stop thread.";
+    m_logger(m_error);
     return -1;
-    }
-switch (packet->packetType)
-    {
-    case RAD_AUTZ_PACKET:
-        return ProcessAutzPacket(packet);
-    case RAD_AUTH_PACKET:
-        return ProcessAuthPacket(packet);
-    case RAD_POST_AUTH_PACKET:
-        return ProcessPostAuthPacket(packet);
-    case RAD_ACCT_START_PACKET:
-        return ProcessAcctStartPacket(packet);
-    case RAD_ACCT_STOP_PACKET:
-        return ProcessAcctStopPacket(packet);
-    case RAD_ACCT_UPDATE_PACKET:
-        return ProcessAcctUpdatePacket(packet);
-    case RAD_ACCT_OTHER_PACKET:
-        return ProcessAcctOtherPacket(packet);
-    default:
-        printfd(__FILE__, "RADIUS::ProcessData Unsupported packet type: %d\n", packet->packetType);
-        return -1;
-    };
 }
 //-----------------------------------------------------------------------------
-int RADIUS::ProcessAutzPacket(RAD_PACKET * packet)
+void* RADIUS::run(void* d)
 {
-USER_CONF conf;
-
-if (!IsAllowedService((char *)packet->service))
-    {
-    printfd(__FILE__, "RADIUS::ProcessAutzPacket service '%s' is not allowed to authorize\n", packet->service);
-    packet->packetType = RAD_REJECT_PACKET;
-    return 0;
-    }
-
-if (store->RestoreUserConf(&conf, (char *)packet->login))
-    {
-    packet->packetType = RAD_REJECT_PACKET;
-    printfd(__FILE__, "RADIUS::ProcessAutzPacket cannot restore conf for user '%s'\n", packet->login);
-    return 0;
-    }
+    sigset_t signalSet;
+    sigfillset(&signalSet);
+    pthread_sigmask(SIG_BLOCK, &signalSet, NULL);
 
-// At this point service can be authorized at least
-// So we send a plain-text password
+    static_cast<RADIUS *>(d)->runImpl();
 
-packet->packetType = RAD_ACCEPT_PACKET;
-strncpy((char *)packet->password, conf.password.c_str(), RAD_PASSWORD_LEN);
-
-return 0;
+    return NULL;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessAuthPacket(RAD_PACKET * packet)
-{
-USER_PTR ui;
 
-if (!CanAcctService((char *)packet->service))
-    {
-
-    // There are no sense to check for allowed service
-    // It has allready checked at previous stage (authorization)
-
-    printfd(__FILE__, "RADIUS::ProcessAuthPacket service '%s' neednot stargazer authentication\n", (char *)packet->service);
-    packet->packetType = RAD_ACCEPT_PACKET;
-    return 0;
-    }
-
-// At this point we have an accountable service
-// All other services got a password if allowed or rejected
-
-if (!FindUser(&ui, (char *)packet->login))
+bool RADIUS::reconnect()
+{
+    if (!m_conns.empty())
     {
-    packet->packetType = RAD_REJECT_PACKET;
-    printfd(__FILE__, "RADIUS::ProcessAuthPacket user '%s' not found\n", (char *)packet->login);
-    return 0;
+        std::deque<STG::Conn *>::const_iterator it;
+        for (it = m_conns.begin(); it != m_conns.end(); ++it)
+            delete(*it);
+        m_conns.clear();
     }
-
-if (ui->IsInetable())
+    if (m_listenSocket != 0)
     {
-    packet->packetType = RAD_ACCEPT_PACKET;
+        shutdown(m_listenSocket, SHUT_RDWR);
+        close(m_listenSocket);
     }
-else
+    if (m_config.connectionType == Config::UNIX)
+        m_listenSocket = createUNIX();
+    else
+        m_listenSocket = createTCP();
+    if (m_listenSocket == 0)
+        return false;
+    if (listen(m_listenSocket, 100) == -1)
     {
-    packet->packetType = RAD_REJECT_PACKET;
+        m_error = std::string("Error starting to listen socket: ") + strerror(errno);
+        m_logger(m_error);
+        return false;
     }
-
-packet->packetType = RAD_ACCEPT_PACKET;
-return 0;
+    return true;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessPostAuthPacket(RAD_PACKET * packet)
-{
-USER_PTR ui;
 
-if (!CanAcctService((char *)packet->service))
+int RADIUS::createUNIX() const
+{
+    int fd = socket(AF_UNIX, SOCK_STREAM, 0);
+    if (fd == -1)
     {
-
-    // There are no sense to check for allowed service
-    // It has allready checked at previous stage (authorization)
-
-    packet->packetType = RAD_ACCEPT_PACKET;
-    return 0;
+        m_error = std::string("Error creating UNIX socket: ") + strerror(errno);
+        m_logger(m_error);
+        return 0;
     }
-
-if (!FindUser(&ui, (char *)packet->login))
+    struct sockaddr_un addr;
+    memset(&addr, 0, sizeof(addr));
+    addr.sun_family = AF_UNIX;
+    strncpy(addr.sun_path, m_config.bindAddress.c_str(), m_config.bindAddress.length());
+    unlink(m_config.bindAddress.c_str());
+    if (bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
     {
-    packet->packetType = RAD_REJECT_PACKET;
-    printfd(__FILE__, "RADIUS::ProcessPostAuthPacket user '%s' not found\n", (char *)packet->login);
-    return 0;
+        shutdown(fd, SHUT_RDWR);
+        close(fd);
+        m_error = std::string("Error binding UNIX socket: ") + strerror(errno);
+        m_logger(m_error);
+        return 0;
     }
-
-// I think that only Framed-User services has sense to be accountable
-// So we have to supply a Framed-IP
-
-USER_IPS ips = ui->GetProperty().ips;
-packet->packetType = RAD_ACCEPT_PACKET;
-
-// Additional checking for Framed-User service
-
-if (!strncmp((char *)packet->service, "Framed-User", RAD_SERVICE_LEN))
-    packet->ip = ips[0].ip;
-else
-    packet->ip = 0;
-
-return 0;
+    chown(m_config.bindAddress.c_str(), m_config.sockUID, m_config.sockGID);
+    if (m_config.sockMode != static_cast<mode_t>(-1))
+        chmod(m_config.bindAddress.c_str(), m_config.sockMode);
+    return fd;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessAcctStartPacket(RAD_PACKET * packet)
-{
-USER_PTR ui;
 
-if (!FindUser(&ui, (char *)packet->login))
+int RADIUS::createTCP() const
+{
+    addrinfo hints;
+    memset(&hints, 0, sizeof(addrinfo));
+
+    hints.ai_family = AF_INET;       /* Allow IPv4 */
+    hints.ai_socktype = SOCK_STREAM; /* Stream socket */
+    hints.ai_flags = AI_PASSIVE;     /* For wildcard IP address */
+    hints.ai_protocol = 0;           /* Any protocol */
+    hints.ai_canonname = NULL;
+    hints.ai_addr = NULL;
+    hints.ai_next = NULL;
+
+    addrinfo* ais = NULL;
+    int res = getaddrinfo(m_config.bindAddress.c_str(), m_config.portStr.c_str(), &hints, &ais);
+    if (res != 0)
     {
-    packet->packetType = RAD_REJECT_PACKET;
-    printfd(__FILE__, "RADIUS::ProcessAcctStartPacket user '%s' not found\n", (char *)packet->login);
-    return 0;
+        m_error = "Error resolving address '" + m_config.bindAddress + "': " + gai_strerror(res);
+        m_logger(m_error);
+        return 0;
     }
 
-// At this point we have to unauthorize user only if it is an accountable service
-
-if (CanAcctService((char *)packet->service))
+    for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
     {
-    if (sessions.find((const char *)packet->sessid) != sessions.end())
+        int fd = socket(AF_INET, SOCK_STREAM, 0);
+        if (fd == -1)
         {
-        printfd(__FILE__, "RADIUS::ProcessAcctStartPacket session already started!\n");
-        packet->packetType = RAD_REJECT_PACKET;
-        return -1;
+            m_error = std::string("Error creating TCP socket: ") + strerror(errno);
+            m_logger(m_error);
+            freeaddrinfo(ais);
+            return 0;
         }
-    USER_IPS ips = ui->GetProperty().ips;
-    if (!users->Authorize(ui->GetLogin(), ips[0].ip, 0xffFFffFF, this))
+        if (bind(fd, ai->ai_addr, ai->ai_addrlen) == -1)
         {
-        printfd(__FILE__, "RADIUS::ProcessAcctStartPacket cannot authorize user '%s'\n", packet->login);
-        packet->packetType = RAD_REJECT_PACKET;
-        return -1;
+            shutdown(fd, SHUT_RDWR);
+            close(fd);
+            m_error = std::string("Error binding TCP socket: ") + strerror(errno);
+            m_logger(m_error);
+            continue;
         }
-    sessions[(const char *)packet->sessid].userName = (const char *)packet->login;
-    sessions[(const char *)packet->sessid].serviceType = (const char *)packet->service;
-    for_each(sessions.begin(), sessions.end(), SPrinter());
-    }
-else
-    {
-    printfd(__FILE__, "RADIUS::ProcessAcctStartPacket service '%s' can not be accounted\n", (char *)packet->service);
+        freeaddrinfo(ais);
+        return fd;
     }
 
-packet->packetType = RAD_ACCEPT_PACKET;
-return 0;
+    m_error = "Failed to resolve '" + m_config.bindAddress;
+    m_logger(m_error);
+
+    freeaddrinfo(ais);
+    return 0;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessAcctStopPacket(RAD_PACKET * packet)
+
+void RADIUS::runImpl()
 {
-std::map<std::string, RAD_SESSION>::iterator sid;
+    m_running = true;
+    m_stopped = false;
 
-if ((sid = sessions.find((const char *)packet->sessid)) == sessions.end())
-    {
-    printfd(__FILE__, "RADIUS::ProcessAcctStopPacket session had not started yet\n");
-    packet->packetType = RAD_REJECT_PACKET;
-    return -1;
-    }
+    while (m_running) {
+        fd_set fds;
 
-USER_PTR ui;
+        buildFDSet(fds);
 
-if (!FindUser(&ui, sid->second.userName))
-    {
-    packet->packetType = RAD_REJECT_PACKET;
-    printfd(__FILE__, "RADIUS::ProcessPostAuthPacket user '%s' not found\n", sid->second.userName.c_str());
-    return 0;
-    }
+        struct timeval tv;
+        tv.tv_sec = 0;
+        tv.tv_usec = 500000;
+
+        int res = select(maxFD() + 1, &fds, NULL, NULL, &tv);
+        if (res < 0)
+        {
+            if (errno == EINTR)
+                continue;
+            m_error = std::string("'select' is failed: '") + strerror(errno) + "'.";
+            m_logger(m_error);
+            break;
+        }
 
-sessions.erase(sid);
+        if (!m_running)
+            break;
 
-users->Unauthorize(ui->GetLogin(), this);
+        if (res > 0)
+            handleEvents(fds);
+        else
+        {
+            for (std::deque<Conn*>::iterator it = m_conns.begin(); it != m_conns.end(); ++it)
+                (*it)->tick();
+        }
 
-packet->packetType = RAD_ACCEPT_PACKET;
-return 0;
+        cleanupConns();
+    }
+
+    m_stopped = true;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessAcctUpdatePacket(RAD_PACKET * packet)
+
+int RADIUS::maxFD() const
 {
-// Fake. May be use it later
-packet->packetType = RAD_ACCEPT_PACKET;
-return 0;
+    int maxFD = m_listenSocket;
+    std::deque<STG::Conn *>::const_iterator it;
+    for (it = m_conns.begin(); it != m_conns.end(); ++it)
+        if (maxFD < (*it)->sock())
+            maxFD = (*it)->sock();
+    return maxFD;
 }
-//-----------------------------------------------------------------------------
-int RADIUS::ProcessAcctOtherPacket(RAD_PACKET * packet)
+
+void RADIUS::buildFDSet(fd_set & fds) const
 {
-// Fake. May be use it later
-packet->packetType = RAD_ACCEPT_PACKET;
-return 0;
+    FD_ZERO(&fds);
+    FD_SET(m_listenSocket, &fds);
+    std::deque<STG::Conn *>::const_iterator it;
+    for (it = m_conns.begin(); it != m_conns.end(); ++it)
+        FD_SET((*it)->sock(), &fds);
 }
-//-----------------------------------------------------------------------------
-bool RADIUS::FindUser(USER_PTR * ui, const std::string & login) const
+
+void RADIUS::cleanupConns()
 {
-if (users->FindByName(login, ui))
+    std::deque<STG::Conn *>::iterator pos;
+    for (pos = m_conns.begin(); pos != m_conns.end(); ++pos)
+        if (!(*pos)->isOk()) {
+            delete *pos;
+            *pos = NULL;
+        }
+
+    pos = std::remove(m_conns.begin(), m_conns.end(), static_cast<STG::Conn *>(NULL));
+    m_conns.erase(pos, m_conns.end());
+}
+
+void RADIUS::handleEvents(const fd_set & fds)
+{
+    if (FD_ISSET(m_listenSocket, &fds))
+        acceptConnection();
+    else
     {
-    return false;
+        std::deque<STG::Conn *>::iterator it;
+        for (it = m_conns.begin(); it != m_conns.end(); ++it)
+            if (FD_ISSET((*it)->sock(), &fds))
+                (*it)->read();
+            else
+                (*it)->tick();
     }
-return true;
 }
-//-----------------------------------------------------------------------------
-bool RADIUS::CanAuthService(const std::string & svc) const
+
+void RADIUS::acceptConnection()
 {
-return find(authServices.begin(), authServices.end(), svc) != authServices.end();
+    if (m_config.connectionType == Config::UNIX)
+        acceptUNIX();
+    else
+        acceptTCP();
 }
-//-----------------------------------------------------------------------------
-bool RADIUS::CanAcctService(const std::string & svc) const
+
+void RADIUS::acceptUNIX()
 {
-return find(acctServices.begin(), acctServices.end(), svc) != acctServices.end();
+    struct sockaddr_un addr;
+    memset(&addr, 0, sizeof(addr));
+    socklen_t size = sizeof(addr);
+    int res = accept(m_listenSocket, reinterpret_cast<sockaddr*>(&addr), &size);
+    if (res == -1)
+    {
+        m_error = std::string("Failed to accept UNIX connection: ") + strerror(errno);
+        m_logger(m_error);
+        return;
+    }
+    printfd(__FILE__, "New UNIX connection: '%s'\n", addr.sun_path);
+    m_conns.push_back(new Conn(*m_users, m_logger, *this, m_config, res, addr.sun_path));
 }
-//-----------------------------------------------------------------------------
-bool RADIUS::IsAllowedService(const std::string & svc) const
+
+void RADIUS::acceptTCP()
 {
-return CanAuthService(svc) || CanAcctService(svc);
+    struct sockaddr_in addr;
+    memset(&addr, 0, sizeof(addr));
+    socklen_t size = sizeof(addr);
+    int res = accept(m_listenSocket, reinterpret_cast<sockaddr*>(&addr), &size);
+    if (res == -1)
+    {
+        m_error = std::string("Failed to accept TCP connection: ") + strerror(errno);
+        m_logger(m_error);
+        return;
+    }
+    std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port));
+    printfd(__FILE__, "New TCP connection: '%s'\n", remote.c_str());
+    m_conns.push_back(new Conn(*m_users, m_logger, *this, m_config, res, remote));
 }
-//-----------------------------------------------------------------------------
-namespace
-{
 
-inline
-void InitEncrypt(BLOWFISH_CTX * ctx, const std::string & password)
+void RADIUS::authorize(const USER& user)
 {
-unsigned char keyL[RAD_PASSWORD_LEN];  // Пароль для шифровки
-memset(keyL, 0, RAD_PASSWORD_LEN);
-strncpy((char *)keyL, password.c_str(), RAD_PASSWORD_LEN);
-Blowfish_Init(ctx, keyL, RAD_PASSWORD_LEN);
+    uint32_t ip = 0;
+    const std::string& login(user.GetLogin());
+    if (!m_users->Authorize(login, ip, 0xffFFffFF, this))
+    {
+        m_error = "Unable to authorize user '" + login + "' with ip " + inet_ntostring(ip) + ".";
+        m_logger(m_error);
+    }
+    else
+        m_logins.insert(login);
 }
-//-----------------------------------------------------------------------------
-inline
-void Encrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8)
-{
-// len8 - длина в 8-ми байтовых блоках
-if (dst != src)
-    memcpy(dst, src, len8 * 8);
 
-for (size_t i = 0; i < len8; i++)
-    Blowfish_Encrypt(ctx, static_cast<uint32_t *>(dst) + i * 2, static_cast<uint32_t *>(dst) + i * 2 + 1);
-}
-//-----------------------------------------------------------------------------
-inline
-void Decrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8)
+void RADIUS::unauthorize(const std::string& login, const std::string& reason)
 {
-// len8 - длина в 8-ми байтовых блоках
-if (dst != src)
-    memcpy(dst, src, len8 * 8);
-
-for (size_t i = 0; i < len8; i++)
-    Blowfish_Decrypt(ctx, static_cast<uint32_t *>(dst) + i * 2, static_cast<uint32_t *>(dst) + i * 2 + 1);
+    const std::set<std::string>::const_iterator it = m_logins.find(login);
+    if (it == m_logins.end())
+        return;
+    m_logins.erase(it);
+    m_users->Unauthorize(login, this, reason);
 }
-
-} // namespace anonymous
index 0f1c95fcda66563c592f3238e611b337ed645b54..52da138ec6eeab1a21c25e71e3e2eed31d14740a 100644 (file)
  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
  */
 
-/*
- *  Radius data access plugin for Stargazer
- *
- *  $Revision: 1.10 $
- *  $Date: 2009/12/13 14:17:13 $
- *
- */
-
-#ifndef RADIUS_H
-#define RADIUS_H
-
-#include <pthread.h>
-
-#include <cstring>
-#include <cstdlib>
-#include <string>
-#include <list>
-#include <map>
-#include <vector>
+#ifndef __STG_RADIUS_H__
+#define __STG_RADIUS_H__
 
 #include "stg/os_int.h"
 #include "stg/auth.h"
 #include "stg/module_settings.h"
-#include "stg/notifer.h"
-#include "stg/user_ips.h"
-#include "stg/user.h"
-#include "stg/users.h"
-#include "stg/blowfish.h"
-#include "stg/rad_packets.h"
 #include "stg/logger.h"
 
-extern "C" PLUGIN * GetPlugin();
+#include "config.h"
+#include "conn.h"
 
-#define RAD_DEBUG (1)
+#include <string>
+#include <deque>
+#include <set>
 
-class RADIUS;
-//-----------------------------------------------------------------------------
-class RAD_SETTINGS {
-public:
-    RAD_SETTINGS()
-        : port(0), errorStr(), password(),
-          authServices(), acctServices()
-    {}
-    virtual ~RAD_SETTINGS() {}
-    const std::string & GetStrError() const { return errorStr; }
-    int ParseSettings(const MODULE_SETTINGS & s);
-    uint16_t GetPort() const { return port; }
-    const std::string & GetPassword() const { return password; }
-    const std::list<std::string> & GetAuthServices() const { return authServices; }
-    const std::list<std::string> & GetAcctServices() const { return acctServices; }
+#include <pthread.h>
+#include <unistd.h>
+#include <sys/select.h>
+#include <sys/types.h>
 
-private:
-    int ParseServices(const std::vector<std::string> & str, std::list<std::string> * lst);
+extern "C" PLUGIN * GetPlugin();
 
-    uint16_t port;
-    std::string errorStr;
-    std::string password;
-    std::list<std::string> authServices;
-    std::list<std::string> acctServices;
-};
-//-----------------------------------------------------------------------------
-struct RAD_SESSION {
-    RAD_SESSION() : userName(), serviceType() {}
-    std::string userName;
-    std::string serviceType;
-};
-//-----------------------------------------------------------------------------
-class RADIUS :public AUTH {
+class STORE;
+class USERS;
+
+class RADIUS : public AUTH {
 public:
-                        RADIUS();
-    virtual             ~RADIUS() {}
+    RADIUS();
+    virtual ~RADIUS() {}
 
-    void                SetUsers(USERS * u) { users = u; }
-    void                SetStore(STORE * s) { store = s; }
-    void                SetStgSettings(const SETTINGS *) {}
-    void                SetSettings(const MODULE_SETTINGS & s) { settings = s; }
-    int                 ParseSettings();
+    void SetUsers(USERS* u) { m_users = u; }
+    void SetStore(STORE* s) { m_store = s; }
+    void SetStgSettings(const SETTINGS*) {}
+    void SetSettings(const MODULE_SETTINGS& s) { m_settings = s; }
+    int ParseSettings();
 
-    int                 Start();
-    int                 Stop();
-    int                 Reload(const MODULE_SETTINGS & /*ms*/) { return 0; }
-    bool                IsRunning() { return isRunning; }
+    int Start();
+    int Stop();
+    int Reload(const MODULE_SETTINGS & /*ms*/) { return 0; }
+    bool IsRunning() { return m_running; }
 
-    const std::string & GetStrError() const { return errorStr; }
-    std::string         GetVersion() const { return "RADIUS data access plugin v 0.6"; }
-    uint16_t            GetStartPosition() const { return 30; }
-    uint16_t            GetStopPosition() const { return 30; }
+    const std::string& GetStrError() const { return m_error; }
+    std::string GetVersion() const { return "RADIUS data access plugin v. 2.0"; }
+    uint16_t GetStartPosition() const { return 30; }
+    uint16_t GetStopPosition() const { return 30; }
 
-    int SendMessage(const STG_MSG &, uint32_t) const { return 0; }
+    int SendMessage(const STG_MSG&, uint32_t) const { return 0; }
+
+    void authorize(const USER& user);
+    void unauthorize(const std::string& login, const std::string& reason);
 
 private:
     RADIUS(const RADIUS & rvalue);
     RADIUS & operator=(const RADIUS & rvalue);
 
-    static void *       Run(void *);
-    int                 PrepareNet();
-    int                 FinalizeNet();
-
-    ssize_t             Send(const RAD_PACKET & packet, struct sockaddr_in * outerAddr);
-    int                 RecvData(RAD_PACKET * packet, struct sockaddr_in * outerAddr);
-    int                 ProcessData(RAD_PACKET * packet);
-
-    int                 ProcessAutzPacket(RAD_PACKET * packet);
-    int                 ProcessAuthPacket(RAD_PACKET * packet);
-    int                 ProcessPostAuthPacket(RAD_PACKET * packet);
-    int                 ProcessAcctStartPacket(RAD_PACKET * packet);
-    int                 ProcessAcctStopPacket(RAD_PACKET * packet);
-    int                 ProcessAcctUpdatePacket(RAD_PACKET * packet);
-    int                 ProcessAcctOtherPacket(RAD_PACKET * packet);
-
-    bool                FindUser(USER_PTR * ui, const std::string & login) const;
-    bool                CanAuthService(const std::string & svc) const;
-    bool                CanAcctService(const std::string & svc) const;
-    bool                IsAllowedService(const std::string & svc) const;
-
-    struct SPrinter : public std::unary_function<std::pair<std::string, RAD_SESSION>, void>
-    {
-        void operator()(const std::pair<std::string, RAD_SESSION> & it)
-        {
-            printfd("radius.cpp", "%s - ('%s', '%s')\n", it.first.c_str(), it.second.userName.c_str(), it.second.serviceType.c_str());
-        }
-    };
+    static void* run(void*);
 
-    BLOWFISH_CTX        ctx;
+    bool reconnect();
+    int createUNIX() const;
+    int createTCP() const;
+    void runImpl();
+    int maxFD() const;
+    void buildFDSet(fd_set & fds) const;
+    void cleanupConns();
+    void handleEvents(const fd_set & fds);
+    void acceptConnection();
+    void acceptUNIX();
+    void acceptTCP();
 
-    mutable std::string errorStr;
-    RAD_SETTINGS        radSettings;
-    MODULE_SETTINGS     settings;
-    std::list<std::string> authServices;
-    std::list<std::string> acctServices;
-    std::map<std::string, RAD_SESSION> sessions;
+    mutable std::string m_error;
+    STG::Config m_config;
 
-    bool                nonstop;
-    bool                isRunning;
+    MODULE_SETTINGS m_settings;
 
-    USERS *             users;
-    const SETTINGS *    stgSettings;
-    const STORE *       store;
+    bool m_running;
+    bool m_stopped;
 
-    pthread_t           thread;
-    pthread_mutex_t     mutex;
+    USERS* m_users;
+    const STORE* m_store;
 
-    int                 sock;
+    int m_listenSocket;
+    std::deque<STG::Conn*> m_conns;
+    std::set<std::string> m_logins;
 
-    RAD_PACKET          packet;
+    pthread_t m_thread;
 
-    PLUGIN_LOGGER       logger;
+    PLUGIN_LOGGER m_logger;
 };
-//-----------------------------------------------------------------------------
 
 #endif
index 0cf9501bced5db222ec6413721f9cf70f1b51a6d..c925177ffea8c37089d4adc9ee0f77fa378154dd 100644 (file)
  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
  */
 
-/*
- *    Date: 27.10.2002
- */
-
 /*
  *    Author : Boris Mikhailenko <stg34@stargazer.dp.ua>
  */
 
-/*
-$Revision: 1.45 $
-$Date: 2010/08/19 13:42:30 $
-$Author: faust $
-*/
+#include "settings_impl.h"
 
+#include "stg/logger.h"
+#include "stg/dotconfpp.h"
+#include "stg/common.h"
+
+#include <stdexcept>
 #include <cstring>
 #include <cerrno>
-#include <string>
 
-#include "stg/logger.h"
-#include "stg/dotconfpp.h"
-#include "settings_impl.h"
+namespace
+{
+
+struct Error : public std::runtime_error
+{
+    Error(const std::string& message) : runtime_error(message) {}
+};
+
+std::vector<std::string> toValues(const DOTCONFDocumentNode& node)
+{
+    std::vector<std::string> values;
+
+    size_t i = 0;
+    const char* value = NULL;
+    while ((value = node.getValue(i++)) != NULL)
+        values.push_back(value);
+
+    return values;
+}
+
+std::vector<PARAM_VALUE> toPVS(const DOTCONFDocumentNode& node)
+{
+    std::vector<PARAM_VALUE> pvs;
+
+    const DOTCONFDocumentNode* child = node.getChildNode();
+    while (child != NULL)
+        {
+        if (child->getName() == NULL)
+            continue;
+
+        if (child->getChildNode() == NULL)
+            pvs.push_back(PARAM_VALUE(child->getName(), toValues(*child)));
+        else
+            pvs.push_back(PARAM_VALUE(child->getName(), toValues(*child), toPVS(*child)));
+
+        child = child->getNextNode();
+        }
+
+    return pvs;
+}
+
+unsigned toPeriod(const char* value)
+{
+    if (value == NULL)
+        throw Error("No detail stat period value.");
+
+    std::string period(value);
+    if (period == "1")
+        return dsPeriod_1;
+    else if (period == "1/2")
+        return dsPeriod_1_2;
+    else if (period == "1/4")
+        return dsPeriod_1_4;
+    else if (period == "1/6")
+        return dsPeriod_1_6;
+
+    throw Error("Invalid detail stat period value: '" + period + "'. Should be one of '1', '1/2', '1/4' or '1/6'.");
+}
+
+}
 
 //-----------------------------------------------------------------------------
 SETTINGS_IMPL::SETTINGS_IMPL(const std::string & cd)
@@ -138,45 +191,6 @@ SETTINGS_IMPL & SETTINGS_IMPL::operator=(const SETTINGS_IMPL & rhs)
     return *this;
 }
 //-----------------------------------------------------------------------------
-int SETTINGS_IMPL::ParseModuleSettings(const DOTCONFDocumentNode * node, std::vector<PARAM_VALUE> * params)
-{
-const DOTCONFDocumentNode * childNode;
-PARAM_VALUE pv;
-const char * value;
-
-pv.param = node->getName();
-
-if (node->getValue(1))
-    {
-    strError = "Unexpected value \'" + std::string(node->getValue(1)) + "\'.";
-    return -1;
-    }
-
-value = node->getValue(0);
-
-if (!value)
-    {
-    strError = "Module name expected.";
-    return -1;
-    }
-
-childNode = node->getChildNode();
-while (childNode)
-    {
-    pv.param = childNode->getName();
-    int i = 0;
-    while ((value = childNode->getValue(i++)) != NULL)
-        {
-        pv.value.push_back(value);
-        }
-    params->push_back(pv);
-    pv.value.clear();
-    childNode = childNode->getNextNode();
-    }
-
-return 0;
-}
-//-----------------------------------------------------------------------------
 void SETTINGS_IMPL::ErrorCallback(void * data, const char * buf)
 {
     printfd(__FILE__, "SETTINGS_IMPL::ErrorCallback() - %s\n", buf);
@@ -246,11 +260,15 @@ while (node)
 
     if (strcasecmp(node->getName(), "DetailStatWritePeriod") == 0)
         {
-        if (ParseDetailStatWritePeriod(node->getValue(0)) != 0)
-            {
-            strError = "Incorrect DetailStatWritePeriod value: \'" + std::string(node->getValue(0)) + "\'";
+        try
+        {
+            detailStatWritePeriod = toPeriod(node->getValue(0));
+        }
+        catch (const Error& error)
+        {
+            strError = error.what();
             return -1;
-            }
+        }
         }
 
     if (strcasecmp(node->getName(), "StatWritePeriod") == 0)
@@ -442,8 +460,13 @@ while (node)
             }
         storeModulesCount++;
 
+        if (node->getValue(0) == NULL)
+            {
+            strError = "No module name in the StoreModule section.";
+            return -1;
+            }
         storeModuleSettings.moduleName = node->getValue(0);
-        ParseModuleSettings(node, &storeModuleSettings.moduleParams);
+        storeModuleSettings.moduleParams = toPVS(*node);
         }
 
     if (strcasecmp(node->getName(), "Modules") == 0)
@@ -461,13 +484,14 @@ while (node)
                 child = child->getNextNode();
                 continue;
                 }
-            MODULE_SETTINGS modSettings;
-            modSettings.moduleParams.clear();
-            modSettings.moduleName = child->getValue();
 
-            ParseModuleSettings(child, &modSettings.moduleParams);
+            if (child->getValue(0) == NULL)
+                {
+                strError = "No module name in the Module section.";
+                return -1;
+                }
 
-            modulesSettings.push_back(modSettings);
+            modulesSettings.push_back(MODULE_SETTINGS(child->getValue(0), toPVS(*child)));
 
             child = child->getNextNode();
             }
@@ -484,29 +508,3 @@ while (node)
 return 0;
 }
 //-----------------------------------------------------------------------------
-int SETTINGS_IMPL::ParseDetailStatWritePeriod(const std::string & detailStatPeriodStr)
-{
-if (detailStatPeriodStr == "1")
-    {
-    detailStatWritePeriod = dsPeriod_1;
-    return 0;
-    }
-else if (detailStatPeriodStr == "1/2")
-    {
-    detailStatWritePeriod = dsPeriod_1_2;
-    return 0;
-    }
-else if (detailStatPeriodStr == "1/4")
-    {
-    detailStatWritePeriod = dsPeriod_1_4;
-    return 0;
-    }
-else if (detailStatPeriodStr == "1/6")
-    {
-    detailStatWritePeriod = dsPeriod_1_6;
-    return 0;
-    }
-
-return -1;
-}
-//-----------------------------------------------------------------------------
index 9bcce5b0f30bed2dbb9a259d6fa061202c06b40e..1e6e2b0df6e0cf04b237e35986fdecef16484799 100644 (file)
@@ -1,9 +1,3 @@
- /*
- $Revision: 1.27 $
- $Date: 2010/08/19 13:42:30 $
- $Author: faust $
- */
-
 /*
  *    This program is free software; you can redistribute it and/or modify
  *    it under the terms of the GNU General Public License as published by
  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
  */
 
-/*
- *    Date: 27.10.2002
- */
-
 /*
  *    Author : Boris Mikhailenko <stg34@stargazer.dp.ua>
  */
 
- /*
- $Revision: 1.27 $
- $Date: 2010/08/19 13:42:30 $
- */
-
-
 #ifndef SETTINGS_IMPL_H
 #define SETTINGS_IMPL_H
 
 #include "stg/settings.h"
 #include "stg/common.h"
 #include "stg/module_settings.h"
+#include "stg/ref.h"
 
 #include <string>
 #include <vector>
 
 //-----------------------------------------------------------------------------
 enum DETAIL_STAT_PERIOD {
-dsPeriod_1,
-dsPeriod_1_2,
-dsPeriod_1_4,
-dsPeriod_1_6
+    dsPeriod_1,
+    dsPeriod_1_2,
+    dsPeriod_1_4,
+    dsPeriod_1_6
 };
 //-----------------------------------------------------------------------------
 class STG_LOGGER;
@@ -58,7 +43,7 @@ class DOTCONFDocumentNode;
 class SETTINGS_IMPL : public SETTINGS {
 public:
     explicit SETTINGS_IMPL(const std::string &);
-    SETTINGS_IMPL(const SETTINGS_IMPL &);
+    SETTINGS_IMPL(const SETTINGS_IMPL & rhs);
     virtual ~SETTINGS_IMPL() {}
     SETTINGS_IMPL & operator=(const SETTINGS_IMPL &);
 
@@ -75,7 +60,7 @@ public:
     const std::string & GetRulesFileName() const { return rules; }
     const std::string & GetLogFileName() const { return logFile; }
     const std::string & GetPIDFileName() const { return pidFile; }
-    unsigned            GetDetailStatWritePeriod() const 
+    unsigned            GetDetailStatWritePeriod() const
         { return detailStatWritePeriod; }
     unsigned            GetStatWritePeriod() const { return statWritePeriod * 60; }
     unsigned            GetDayFee() const { return dayFee; }
@@ -104,9 +89,6 @@ public:
 
 private:
 
-    int ParseDetailStatWritePeriod(const std::string & str);
-    int ParseModuleSettings(const DOTCONFDocumentNode * dirNameNode, std::vector<PARAM_VALUE> * params);
-
     static void ErrorCallback(void * data, const char * buf);
 
     std::string strError;
@@ -115,7 +97,7 @@ private:
     std::string modulesPath;
     std::vector<std::string> dirName;
     std::string confDir;
-    std::string        scriptsDir;
+    std::string scriptsDir;
     std::string rules;
     std::string logFile;
     std::string pidFile;
@@ -142,7 +124,7 @@ private:
 
     std::vector<MODULE_SETTINGS> modulesSettings;
     MODULE_SETTINGS storeModuleSettings;
-    STG_LOGGER & logger;
+    STG::RefWrapper<STG_LOGGER> logger;
 };
 //-----------------------------------------------------------------------------
 
index 1f662c25c3357985ec83d49a11fbb5129b7a4539..932e270948b62d3b78fed5186e6889e53dd304b2 100644 (file)
@@ -7,9 +7,11 @@ include ../../Makefile.conf
 LIB_NAME = stgcommon
 
 SRCS = common.cpp \
-       strptime.cpp
+       strptime.cpp \
+       blockio.cpp
 
-INCS = common.h
+INCS = common.h \
+       blockio.h
 
 LIBS += $(LIBICONV)
 
diff --git a/stglibs/common.lib/blockio.cpp b/stglibs/common.lib/blockio.cpp
new file mode 100644 (file)
index 0000000..04fd1d8
--- /dev/null
@@ -0,0 +1,102 @@
+#include "stg/blockio.h"
+
+namespace
+{
+
+void* adjust(void* base, size_t shift)
+{
+    char* ptr = static_cast<char*>(base);
+    return ptr + shift;
+}
+
+} // namspace anonymous
+
+using STG::BlockReader;
+using STG::BlockWriter;
+
+BlockReader::BlockReader(const IOVec& ioVec)
+    : m_dest(ioVec),
+      m_remainder(0)
+{
+    for (size_t i = 0; i < m_dest.size(); ++i)
+        m_remainder += m_dest[i].iov_len;
+}
+
+bool BlockReader::read(int socket)
+{
+    if (m_remainder == 0)
+        return true;
+
+    size_t offset = m_dest.size() - 1;
+    size_t toRead = m_remainder;
+    while (offset > 0) {
+        if (toRead < m_dest[offset].iov_len)
+            break;
+        toRead -= m_dest[offset].iov_len;
+        --offset;
+    }
+
+    IOVec dest(m_dest.size() - offset);
+    for (size_t i = 0; i < dest.size(); ++i) {
+        if (i == 0) {
+            dest[0].iov_len = toRead;
+            dest[0].iov_base = adjust(m_dest[offset].iov_base, m_dest[offset].iov_len - toRead);
+        } else {
+            dest[i] = m_dest[offset + i];
+        }
+    }
+
+    ssize_t res = readv(socket, dest.data(), dest.size());
+    if (res < 0)
+        return false;
+    if (res == 0)
+        return m_remainder == 0;
+    if (res < static_cast<ssize_t>(m_remainder))
+        m_remainder -= res;
+    else
+        m_remainder = 0;
+    return true;
+}
+
+BlockWriter::BlockWriter(const IOVec& ioVec)
+    : m_source(ioVec),
+      m_remainder(0)
+{
+    for (size_t i = 0; i < m_source.size(); ++i)
+        m_remainder += m_source[i].iov_len;
+}
+
+bool BlockWriter::write(int socket)
+{
+    if (m_remainder == 0)
+        return true;
+
+    size_t offset = m_source.size() - 1;
+    size_t toWrite = m_remainder;
+    while (offset > 0) {
+        if (toWrite < m_source[offset].iov_len)
+            break;
+        toWrite -= m_source[offset].iov_len;
+        --offset;
+    }
+
+    IOVec source(m_source.size() - offset);
+    for (size_t i = 0; i < source.size(); ++i) {
+        if (i == 0) {
+            source[0].iov_len = toWrite;
+            source[0].iov_base = adjust(m_source[offset].iov_base, m_source[offset].iov_len - toWrite);
+        } else {
+            source[i] = m_source[offset + i];
+        }
+    }
+    ssize_t res = writev(socket, source.data(), source.size());
+    if (res < 0)
+        return false;
+    if (res == 0)
+        return m_remainder == 0;
+    if (res < static_cast<ssize_t>(m_remainder))
+        m_remainder -= res;
+    else
+        m_remainder = 0;
+    return true;
+}
index fc7c35ce86c5b19eebc1a569c3120ae2e1fc2ce2..7bf27397f38e419a2904f2311ce6b29f64968a40 100644 (file)
@@ -32,7 +32,8 @@
 // Like FreeBSD4
 #include <sys/types.h>
 #include <sys/time.h>
-#include <unistd.h>
+#include <pwd.h>
+#include <grp.h>
 
 #include <sys/select.h>
 
@@ -1111,3 +1112,37 @@ std::string ToPrintable(const std::string & src)
 
     return dest;
 }
+
+uid_t str2uid(const std::string& name)
+{
+    const passwd* res = getpwnam(name.c_str());
+    if (res == NULL)
+        return -1;
+    return res->pw_uid;
+}
+
+gid_t str2gid(const std::string& name)
+{
+    const group* res = getgrnam(name.c_str());
+    if (res == NULL)
+        return -1;
+    return res->gr_gid;
+}
+
+mode_t str2mode(const std::string& name)
+{
+    if (name.length() < 3 || name.length() > 4)
+        return -1;
+
+    if (name.length() == 4 && name[0] != '0')
+        return -1;
+
+    mode_t res = 0;
+    for (size_t i = 0; i < name.length(); ++i)
+    {
+        if (name[i] > '7' || name[i] < '0')
+            return -1;
+        res = (res << 3) + (name[i] - '0');
+    }
+    return res;
+}
diff --git a/stglibs/common.lib/include/stg/blockio.h b/stglibs/common.lib/include/stg/blockio.h
new file mode 100644 (file)
index 0000000..3879e39
--- /dev/null
@@ -0,0 +1,43 @@
+#ifndef __STG_STGLIBS_BLOCK_IO_H__
+#define __STG_STGLIBS_BLOCK_IO_H__
+
+#include <vector>
+
+#include <sys/uio.h>
+
+namespace STG
+{
+
+typedef std::vector<iovec> IOVec;
+
+class BlockReader
+{
+    public:
+        BlockReader(const IOVec& ioVec);
+
+        bool read(int socket);
+        bool done() const { return m_remainder == 0; }
+        size_t remainder() const { return m_remainder; }
+
+    private:
+        IOVec m_dest;
+        size_t m_remainder;
+};
+
+class BlockWriter
+{
+    public:
+        BlockWriter(const IOVec& ioVec);
+
+        bool write(int socket);
+        bool done() const { return m_remainder == 0; }
+        size_t remainder() const { return m_remainder; }
+
+    private:
+        IOVec m_source;
+        size_t m_remainder;
+};
+
+} // namespace STG
+
+#endif
index 8e82d2a84e17b099ef63a8b0ad288cc3bdc66c06..d404e013a824332976699956b036392f701d579b 100644 (file)
 #ifndef common_h
 #define common_h
 
-#ifdef __BORLANDC__
-#include <time.h>
-#else
-#include <ctime>
-#include <climits> // NAME_MAX
-#endif
+#include "stg/os_int.h"
+#include "stg/const.h"
+
 #include <string>
 #include <sstream>
+#include <ctime>
+#include <climits> // NAME_MAX
 
-#include "stg/os_int.h"
-#include "stg/const.h"
+#include <unistd.h> // uid_t, gid_t
+#include <sys/stat.h> // mode_t
 
 #define STAT_TIME_3         (1)
 #define STAT_TIME_2         (2)
@@ -303,4 +302,8 @@ const std::string & unsigned2str(varT x, std::string & s)
 char * stg_strptime(const char *, const char *, struct tm *);
 time_t stg_timegm(struct tm *);
 
+uid_t str2uid(const std::string& name);
+gid_t str2gid(const std::string& name);
+mode_t str2mode(const std::string& mode);
+
 #endif
diff --git a/stglibs/json.lib/Makefile b/stglibs/json.lib/Makefile
new file mode 100644 (file)
index 0000000..947fce4
--- /dev/null
@@ -0,0 +1,18 @@
+###############################################################################
+# $Id: Makefile,v 1.9 2010/08/18 07:47:03 faust Exp $
+###############################################################################
+
+LIB_NAME = stgjson
+
+STGLIBS = -lstgcommon
+LIBS =
+
+SRCS =  parser.cpp \
+        generator.cpp
+
+INCS = json_parser.h \
+       json_generator.h
+
+LIB_INCS = -I ../common.lib/include
+
+include ../Makefile.in
diff --git a/stglibs/json.lib/generator.cpp b/stglibs/json.lib/generator.cpp
new file mode 100644 (file)
index 0000000..d18ef04
--- /dev/null
@@ -0,0 +1,82 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#include "stg/json_generator.h"
+
+#include <yajl/yajl_gen.h>
+
+using STG::JSON::NullGen;
+using STG::JSON::BoolGen;
+using STG::JSON::StringGen;
+using STG::JSON::NumberGen;
+using STG::JSON::MapGen;
+using STG::JSON::ArrayGen;
+using STG::JSON::Callback;
+
+namespace
+{
+
+void genString(yajl_gen_t* handle, const std::string& value)
+{
+    yajl_gen_string(handle, reinterpret_cast<const unsigned char*>(value.c_str()), value.length());
+}
+
+}
+
+void NullGen::run(yajl_gen_t* handle) const { yajl_gen_null(handle); }
+void BoolGen::run(yajl_gen_t* handle) const { yajl_gen_bool(handle, m_value); }
+void StringGen::run(yajl_gen_t* handle) const { genString(handle, m_value); }
+void NumberGen::run(yajl_gen_t* handle) const { yajl_gen_number(handle, m_value.c_str(), m_value.length()); }
+
+void MapGen::run(yajl_gen_t* handle) const
+{
+    yajl_gen_map_open(handle);
+    for (Value::const_iterator it = m_value.begin(); it != m_value.end(); ++it)
+    {
+        genString(handle, it->first);
+        it->second.first->run(handle);
+    }
+    yajl_gen_map_close(handle);
+}
+
+void ArrayGen::run(yajl_gen_t* handle) const
+{
+    yajl_gen_array_open(handle);
+    for (Value::const_iterator it = m_value.begin(); it != m_value.end(); ++it)
+        it->first->run(handle);
+    yajl_gen_array_close(handle);
+}
+
+bool STG::JSON::generate(Gen& gen, Callback callback, void* data)
+{
+    yajl_gen handle = yajl_gen_alloc(NULL);
+
+    gen.run(handle);
+
+    const unsigned char* buf = NULL;
+    size_t size = 0;
+    yajl_gen_get_buf(handle, &buf, &size);
+
+    bool res = callback(data, reinterpret_cast<const char*>(buf), size);
+
+    yajl_gen_free(handle);
+
+    return res;
+}
diff --git a/stglibs/json.lib/include/stg/json_generator.h b/stglibs/json.lib/include/stg/json_generator.h
new file mode 100644 (file)
index 0000000..4f1523f
--- /dev/null
@@ -0,0 +1,122 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#ifndef __STG_STGLIBS_JSON_GENERATOR_H__
+#define __STG_STGLIBS_JSON_GENERATOR_H__
+
+#include <string>
+#include <map>
+#include <vector>
+#include <utility>
+
+#include <boost/scoped_ptr.hpp>
+
+struct yajl_gen_t;
+
+namespace STG
+{
+namespace JSON
+{
+
+struct Gen
+{
+    virtual ~Gen() {}
+    virtual void run(yajl_gen_t* handle) const = 0;
+};
+
+struct NullGen : public Gen
+{
+    virtual void run(yajl_gen_t* handle) const;
+};
+
+class BoolGen : public Gen
+{
+    public:
+        explicit BoolGen(bool value) : m_value(value) {}
+        virtual void run(yajl_gen_t* handle) const;
+    private:
+        bool m_value;
+};
+
+class StringGen : public Gen
+{
+    public:
+        explicit StringGen(const std::string& value) : m_value(value) {}
+        virtual void run(yajl_gen_t* handle) const;
+    private:
+        std::string m_value;
+};
+
+class NumberGen : public Gen
+{
+    public:
+        explicit NumberGen(const std::string& value) : m_value(value) {}
+        template <typename T>
+        explicit NumberGen(const T& value) : m_value(x2str(value)) {}
+        virtual void run(yajl_gen_t* handle) const;
+    private:
+        std::string m_value;
+};
+
+class MapGen : public Gen
+{
+    public:
+        MapGen() {}
+        virtual ~MapGen()
+        {
+            for (Value::iterator it = m_value.begin(); it != m_value.end(); ++it)
+                if (it->second.second)
+                    delete it->second.first;
+        }
+        MapGen& add(const std::string& key, Gen* value) { m_value[key] = std::make_pair(value, true); return *this; }
+        MapGen& add(const std::string& key, Gen& value) { m_value[key] = std::make_pair(&value, false); return *this; }
+        virtual void run(yajl_gen_t* handle) const;
+    private:
+        typedef std::pair<Gen*, bool> SmartGen;
+        typedef std::map<std::string, SmartGen> Value;
+        Value m_value;
+};
+
+class ArrayGen : public Gen
+{
+    public:
+        ArrayGen() {}
+        virtual ~ArrayGen()
+        {
+            for (Value::iterator it = m_value.begin(); it != m_value.end(); ++it)
+                if (it->second)
+                    delete it->first;
+        }
+        void add(Gen* value) { m_value.push_back(std::make_pair(value, true)); }
+        void add(Gen& value) { m_value.push_back(std::make_pair(&value, false)); }
+        virtual void run(yajl_gen_t* handle) const;
+    private:
+        typedef std::pair<Gen*, bool> SmartGen;
+        typedef std::vector<SmartGen> Value;
+        Value m_value;
+};
+
+typedef bool (*Callback)(void* /*data*/, const char* /*buf*/, size_t /*size*/);
+bool generate(Gen& gen, Callback callback, void* data);
+
+}
+}
+
+#endif
diff --git a/stglibs/json.lib/include/stg/json_parser.h b/stglibs/json.lib/include/stg/json_parser.h
new file mode 100644 (file)
index 0000000..a614257
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#ifndef __STG_STGLIBS_JSON_PARSER_H__
+#define __STG_STGLIBS_JSON_PARSER_H__
+
+#include "stg/common.h"
+
+#include <string>
+#include <map>
+
+#include <boost/scoped_ptr.hpp>
+
+namespace STG
+{
+namespace JSON
+{
+
+struct NodeParser
+{
+    virtual ~NodeParser() {}
+
+    virtual NodeParser* parseNull() { return this; }
+    virtual NodeParser* parseBoolean(const bool& /*value*/) { return this; }
+    virtual NodeParser* parseNumber(const std::string& /*value*/) { return this; }
+    virtual NodeParser* parseString(const std::string& /*value*/) { return this; }
+    virtual NodeParser* parseStartMap() { return this; }
+    virtual NodeParser* parseMapKey(const std::string& /*value*/) { return this; }
+    virtual NodeParser* parseEndMap() { return this; }
+    virtual NodeParser* parseStartArray() { return this; }
+    virtual NodeParser* parseEndArray() { return this; }
+};
+
+class Parser
+{
+    public:
+        explicit Parser(NodeParser* topParser);
+        virtual ~Parser();
+
+        bool append(const char* data, size_t size);
+        bool last();
+
+    private:
+        class Impl;
+        boost::scoped_ptr<Impl> m_impl;
+};
+
+template <typename T>
+class EnumParser : public NodeParser
+{
+    public:
+        typedef std::map<std::string, T> Codes;
+        EnumParser(NodeParser* next, T& data, std::string& dataStr, const Codes& codes)
+            : m_next(next), m_data(data), m_dataStr(dataStr), m_codes(codes) {}
+        virtual NodeParser* parseString(const std::string& value)
+        {
+            m_dataStr = value;
+            const typename Codes::const_iterator it = m_codes.find(ToLower(value));
+            if (it != m_codes.end())
+                m_data = it->second;
+            return m_next;
+        }
+    private:
+        NodeParser* m_next;
+        T& m_data;
+        std::string& m_dataStr;
+        const Codes& m_codes;
+};
+
+class PairsParser : public NodeParser
+{
+    public:
+        typedef std::map<std::string, std::string> Pairs;
+
+        PairsParser(NodeParser* next, Pairs& pairs) : m_next(next), m_pairs(pairs) {}
+
+        virtual NodeParser* parseStartMap() { return this; }
+        virtual NodeParser* parseString(const std::string& value) { m_pairs[m_key] = value; return this; }
+        virtual NodeParser* parseMapKey(const std::string& value) { m_key = value; return this; }
+        virtual NodeParser* parseEndMap() { return m_next; }
+    private:
+        NodeParser* m_next;
+        Pairs& m_pairs;
+        std::string m_key;
+};
+
+}
+}
+
+#endif
diff --git a/stglibs/json.lib/parser.cpp b/stglibs/json.lib/parser.cpp
new file mode 100644 (file)
index 0000000..5711504
--- /dev/null
@@ -0,0 +1,123 @@
+/*
+ *    This program is free software; you can redistribute it and/or modify
+ *    it under the terms of the GNU General Public License as published by
+ *    the Free Software Foundation; either version 2 of the License, or
+ *    (at your option) any later version.
+ *
+ *    This program is distributed in the hope that it will be useful,
+ *    but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *    GNU General Public License for more details.
+ *
+ *    You should have received a copy of the GNU General Public License
+ *    along with this program; if not, write to the Free Software
+ *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+ */
+
+/*
+ *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
+ */
+
+#include "stg/json_parser.h"
+
+#include <yajl/yajl_parse.h>
+
+using STG::JSON::Parser;
+using STG::JSON::NodeParser;
+
+class Parser::Impl
+{
+    public:
+        Impl(NodeParser* topParser);
+        ~Impl()
+        {
+            yajl_free(m_handle);
+        }
+
+        bool append(const char* data, size_t size) { return yajl_parse(m_handle, reinterpret_cast<const unsigned char*>(data), size) == yajl_status_ok; }
+        bool last() { return yajl_complete_parse(m_handle) == yajl_status_ok; }
+
+        static int parseNull(void* ctx)
+        { return runParser(ctx, &NodeParser::parseNull); }
+        static int parseBoolean(void* ctx, int value)
+        { return runParser(ctx, &NodeParser::parseBoolean, value != 0); }
+        static int parseNumber(void* ctx, const char* value, size_t size)
+        { return runParser(ctx, &NodeParser::parseNumber, std::string(value, size)); }
+        static int parseString(void* ctx, const unsigned char* value, size_t size)
+        { return runParser(ctx, &NodeParser::parseString, std::string(reinterpret_cast<const char*>(value), size)); }
+        static int parseStartMap(void* ctx)
+        { return runParser(ctx, &NodeParser::parseStartMap); }
+        static int parseMapKey(void* ctx, const unsigned char* value, size_t size)
+        { return runParser(ctx, &NodeParser::parseMapKey, std::string(reinterpret_cast<const char*>(value), size)); }
+        static int parseEndMap(void* ctx)
+        { return runParser(ctx, &NodeParser::parseEndMap); }
+        static int parseStartArray(void* ctx)
+        { return runParser(ctx, &NodeParser::parseStartArray); }
+        static int parseEndArray(void* ctx)
+        { return runParser(ctx, &NodeParser::parseEndArray); }
+
+    private:
+        yajl_handle m_handle;
+        NodeParser* m_parser;
+
+        static yajl_callbacks callbacks;
+
+        static NodeParser& getParser(void* ctx) { return *static_cast<Impl*>(ctx)->m_parser; }
+        static bool runParser(void* ctx, NodeParser* (NodeParser::*func)())
+        {
+            Impl& p = *static_cast<Impl*>(ctx);
+            NodeParser* next = (p.m_parser->*func)();
+            if (next != NULL)
+                p.m_parser = next;
+            return next != NULL;
+        }
+        template <typename T>
+        static bool runParser(void* ctx, NodeParser* (NodeParser::*func)(const T&), const T& value)
+        {
+            Impl& p = *static_cast<Impl*>(ctx);
+            NodeParser* next = (p.m_parser->*func)(value);
+            if (next != NULL)
+                p.m_parser = next;
+            return next != NULL;
+        }
+};
+
+yajl_callbacks Parser::Impl::callbacks = {
+    Parser::Impl::parseNull,
+    Parser::Impl::parseBoolean,
+    NULL, // parsing of integer is done using parseNumber
+    NULL, // parsing of double is done using parseNumber
+    Parser::Impl::parseNumber,
+    Parser::Impl::parseString,
+    Parser::Impl::parseStartMap,
+    Parser::Impl::parseMapKey,
+    Parser::Impl::parseEndMap,
+    Parser::Impl::parseStartArray,
+    Parser::Impl::parseEndArray
+};
+
+Parser::Impl::Impl(NodeParser* topParser)
+    : m_handle(yajl_alloc(&callbacks, NULL, this)),
+      m_parser(topParser)
+{
+    yajl_config(m_handle, yajl_allow_multiple_values, 1);
+}
+
+Parser::Parser(NodeParser* topParser)
+    : m_impl(new Impl(topParser))
+{
+}
+
+Parser::~Parser()
+{
+}
+
+bool Parser::append(const char* data, size_t size)
+{
+    return m_impl->append(data, size);
+}
+
+bool Parser::last()
+{
+    return m_impl->last();
+}