From: Maxim Mamontov Date: Sun, 18 Oct 2015 14:49:24 +0000 (+0300) Subject: Merge branch 'stg-2.409' into stg-2.409-radius X-Git-Url: https://git.stg.codes/stg.git/commitdiff_plain/dd226cecb0cb5dad554ed3adf7bf60d2eb059a30?hp=5056c7c504cad88d663296454d9f67a96ea2fbdd Merge branch 'stg-2.409' into stg-2.409-radius --- diff --git a/include/stg/locker.h b/include/stg/locker.h index fe1014d7..16b0323f 100644 --- a/include/stg/locker.h +++ b/include/stg/locker.h @@ -34,7 +34,12 @@ class STG_LOCKER { public: - STG_LOCKER(pthread_mutex_t * m) + explicit STG_LOCKER(pthread_mutex_t& m) + : mutex(&m) + { + pthread_mutex_lock(mutex); + } + explicit STG_LOCKER(pthread_mutex_t * m) : mutex(m) { pthread_mutex_lock(mutex); diff --git a/include/stg/module_settings.h b/include/stg/module_settings.h index a1baf33c..5af5eefc 100644 --- a/include/stg/module_settings.h +++ b/include/stg/module_settings.h @@ -14,9 +14,15 @@ //----------------------------------------------------------------------------- struct PARAM_VALUE { - PARAM_VALUE() - : param(), - value() + PARAM_VALUE() {} + PARAM_VALUE(const std::string& p, const std::vector& vs) + : param(p), + value(vs) + {} + PARAM_VALUE(const std::string& p, const std::vector& vs, const std::vector& ss) + : param(p), + value(vs), + sections(ss) {} bool operator==(const PARAM_VALUE & rhs) const { return !strcasecmp(param.c_str(), rhs.param.c_str()); } @@ -26,6 +32,7 @@ struct PARAM_VALUE std::string param; std::vector value; + std::vector sections; }; //----------------------------------------------------------------------------- struct MODULE_SETTINGS @@ -34,9 +41,9 @@ struct MODULE_SETTINGS : moduleName(), moduleParams() {} - MODULE_SETTINGS(const MODULE_SETTINGS & rvalue) - : moduleName(rvalue.moduleName), - moduleParams(rvalue.moduleParams) + MODULE_SETTINGS(const std::string& name, const std::vector& params) + : moduleName(name), + moduleParams(params) {} bool operator==(const MODULE_SETTINGS & rhs) const { return !strcasecmp(moduleName.c_str(), rhs.moduleName.c_str()); } diff --git a/include/stg/ref.h b/include/stg/ref.h new file mode 100644 index 00000000..e1bdb54d --- /dev/null +++ b/include/stg/ref.h @@ -0,0 +1,54 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +namespace STG +{ + +// The implementation is derived from +// http://en.cppreference.com/w/cpp/utility/functional/reference_wrapper +// and +// http://en.cppreference.com/w/cpp/memory/addressof + +template< class T > +T* AddressOf(T& arg) +{ + return reinterpret_cast(&const_cast(reinterpret_cast(arg))); +} + +template +class RefWrapper { + public: + typedef T type; + RefWrapper(T& ref) throw() : m_ptr(AddressOf(ref)) {} + RefWrapper(const RefWrapper& rhs) : m_ptr(rhs.m_ptr) {} + RefWrapper& operator=(const RefWrapper& rhs) throw() { m_ptr = rhs.m_ptr; } + operator T& () const throw() { return *m_ptr; } + T& get() const throw() { return *m_ptr; } + + void operator()() { (*m_ptr)(); } + template + void operator()(A a) { (*m_ptr)(a); } + template + void operator()(A1 a1, A2 a2) { (*m_ptr)(a1, a2); } + private: + T* m_ptr; +}; + +} // namespace STG diff --git a/projects/rlm_stg/Makefile b/projects/rlm_stg/Makefile index 05c43d95..27f486b4 100644 --- a/projects/rlm_stg/Makefile +++ b/projects/rlm_stg/Makefile @@ -8,12 +8,15 @@ LIB_NAME = rlm_stg PROG = $(LIB_NAME).so -SRCS = ./rlm_stg.c \ - ./iface.cpp \ - ./stg_client.cpp +SRCS = rlm_stg.c \ + iface.cpp \ + stg_client.cpp \ + conn.cpp \ + radlog.c STGLIBS = crypto \ - common + common \ + json STGLIBS_INCS = $(addprefix -I ../../stglibs/,$(addsuffix .lib/include,$(STGLIBS))) STGLIBS_LIBS = $(addprefix -L ../../stglibs/,$(addsuffix .lib,$(STGLIBS))) @@ -21,16 +24,19 @@ STGLIBS_LIBS = $(addprefix -L ../../stglibs/,$(addsuffix .lib,$(STGLIBS))) LIBS += $(addprefix -lstg,$(STGLIBS)) $(LIB_THREAD) $(LIBICONV) ifeq ($(OS),linux) -LIBS += -ldl +LIBS += -ldl \ + -lyajl else LIBS += -lintl \ - -lc + -lc \ + -lyajl endif SEARCH_DIRS = -I ../../include OBJS = $(notdir $(patsubst %.cpp, %.o, $(patsubst %.c, %.o, $(SRCS)))) +CFLAGS += -fPIC $(DEFS) $(STGLIBS_INCS) $(SEARCH_DIRS) CXXFLAGS += -fPIC $(DEFS) $(STGLIBS_INCS) $(SEARCH_DIRS) CFLAGS += $(DEFS) $(STGLIBS_INCS) $(SEARCH_DIRS) LDFLAGS += -shared -Wl,-rpath,$(PREFIX)/usr/lib/stg -Wl,-E $(STGLIBS_LIBS) diff --git a/projects/rlm_stg/build b/projects/rlm_stg/build index 179d5261..8d63ae64 100755 --- a/projects/rlm_stg/build +++ b/projects/rlm_stg/build @@ -20,6 +20,7 @@ if [ "$1" = "debug" ] then DEFS="$DEFS -DDEBUG" MAKEOPTS="$MAKEOPTS -j1" + CFLAGS="$CFLAGS -ggdb3 -W -Wall -Wextra" CXXFLAGS="$CXXFLAGS -ggdb3 -W -Wall -Wextra" DEBUG="yes" else @@ -27,6 +28,7 @@ else DEBUG="no" fi +CFLAGS="$CFLAGS -I/usr/local/include" CXXFLAGS="$CXXFLAGS -I/usr/local/include" LDFLAGS="$LDFLAGS -L/usr/local/lib" @@ -61,7 +63,7 @@ printf "######################################################################## printf " Building rlm_stg for $sys $release\n" printf "#############################################################################\n" -STG_LIBS="crypto.lib common.lib" +STG_LIBS="crypto.lib common.lib json.lib" if [ "$OS" = "linux" ] then @@ -150,6 +152,68 @@ else fi rm -f fake +printf "Checking for -lyajl... " +pkg-config --version > /dev/null 2> /dev/null +if [ "$?" = "0" ] +then + pkg-config --atleast-version=2.0.0 yajl + if [ "$?" != "0" ] + then + CHECK_YAJL=no + printf "no\n" + exit; + else + CHECK_YAJL=yes + printf `pkg-config --modversion yajl`"\n" + fi +else + printf "#include \n" > build_check.c + printf "#include \n" >> build_check.c + printf "int main() { printf(\"%%d\", yajl_version()); return 0; }\n" >> build_check.c + $CC $CFLAGS $LDFLAGS build_check.c -lyajl -o fake > /dev/null 2> /dev/null + if [ $? != 0 ] + then + CHECK_YAJL=no + printf "no\n" + exit; + else + YAJL_VERSION=`./fake` + if [ $YAJL_VERSION -ge 20000 ] + then + CHECK_YAJL=yes + printf "${YAJL_VERSION}\n" + else + CHECK_YAJL=no + printf "no. Need at least version 2.0.0, existing version is ${YAJL_VERSION}\n" + exit; + fi + fi + rm -f fake +fi + +printf "Checking for boost::scoped_ptr... " +printf "#include \nint main() { boost::scoped_ptr test(new int(1)); return 0; }\n" > build_check.cpp +$CXX $CXXFLAGS $LDFLAGS build_check.cpp -o fake # > /dev/null 2> /dev/null +if [ $? != 0 ] +then + CHECK_BOOST_SCOPED_PTR=no + printf "no\n" + exit; +else + CHECK_BOOST_SCOPED_PTR=yes + printf "yes\n" +fi +rm -f fake + +rm -f build_check.c +rm -f build_check.cpp + +if [ "$CHECK_YAJL" = "yes" -a "$CHECK_BOOST_SCOPED_PTR" = "yes" ] +then + STG_LIBS="$STG_LIBS + json.lib" +fi + printf "OS=$OS\n" > $CONFFILE printf "STG_TIME=yes\n" >> $CONFFILE printf "DEBUG=$DEBUG\n" >> $CONFFILE @@ -165,6 +229,8 @@ do printf "$lib " >> $CONFFILE done printf "\n" >> $CONFFILE +printf "CC=$CC\n" >> $CONFFILE +printf "CXX=$CXX\n" >> $CONFFILE printf "CXXFLAGS=$CXXFLAGS\n" >> $CONFFILE printf "CFLAGS=$CFLAGS\n" >> $CONFFILE printf "LDFLAGS=$LDFLAGS\n" >> $CONFFILE diff --git a/projects/rlm_stg/conn.cpp b/projects/rlm_stg/conn.cpp new file mode 100644 index 00000000..3589a901 --- /dev/null +++ b/projects/rlm_stg/conn.cpp @@ -0,0 +1,686 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#include "conn.h" + +#include "radlog.h" +#include "stgpair.h" + +#include "stg/json_parser.h" +#include "stg/json_generator.h" +#include "stg/locker.h" + +#include +#include + +#include +#include +#include // UNIX +#include // IP +#include // TCP +#include + +namespace RLM = STG::RLM; + +using RLM::Conn; +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; + +namespace +{ + +double CONN_TIMEOUT = 60; +double PING_TIMEOUT = 10; + +struct ChannelConfig { + struct Error : std::runtime_error { + explicit Error(const std::string& message) : runtime_error(message) {} + }; + + explicit ChannelConfig(std::string address); + + std::string transport; + std::string key; + std::string address; + std::string portStr; + uint16_t port; +}; + +std::string toStage(RLM::REQUEST_TYPE type) +{ + switch (type) + { + case RLM::AUTHORIZE: return "authorize"; + case RLM::AUTHENTICATE: return "authenticate"; + case RLM::POST_AUTH: return "postauth"; + case RLM::PRE_ACCT: return "preacct"; + case RLM::ACCOUNT: return "accounting"; + } + return ""; +} + +enum Packet +{ + PING, + PONG, + DATA +}; + +std::map packetCodes; +std::map resultCodes; +std::map returnCodes; + +class PacketParser : public EnumParser +{ + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class ResultParser : public EnumParser +{ + public: + ResultParser(NodeParser* next, bool& result, std::string& resultStr) + : EnumParser(next, result, resultStr, resultCodes) + { + if (!resultCodes.empty()) + return; + resultCodes["no"] = false; + resultCodes["ok"] = true; + } +}; + +class ReturnCodeParser : public EnumParser +{ + public: + ReturnCodeParser(NodeParser* next, int& returnCode, std::string& returnCodeStr) + : EnumParser(next, returnCode, returnCodeStr, returnCodes) + { + if (!returnCodes.empty()) + return; + returnCodes["reject"] = STG_REJECT; + returnCodes["fail"] = STG_FAIL; + returnCodes["ok"] = STG_OK; + returnCodes["handled"] = STG_HANDLED; + returnCodes["invalid"] = STG_INVALID; + returnCodes["userlock"] = STG_USERLOCK; + returnCodes["notfound"] = STG_NOTFOUND; + returnCodes["noop"] = STG_NOOP; + returnCodes["updated"] = STG_UPDATED; + } +}; + +class TopParser : public NodeParser +{ + public: + typedef void (*Callback) (void* /*data*/); + TopParser(Callback callback, void* data) + : m_packet(PING), + m_result(false), + m_returnCode(STG_REJECT), + m_packetParser(this, m_packet, m_packetStr), + m_resultParser(this, m_result, m_resultStr), + m_returnCodeParser(this, m_returnCode, m_returnCodeStr), + m_replyParser(this, m_reply), + m_modifyParser(this, m_modify), + m_callback(callback), m_data(data) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) + { + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "result") + return &m_resultParser; + else if (key == "reply") + return &m_replyParser; + else if (key == "modify") + return &m_modifyParser; + else if (key == "return_code") + return &m_returnCodeParser; + + return this; + } + virtual NodeParser* parseEndMap() { m_callback(m_data); return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& resultStr() const { return m_resultStr; } + bool result() const { return m_result; } + const std::string& returnCodeStr() const { return m_returnCodeStr; } + int returnCode() const { return m_returnCode; } + const PairsParser::Pairs& reply() const { return m_reply; } + const PairsParser::Pairs& modify() const { return m_modify; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_resultStr; + bool m_result; + std::string m_returnCodeStr; + int m_returnCode; + PairsParser::Pairs m_reply; + PairsParser::Pairs m_modify; + + PacketParser m_packetParser; + ResultParser m_resultParser; + ReturnCodeParser m_returnCodeParser; + PairsParser m_replyParser; + PairsParser m_modifyParser; + + Callback m_callback; + void* m_data; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser(TopParser::Callback callback, void* data) + : Parser( &m_topParser ), + m_topParser(callback, data) + {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& resultStr() const { return m_topParser.resultStr(); } + bool result() const { return m_topParser.result(); } + const std::string& returnCodeStr() const { return m_topParser.returnCodeStr(); } + int returnCode() const { return m_topParser.returnCode(); } + const PairsParser::Pairs& reply() const { return m_topParser.reply(); } + const PairsParser::Pairs& modify() const { return m_topParser.modify(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + explicit PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen& map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; + +} + +class Conn::Impl +{ +public: + Impl(const std::string& address, Callback callback, void* data); + ~Impl(); + + bool stop(); + bool connected() const { return m_connected; } + + bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs); + +private: + ChannelConfig m_config; + + int m_sock; + + bool m_running; + bool m_stopped; + + time_t m_lastPing; + time_t m_lastActivity; + + pthread_t m_thread; + pthread_mutex_t m_mutex; + + Callback m_callback; + void* m_data; + + ProtoParser m_parser; + + bool m_connected; + + void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password); + void m_writePairBlock(const PAIRS& source); + PAIRS m_readPairBlock(); + + static void* run(void* ); + + void runImpl(); + + bool start(); + + int connect(); + int connectTCP(); + int connectUNIX(); + + bool read(); + bool tick(); + + static void process(void* data); + void processPing(); + void processPong(); + void processData(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + +ChannelConfig::ChannelConfig(std::string addr) +{ + // unix:pass@/var/run/stg.sock + // tcp:secret@192.168.0.1:12345 + // udp:key@isp.com.ua:54321 + + size_t pos = addr.find_first_of(':'); + if (pos == std::string::npos) + throw Error("Missing transport name."); + transport = ToLower(addr.substr(0, pos)); + addr = addr.substr(pos + 1); + if (addr.empty()) + throw Error("Missing address to connect to."); + pos = addr.find_first_of('@'); + if (pos != std::string::npos) { + key = addr.substr(0, pos); + addr = addr.substr(pos + 1); + if (addr.empty()) + throw Error("Missing address to connect to."); + } + if (transport == "unix") + { + address = addr; + return; + } + pos = addr.find_first_of(':'); + if (pos == std::string::npos) + throw Error("Missing port."); + address = addr.substr(0, pos); + portStr = addr.substr(pos + 1); + if (str2x(portStr, port)) + throw Error("Invalid port value."); +} + +Conn::Conn(const std::string& address, Callback callback, void* data) + : m_impl(new Impl(address, callback, data)) +{ +} + +Conn::~Conn() +{ +} + +bool Conn::stop() +{ + return m_impl->stop(); +} + +bool Conn::connected() const +{ + return m_impl->connected(); +} + +bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) +{ + return m_impl->request(type, userName, password, pairs); +} + +Conn::Impl::Impl(const std::string& address, Callback callback, void* data) + : m_config(address), + m_sock(connect()), + m_running(false), + m_stopped(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing), + m_callback(callback), + m_data(data), + m_parser(&Conn::Impl::process, this), + m_connected(true) +{ + pthread_mutex_init(&m_mutex, NULL); +} + +Conn::Impl::~Impl() +{ + stop(); + shutdown(m_sock, SHUT_RDWR); + close(m_sock); + pthread_mutex_destroy(&m_mutex); +} + +bool Conn::Impl::stop() +{ + m_connected = false; + + if (m_stopped) + return true; + + m_running = false; + + for (size_t i = 0; i < 25 && !m_stopped; i++) { + struct timespec ts = {0, 200000000}; + nanosleep(&ts, NULL); + } + + if (m_stopped) { + pthread_join(m_thread, NULL); + return true; + } + + return false; +} + +bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) +{ + if (!m_running) + if (!start()) + return false; + MapGen map; + for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it) + map.add(it->first, new StringGen(it->second)); + map.add("Radius-Username", new StringGen(userName)); + map.add("Radius-Userpass", new StringGen(password)); + + PacketGen gen("data"); + gen.add("stage", toStage(type)) + .add("pairs", map); + + STG_LOCKER lock(m_mutex); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +void Conn::Impl::runImpl() +{ + m_running = true; + + while (m_running) { + fd_set fds; + + FD_ZERO(&fds); + FD_SET(m_sock, &fds); + + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 500000; + + int res = select(m_sock + 1, &fds, NULL, NULL, &tv); + if (res < 0) + { + if (errno == EINTR) + continue; + RadLog("'select' is failed: %s", strerror(errno)); + break; + } + + + if (!m_running) + break; + + STG_LOCKER lock(m_mutex); + + if (res > 0) + { + if (FD_ISSET(m_sock, &fds)) + m_running = read(); + } + else + m_running = tick(); + } + + m_connected = false; + m_stopped = true; +} + +bool Conn::Impl::start() +{ + int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this); + if (res != 0) + return false; + return true; +} + +int Conn::Impl::connect() +{ + if (m_config.transport == "tcp") + return connectTCP(); + else if (m_config.transport == "unix") + return connectUNIX(); + throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'."); +} + +int Conn::Impl::connectTCP() +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_family = AF_INET; /* Allow IPv4 */ + hints.ai_socktype = SOCK_STREAM; /* Stream socket */ + hints.ai_flags = 0; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + addrinfo* ais = NULL; + int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais); + if (res != 0) + throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res)); + + for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next) + { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) + { + Error error(std::string("Error creating TCP socket: ") + strerror(errno)); + freeaddrinfo(ais); + throw error; + } + if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + RadLog("'connect' is failed: %s", strerror(errno)); + continue; + } + freeaddrinfo(ais); + return fd; + } + + freeaddrinfo(ais); + + throw Error("Failed to resolve '" + m_config.address); +}; + +int Conn::Impl::connectUNIX() +{ + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) + throw Error(std::string("Error creating UNIX socket: ") + strerror(errno)); + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length()); + if (::connect(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) + { + Error error(std::string("Error connecting UNIX socket: ") + strerror(errno)); + shutdown(fd, SHUT_RDWR); + close(fd); + throw error; + } + return fd; +} + +bool Conn::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); + if (res < 0) + { + RadLog("Failed to read data: %s", strerror(errno)); + return false; + } + m_lastActivity = time(NULL); + RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str()); + if (res == 0) + { + m_parser.last(); + return false; + } + return m_parser.append(buffer.data(), res); +} + +bool Conn::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) + { + int delta = difftime(now, m_lastActivity); + RadLog("Connection timeout: %d sec.", delta); + //m_logger("Connection to " + m_remote + " timed out."); + return false; + } + if (difftime(now, m_lastPing) > PING_TIMEOUT) + { + int delta = difftime(now, m_lastPing); + RadLog("Ping timeout: %d sec. Sending ping...", delta); + sendPing(); + } + return true; +} + +void Conn::Impl::process(void* data) +{ + Impl& impl = *static_cast(data); + switch (impl.m_parser.packet()) + { + case PING: + impl.processPing(); + return; + case PONG: + impl.processPong(); + return; + case DATA: + impl.processData(); + return; + } + RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str()); +} + +void Conn::Impl::processPing() +{ + sendPong(); +} + +void Conn::Impl::processPong() +{ + m_lastActivity = time(NULL); +} + +void Conn::Impl::processData() +{ + RESULT data; + if (m_parser.result()) + { + for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it) + data.reply.push_back(std::make_pair(it->first, it->second)); + for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it) + data.modify.push_back(std::make_pair(it->first, it->second)); + data.returnCode = STG_UPDATED; + } + else + data.returnCode = m_parser.returnCode(); + m_callback(m_data, data); +} + +bool Conn::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPong() +{ + PacketGen gen("pong"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::write(void* data, const char* buf, size_t size) +{ + std::string json(buf, size); + RadLog("Sending JSON: %s", json.c_str()); + Conn::Impl& impl = *static_cast(data); + while (size > 0) + { + ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL); + if (res < 0) + { + impl.m_connected = false; + RadLog("Failed to write data: %s.", strerror(errno)); + return false; + } + size -= res; + } + return true; +} + +void* Conn::Impl::run(void* data) +{ + Impl& impl = *static_cast(data); + impl.runImpl(); + return NULL; +} diff --git a/projects/rlm_stg/conn.h b/projects/rlm_stg/conn.h new file mode 100644 index 00000000..6233b154 --- /dev/null +++ b/projects/rlm_stg/conn.h @@ -0,0 +1,63 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#ifndef __STG_RLM_CONN_H__ +#define __STG_RLM_CONN_H__ + +#include "types.h" + +#include "stg/os_int.h" + +#include + +#include +#include + +namespace STG +{ +namespace RLM +{ + +class Conn +{ + public: + struct Error : std::runtime_error { + explicit Error(const std::string& message) : runtime_error(message) {} + }; + + typedef bool (*Callback)(void* /*data*/, const RESULT& /*result*/); + + Conn(const std::string& address, Callback callback, void* data); + ~Conn(); + + bool stop(); + bool connected() const; + + bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs); + + private: + class Impl; + boost::scoped_ptr m_impl; +}; + +} +} + +#endif diff --git a/projects/rlm_stg/event.h b/projects/rlm_stg/event.h deleted file mode 100644 index 83a72d01..00000000 --- a/projects/rlm_stg/event.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef FR_EVENT_H -#define FR_EVENT_H - -/* - * event.h Simple event queue - * - * Version: $Id: event.h,v 1.1 2010/08/14 04:13:52 faust Exp $ - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA - * - * Copyright 2007 The FreeRADIUS server project - * Copyright 2007 Alan DeKok - */ - -//#include -//RCSIDH(event_h, "$Id: event.h,v 1.1 2010/08/14 04:13:52 faust Exp $") - -typedef struct fr_event_list_t fr_event_list_t; -typedef struct fr_event_t fr_event_t; - -typedef void (*fr_event_callback_t)(void *); -typedef void (*fr_event_status_t)(struct timeval *); -typedef void (*fr_event_fd_handler_t)(fr_event_list_t *el, int sock, void *ctx); - -fr_event_list_t *fr_event_list_create(fr_event_status_t status); -void fr_event_list_free(fr_event_list_t *el); - -int fr_event_list_num_elements(fr_event_list_t *el); - -int fr_event_insert(fr_event_list_t *el, - fr_event_callback_t callback, - void *ctx, struct timeval *when, fr_event_t **ev_p); -int fr_event_delete(fr_event_list_t *el, fr_event_t **ev_p); - -int fr_event_run(fr_event_list_t *el, struct timeval *when); - -int fr_event_now(fr_event_list_t *el, struct timeval *when); - -int fr_event_fd_insert(fr_event_list_t *el, int type, int fd, - fr_event_fd_handler_t handler, void *ctx); -int fr_event_fd_delete(fr_event_list_t *el, int type, int fd); -int fr_event_loop(fr_event_list_t *el); -void fr_event_loop_exit(fr_event_list_t *el, int code); - -#endif /* FR_HASH_H */ diff --git a/projects/rlm_stg/iface.cpp b/projects/rlm_stg/iface.cpp index a74f3259..f97593f4 100644 --- a/projects/rlm_stg/iface.cpp +++ b/projects/rlm_stg/iface.cpp @@ -1,41 +1,121 @@ #include "iface.h" -#include "thriftclient.h" +#include "stg_client.h" +#include "types.h" +#include "radlog.h" -int stgInstantiateImpl(const char * server, uint16_t port, const char * password) +#include +#include + +#include + +namespace RLM = STG::RLM; + +using RLM::Client; +using RLM::PAIRS; +using RLM::RESULT; +using RLM::REQUEST_TYPE; + +namespace { - if (STG_CLIENT_ST::Get().Configure(server, port, password)) - return 1; - return 0; +STG_PAIR* toSTGPairs(const PAIRS& source) +{ + STG_PAIR * pairs = new STG_PAIR[source.size() + 1]; + for (size_t pos = 0; pos < source.size(); ++pos) { + bzero(pairs[pos].key, sizeof(pairs[pos].key)); + bzero(pairs[pos].value, sizeof(pairs[pos].value)); + strncpy(pairs[pos].key, source[pos].first.c_str(), sizeof(pairs[pos].key)); + strncpy(pairs[pos].value, source[pos].second.c_str(), sizeof(pairs[pos].value)); + } + bzero(pairs[source.size()].key, sizeof(pairs[source.size()].key)); + bzero(pairs[source.size()].value, sizeof(pairs[source.size()].value)); + + return pairs; } -const STG_PAIR * stgAuthorizeImpl(const char * userName, const char * serviceType) +PAIRS fromSTGPairs(const STG_PAIR* pairs) { - return STG_CLIENT_ST::Get().Authorize(userName, serviceType); + const STG_PAIR* pair = pairs; + PAIRS res; + + while (!emptyPair(pair)) { + res.push_back(std::pair(pair->key, pair->value)); + ++pair; + } + + return res; +} + +STG_RESULT toResult(const RESULT& source) +{ + STG_RESULT result; + result.modify = toSTGPairs(source.modify); + result.reply = toSTGPairs(source.reply); + result.returnCode = source.returnCode; + return result; } -const STG_PAIR * stgAuthenticateImpl(const char * userName, const char * serviceType) +STG_RESULT emptyResult() { - return STG_CLIENT_ST::Get().Authenticate(userName, serviceType); + STG_RESULT result = {NULL, NULL, STG_REJECT}; + return result; } -const STG_PAIR * stgPostAuthImpl(const char * userName, const char * serviceType) +std::string toString(const char* value) { - return STG_CLIENT_ST::Get().PostAuth(userName, serviceType); + if (value == NULL) + return ""; + else + return value; } -/*const STG_PAIR * stgPreAcctImpl(const char * userName, const char * serviceType) +STG_RESULT stgRequest(REQUEST_TYPE type, const char* userName, const char* password, const STG_PAIR* pairs) { - return STG_CLIENT_ST::Get().PreAcct(userName, serviceType); -}*/ + Client* client = Client::get(); + if (client == NULL) { + RadLog("Client is not configured."); + return emptyResult(); + } + try { + return toResult(client->request(type, toString(userName), toString(password), fromSTGPairs(pairs))); + } catch (const std::runtime_error& ex) { + RadLog("Error: '%s'.", ex.what()); + return emptyResult(); + } +} + +} + +int stgInstantiateImpl(const char* address) +{ + if (Client::configure(toString(address))) + return 1; + + return 0; +} + +STG_RESULT stgAuthorizeImpl(const char* userName, const char* password, const STG_PAIR* pairs) +{ + return stgRequest(RLM::AUTHORIZE, userName, password, pairs); +} + +STG_RESULT stgAuthenticateImpl(const char* userName, const char* password, const STG_PAIR* pairs) +{ + return stgRequest(RLM::AUTHENTICATE, userName, password, pairs); +} + +STG_RESULT stgPostAuthImpl(const char* userName, const char* password, const STG_PAIR* pairs) +{ + return stgRequest(RLM::POST_AUTH, userName, password, pairs); +} -const STG_PAIR * stgAccountingImpl(const char * userName, const char * serviceType, const char * statusType, const char * sessionId) +STG_RESULT stgPreAcctImpl(const char* userName, const char* password, const STG_PAIR* pairs) { - return STG_CLIENT_ST::Get().Account(userName, serviceType, statusType, sessionId); + return stgRequest(RLM::PRE_ACCT, userName, password, pairs); } -void deletePairs(const STG_PAIR * pairs) +STG_RESULT stgAccountingImpl(const char* userName, const char* password, const STG_PAIR* pairs) { - delete[] pairs; + return stgRequest(RLM::ACCOUNT, userName, password, pairs); } diff --git a/projects/rlm_stg/iface.h b/projects/rlm_stg/iface.h index 831c3123..e863e939 100644 --- a/projects/rlm_stg/iface.h +++ b/projects/rlm_stg/iface.h @@ -9,14 +9,12 @@ extern "C" { #endif -int stgInstantiateImpl(const char * server, uint16_t port, const char * password); -const STG_PAIR * stgAuthorizeImpl(const char * userName, const char * serviceType); -const STG_PAIR * stgAuthenticateImpl(const char * userName, const char * serviceType); -const STG_PAIR * stgPostAuthImpl(const char * userName, const char * serviceType); -/*const STG_PAIR * stgPreAcctImpl(const char * userName, const char * serviceType);*/ -const STG_PAIR * stgAccountingImpl(const char * userName, const char * serviceType, const char * statusType, const char * sessionId); - -void deletePairs(const STG_PAIR * pairs); +int stgInstantiateImpl(const char* address); +STG_RESULT stgAuthorizeImpl(const char* userName, const char* password, const STG_PAIR* vps); +STG_RESULT stgAuthenticateImpl(const char* userName, const char* password, const STG_PAIR* vps); +STG_RESULT stgPostAuthImpl(const char* userName, const char* password, const STG_PAIR* vps); +STG_RESULT stgPreAcctImpl(const char* userName, const char* password, const STG_PAIR* vps); +STG_RESULT stgAccountingImpl(const char* userName, const char* password, const STG_PAIR* vps); #ifdef __cplusplus } diff --git a/projects/rlm_stg/radlog.c b/projects/rlm_stg/radlog.c new file mode 100644 index 00000000..523dc1c4 --- /dev/null +++ b/projects/rlm_stg/radlog.c @@ -0,0 +1,23 @@ +#include "radlog.h" + +//#ifndef NDEBUG +//#define NDEBUG +#include +#include +#include +//#undef NDEBUG +//#endif + +#include + +void RadLog(const char* format, ...) +{ + char buf[1024]; + + va_list vl; + va_start(vl, format); + vsnprintf(buf, sizeof(buf), format, vl); + va_end(vl); + + DEBUG("[rlm_stg] *** %s", buf); +} diff --git a/projects/rlm_stg/radlog.h b/projects/rlm_stg/radlog.h new file mode 100644 index 00000000..00a5dcb5 --- /dev/null +++ b/projects/rlm_stg/radlog.h @@ -0,0 +1,14 @@ +#ifndef __STG_RADLOG_H__ +#define __STG_RADLOG_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +void RadLog(const char* format, ...); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/projects/rlm_stg/rlm_stg.c b/projects/rlm_stg/rlm_stg.c index e1caf575..84df4e7b 100644 --- a/projects/rlm_stg/rlm_stg.c +++ b/projects/rlm_stg/rlm_stg.c @@ -26,32 +26,105 @@ * */ -#ifndef NDEBUG -#define NDEBUG +#include "iface.h" +#include "stgpair.h" + #include #include #include -#undef NDEBUG -#endif -#include "stgpair.h" -#include "iface.h" +#include // size_t typedef struct rlm_stg_t { - char * server; - uint16_t port; - char * password; + char* address; } rlm_stg_t; static const CONF_PARSER module_config[] = { - { "server", PW_TYPE_STRING_PTR, offsetof(rlm_stg_t,server), NULL, "localhost"}, - { "port", PW_TYPE_INTEGER, offsetof(rlm_stg_t,port), NULL, "9091" }, - { "password", PW_TYPE_STRING_PTR, offsetof(rlm_stg_t,password), NULL, "123456"}, + { "address", PW_TYPE_STRING_PTR, offsetof(rlm_stg_t, address), NULL, "unix:/var/run/stg.sock"}, { NULL, -1, 0, NULL, NULL } /* end the list */ }; -int emptyPair(const STG_PAIR * pair); +static void deletePairs(STG_PAIR* pairs) +{ + free(pairs); +} + +static size_t toVPS(const STG_PAIR* pairs, VALUE_PAIR** vps) +{ + const STG_PAIR* pair = pairs; + size_t count = 0; + + while (!emptyPair(pair)) { + VALUE_PAIR* vp = pairmake(pair->key, pair->value, T_OP_SET); + if (vp != NULL) { + pairadd(vps, vp); + ++count; + } + ++pair; + } + + return count; +} + +static size_t toReply(STG_RESULT result, REQUEST* request) +{ + size_t count = 0; + + count += toVPS(result.modify, &request->config_items); + pairfree(&request->reply->vps); + count += toVPS(result.reply, &request->reply->vps); + + deletePairs(result.modify); + deletePairs(result.reply); + + return count; +} + +static int countVPS(const VALUE_PAIR* pairs) +{ + unsigned count = 0; + while (pairs != NULL) { + ++count; + pairs = pairs->next; + } + return count; +} + +static STG_PAIR* fromVPS(const VALUE_PAIR* pairs) +{ + unsigned size = countVPS(pairs); + STG_PAIR* res = (STG_PAIR*)malloc(sizeof(STG_PAIR) * (size + 1)); + size_t pos = 0; + while (pairs != NULL) { + bzero(res[pos].key, sizeof(res[0].key)); + bzero(res[pos].value, sizeof(res[0].value)); + strncpy(res[pos].key, pairs->name, sizeof(res[0].key)); + vp_prints_value(res[pos].value, sizeof(res[0].value), (VALUE_PAIR*)pairs, 0); + ++pos; + pairs = pairs->next; + } + bzero(res[pos].key, sizeof(res[0].key)); + bzero(res[pos].value, sizeof(res[0].value)); + return res; +} + +static int toRLMCode(int code) +{ + switch (code) + { + case STG_REJECT: return RLM_MODULE_REJECT; + case STG_FAIL: return RLM_MODULE_FAIL; + case STG_OK: return RLM_MODULE_OK; + case STG_HANDLED: return RLM_MODULE_HANDLED; + case STG_INVALID: return RLM_MODULE_INVALID; + case STG_USERLOCK: return RLM_MODULE_USERLOCK; + case STG_NOTFOUND: return RLM_MODULE_NOTFOUND; + case STG_NOOP: return RLM_MODULE_NOOP; + case STG_UPDATED: return RLM_MODULE_UPDATED; + } + return RLM_MODULE_REJECT; +} /* * Do any per-module initialization that is separate to each @@ -63,17 +136,17 @@ int emptyPair(const STG_PAIR * pair); * that must be referenced in later calls, store a handle to it * in *instance otherwise put a null pointer there. */ -static int stg_instantiate(CONF_SECTION *conf, void **instance) +static int stg_instantiate(CONF_SECTION* conf, void** instance) { - rlm_stg_t *data; + rlm_stg_t* data; /* * Set up a storage area for instance data */ data = rad_malloc(sizeof(*data)); - if (!data) { + if (!data) return -1; - } + memset(data, 0, sizeof(*data)); /* @@ -85,7 +158,7 @@ static int stg_instantiate(CONF_SECTION *conf, void **instance) return -1; } - if (!stgInstantiateImpl(data->server, data->port)) { + if (!stgInstantiateImpl(data->address)) { free(data); return -1; } @@ -101,158 +174,165 @@ static int stg_instantiate(CONF_SECTION *conf, void **instance) * from the database. The authentication code only needs to check * the password, the rest is done here. */ -static int stg_authorize(void *, REQUEST *request) +static int stg_authorize(void* instance, REQUEST* request) { - const STG_PAIR * pairs; - const STG_PAIR * pair; + STG_RESULT result; + STG_PAIR* pairs = fromVPS(request->packet->vps); size_t count = 0; + const char* username = NULL; + const char* password = NULL; instance = instance; DEBUG("rlm_stg: stg_authorize()"); if (request->username) { - DEBUG("rlm_stg: stg_authorize() request username field: '%s'", request->username->vp_strvalue); + username = request->username->data.strvalue; + DEBUG("rlm_stg: stg_authorize() request username field: '%s'", username); } + if (request->password) { - DEBUG("rlm_stg: stg_authorize() request password field: '%s'", request->password->vp_strvalue); - } - // Here we need to define Framed-Protocol - VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE); - if (svc) { - DEBUG("rlm_stg: stg_authorize() Service-Type defined as '%s'", svc->vp_strvalue); - pairs = stgAuthorizeImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue); - } else { - DEBUG("rlm_stg: stg_authorize() Service-Type undefined"); - pairs = stgAuthorizeImpl((const char *)request->username->vp_strvalue, ""); + password = request->password->data.strvalue; + DEBUG("rlm_stg: stg_authorize() request password field: '%s'", password); } - if (!pairs) { + + result = stgAuthorizeImpl(username, password, pairs); + deletePairs(pairs); + + if (!result.modify && !result.reply) { DEBUG("rlm_stg: stg_authorize() failed."); return RLM_MODULE_REJECT; } - pair = pairs; - while (!emptyPair(pair)) { - VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET); - pairadd(&request->config_items, pwd); - DEBUG("Adding pair '%s': '%s'", pair->key, pair->value); - ++pair; - ++count; - } - deletePairs(pairs); + count = toReply(result, request); if (count) return RLM_MODULE_UPDATED; - return RLM_MODULE_NOOP; + return toRLMCode(result.returnCode); } /* * Authenticate the user with the given password. */ -static int stg_authenticate(void *, REQUEST *request) +static int stg_authenticate(void* instance, REQUEST* request) { - const STG_PAIR * pairs; - const STG_PAIR * pair; + STG_RESULT result; + STG_PAIR* pairs = fromVPS(request->packet->vps); size_t count = 0; + const char* username = NULL; + const char* password = NULL; instance = instance; DEBUG("rlm_stg: stg_authenticate()"); - VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE); - if (svc) { - DEBUG("rlm_stg: stg_authenticate() Service-Type defined as '%s'", svc->vp_strvalue); - pairs = stgAuthenticateImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue); - } else { - DEBUG("rlm_stg: stg_authenticate() Service-Type undefined"); - pairs = stgAuthenticateImpl((const char *)request->username->vp_strvalue, ""); + if (request->username) { + username = request->username->data.strvalue; + DEBUG("rlm_stg: stg_authenticate() request username field: '%s'", username); + } + + if (request->password) { + password = request->password->data.strvalue; + DEBUG("rlm_stg: stg_authenticate() request password field: '%s'", password); } - if (!pairs) { + + result = stgAuthenticateImpl(username, password, pairs); + deletePairs(pairs); + + if (!result.modify && !result.reply) { DEBUG("rlm_stg: stg_authenticate() failed."); return RLM_MODULE_REJECT; } - pair = pairs; - while (!emptyPair(pair)) { - VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET); - pairadd(&request->reply->vps, pwd); - ++pair; - ++count; - } - deletePairs(pairs); + count = toReply(result, request); if (count) return RLM_MODULE_UPDATED; - return RLM_MODULE_NOOP; + return toRLMCode(result.returnCode); } /* * Massage the request before recording it or proxying it */ -static int stg_preacct(void *, REQUEST *) +static int stg_preacct(void* instance, REQUEST* request) { + STG_RESULT result; + STG_PAIR* pairs = fromVPS(request->packet->vps); + size_t count = 0; + const char* username = NULL; + const char* password = NULL; + DEBUG("rlm_stg: stg_preacct()"); instance = instance; - return RLM_MODULE_OK; + if (request->username) { + username = request->username->data.strvalue; + DEBUG("rlm_stg: stg_preacct() request username field: '%s'", username); + } + + if (request->password) { + password = request->password->data.strvalue; + DEBUG("rlm_stg: stg_preacct() request password field: '%s'", password); + } + + result = stgPreAcctImpl(username, password, pairs); + deletePairs(pairs); + + if (!result.modify && !result.reply) { + DEBUG("rlm_stg: stg_preacct() failed."); + return RLM_MODULE_REJECT; + } + + count = toReply(result, request); + + if (count) + return RLM_MODULE_UPDATED; + + return toRLMCode(result.returnCode); } /* * Write accounting information to this modules database. */ -static int stg_accounting(void *, REQUEST * request) +static int stg_accounting(void* instance, REQUEST* request) { - const STG_PAIR * pairs; - const STG_PAIR * pair; + STG_RESULT result; + STG_PAIR* pairs = fromVPS(request->packet->vps); size_t count = 0; - - instance = instance; + const char* username = NULL; + const char* password = NULL; DEBUG("rlm_stg: stg_accounting()"); - VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE); - VALUE_PAIR * sessid = pairfind(request->packet->vps, PW_ACCT_SESSION_ID); - VALUE_PAIR * sttype = pairfind(request->packet->vps, PW_ACCT_STATUS_TYPE); + instance = instance; - if (!sessid) { - DEBUG("rlm_stg: stg_accounting() Acct-Session-ID undefined"); - return RLM_MODULE_FAIL; + if (request->username) { + username = request->username->data.strvalue; + DEBUG("rlm_stg: stg_accounting() request username field: '%s'", username); } - if (sttype) { - DEBUG("Acct-Status-Type := %s", sttype->vp_strvalue); - if (svc) { - DEBUG("rlm_stg: stg_accounting() Service-Type defined as '%s'", svc->vp_strvalue); - pairs = stgAccountingImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue, (const char *)sttype->vp_strvalue, (const char *)sessid->vp_strvalue); - } else { - DEBUG("rlm_stg: stg_accounting() Service-Type undefined"); - pairs = stgAccountingImpl((const char *)request->username->vp_strvalue, "", (const char *)sttype->vp_strvalue, (const char *)sessid->vp_strvalue); - } - } else { - DEBUG("rlm_stg: stg_accounting() Acct-Status-Type := NULL"); - return RLM_MODULE_OK; + if (request->password) { + password = request->password->data.strvalue; + DEBUG("rlm_stg: stg_accounting() request password field: '%s'", password); } - if (!pairs) { + + result = stgAccountingImpl(username, password, pairs); + deletePairs(pairs); + + if (!result.modify && !result.reply) { DEBUG("rlm_stg: stg_accounting() failed."); return RLM_MODULE_REJECT; } - pair = pairs; - while (!emptyPair(pair)) { - VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET); - pairadd(&request->reply->vps, pwd); - ++pair; - ++count; - } - deletePairs(pairs); + count = toReply(result, request); if (count) return RLM_MODULE_UPDATED; - return RLM_MODULE_OK; + return toRLMCode(result.returnCode); } /* @@ -265,86 +345,76 @@ static int stg_accounting(void *, REQUEST * request) * max. number of logins, do a second pass and validate all * logins by querying the terminal server (using eg. SNMP). */ -static int stg_checksimul(void *, REQUEST *request) +static int stg_checksimul(void* instance, REQUEST* request) { DEBUG("rlm_stg: stg_checksimul()"); instance = instance; - request->simul_count=0; + request->simul_count = 0; return RLM_MODULE_OK; } -static int stg_postauth(void *, REQUEST *request) +static int stg_postauth(void* instance, REQUEST* request) { - const STG_PAIR * pairs; - const STG_PAIR * pair; + STG_RESULT result; + STG_PAIR* pairs = fromVPS(request->packet->vps); size_t count = 0; - - instance = instance; + const char* username = NULL; + const char* password = NULL; DEBUG("rlm_stg: stg_postauth()"); - VALUE_PAIR * svc = pairfind(request->packet->vps, PW_SERVICE_TYPE); + instance = instance; + + if (request->username) { + username = request->username->data.strvalue; + DEBUG("rlm_stg: stg_postauth() request username field: '%s'", username); + } - if (svc) { - DEBUG("rlm_stg: stg_postauth() Service-Type defined as '%s'", svc->vp_strvalue); - pairs = stgPostAuthImpl((const char *)request->username->vp_strvalue, (const char *)svc->vp_strvalue); - } else { - DEBUG("rlm_stg: stg_postauth() Service-Type undefined"); - pairs = stgPostAuthImpl((const char *)request->username->vp_strvalue, ""); + if (request->password) { + password = request->password->data.strvalue; + DEBUG("rlm_stg: stg_postauth() request password field: '%s'", password); } - if (!pairs) { + + result = stgPostAuthImpl(username, password, pairs); + deletePairs(pairs); + + if (!result.modify && !result.reply) { DEBUG("rlm_stg: stg_postauth() failed."); return RLM_MODULE_REJECT; } - pair = pairs; - while (!emptyPair(pair)) { - VALUE_PAIR * pwd = pairmake(pair->key, pair->value, T_OP_SET); - pairadd(&request->reply->vps, pwd); - ++pair; - ++count; - } - deletePairs(pairs); + count = toReply(result, request); if (count) return RLM_MODULE_UPDATED; - return RLM_MODULE_NOOP; + return toRLMCode(result.returnCode); } -static int stg_detach(void *instance) +static int stg_detach(void* instance) { - free(((struct rlm_stg_t *)instance)->server); + free(((struct rlm_stg_t*)instance)->address); free(instance); return 0; } -/* - * The module name should be the only globally exported symbol. - * That is, everything else should be 'static'. - * - * If the module needs to temporarily modify it's instantiation - * data, the type should be changed to RLM_TYPE_THREAD_UNSAFE. - * The server will then take care of ensuring that the module - * is single-threaded. - */ module_t rlm_stg = { RLM_MODULE_INIT, "stg", - RLM_TYPE_THREAD_SAFE, /* type */ - stg_instantiate, /* instantiation */ - stg_detach, /* detach */ + RLM_TYPE_THREAD_UNSAFE, /* type */ + stg_instantiate, /* instantiation */ + stg_detach, /* detach */ { - stg_authenticate, /* authentication */ + stg_authenticate, /* authentication */ stg_authorize, /* authorization */ - stg_preacct, /* preaccounting */ - stg_accounting, /* accounting */ - stg_checksimul, /* checksimul */ - NULL, /* pre-proxy */ - NULL, /* post-proxy */ - stg_postauth /* post-auth */ + stg_preacct, /* preaccounting */ + stg_accounting, /* accounting */ + stg_checksimul, /* checksimul */ + NULL, /* pre-proxy */ + NULL, /* post-proxy */ + stg_postauth /* post-auth */ }, }; diff --git a/projects/rlm_stg/stg_client.cpp b/projects/rlm_stg/stg_client.cpp index 113e71c9..e34c50cd 100644 --- a/projects/rlm_stg/stg_client.cpp +++ b/projects/rlm_stg/stg_client.cpp @@ -18,278 +18,133 @@ * Author : Maxim Mamontov */ -/* - * Realization of data access via Stargazer for RADIUS - * - * $Revision: 1.8 $ - * $Date: 2010/04/16 12:30:02 $ - * - */ - -#include -#include -#include // close - -#include -#include -#include -#include - -#include - #include "stg_client.h" -typedef std::vector > PAIRS; - -//----------------------------------------------------------------------------- +#include "conn.h" +#include "radlog.h" -STG_CLIENT::STG_CLIENT(const std::string & host, uint16_t port, uint16_t lp, const std::string & pass) - : password(pass), - framedIP(0) -{ -/*sock = socket(AF_INET, SOCK_DGRAM, 0); -if (sock == -1) - { - std::string message = strerror(errno); - message = "Socket create error: '" + message + "'"; - throw std::runtime_error(message); - } +#include "stg/locker.h" +#include "stg/common.h" -struct hostent * he = NULL; -he = gethostbyname(host.c_str()); -if (he == NULL) - { - throw std::runtime_error("gethostbyname error"); - } +#include +#include -outerAddr.sin_family = AF_INET; -outerAddr.sin_port = htons(port); -outerAddr.sin_addr.s_addr = *(uint32_t *)he->h_addr; +using STG::RLM::Client; +using STG::RLM::Conn; +using STG::RLM::RESULT; -InitEncrypt(&ctx, password); +namespace { -PrepareNet();*/ -} +Client* stgClient = NULL; -STG_CLIENT::~STG_CLIENT() -{ -/*close(sock);*/ } -int STG_CLIENT::PrepareNet() +class Client::Impl { -return 0; -} - -int STG_CLIENT::Send(const RAD_PACKET & packet) + public: + explicit Impl(const std::string& address); + ~Impl(); + + bool stop() { return m_conn ? m_conn->stop() : true; } + + RESULT request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs); + + private: + std::string m_address; + boost::scoped_ptr m_conn; + + pthread_mutex_t m_mutex; + pthread_cond_t m_cond; + bool m_done; + RESULT m_result; + + static bool callback(void* data, const RESULT& result) + { + Impl& impl = *static_cast(data); + STG_LOCKER lock(impl.m_mutex); + impl.m_result = result; + impl.m_done = true; + pthread_cond_signal(&impl.m_cond); + return true; + } +}; + +Client::Impl::Impl(const std::string& address) + : m_address(address) { -/*char buf[RAD_MAX_PACKET_LEN]; - -Encrypt(&ctx, buf, (char *)&packet, sizeof(RAD_PACKET) / 8); - -int res = sendto(sock, buf, sizeof(RAD_PACKET), 0, (struct sockaddr *)&outerAddr, sizeof(outerAddr)); - -if (res == -1) - errorStr = "Error sending data"; - -return res;*/ -} - -int STG_CLIENT::RecvData(RAD_PACKET * packet) -{ -/*char buf[RAD_MAX_PACKET_LEN]; -int res; - -struct sockaddr_in addr; -socklen_t len = sizeof(struct sockaddr_in); - -res = recvfrom(sock, buf, RAD_MAX_PACKET_LEN, 0, reinterpret_cast(&addr), &len); -if (res == -1) + try { - errorStr = "Error receiving data"; - return -1; + m_conn.reset(new Conn(m_address, &Impl::callback, this)); } - -Decrypt(&ctx, (char *)packet, buf, res / 8); - -return 0;*/ -} - -int STG_CLIENT::Request(RAD_PACKET * packet, const std::string & login, const std::string & svc, uint8_t packetType) -{ -/*int res; - -memcpy((void *)&packet->magic, (void *)RAD_ID, RAD_MAGIC_LEN); -packet->protoVer[0] = '0'; -packet->protoVer[1] = '1'; -packet->packetType = packetType; -packet->ip = 0; -strncpy((char *)packet->login, login.c_str(), RAD_LOGIN_LEN); -strncpy((char *)packet->service, svc.c_str(), RAD_SERVICE_LEN); - -res = Send(*packet); -if (res == -1) - return -1; - -res = RecvData(packet); -if (res == -1) - return -1; - -if (strncmp((char *)packet->magic, RAD_ID, RAD_MAGIC_LEN)) + catch (const std::runtime_error& ex) { - errorStr = "Magic invalid. Wanted: '"; - errorStr += RAD_ID; - errorStr += "', got: '"; - errorStr += (char *)packet->magic; - errorStr += "'"; - return -1; + RadLog("Connection error: %s.", ex.what()); } - -return 0;*/ + pthread_mutex_init(&m_mutex, NULL); + pthread_cond_init(&m_cond, NULL); + m_done = false; } -//----------------------------------------------------------------------------- - -const STG_PAIRS * STG_CLIENT::Authorize(const std::string & login, const std::string & svc) +Client::Impl::~Impl() { -/*RAD_PACKET packet; - -userPassword = ""; - -if (Request(&packet, login, svc, RAD_AUTZ_PACKET)) - return -1; - -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1; - -userPassword = (char *)packet.password;*/ - -PAIRS pairs; -pairs.push_back(std::make_pair("Cleartext-Password", userPassword)); - -return ToSTGPairs(pairs); + pthread_cond_destroy(&m_cond); + pthread_mutex_destroy(&m_mutex); } -const STG_PAIRS * STG_CLIENT::Authenticate(const std::string & login, const std::string & svc) +RESULT Client::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) { -/*RAD_PACKET packet; - -userPassword = ""; - -if (Request(&packet, login, svc, RAD_AUTH_PACKET)) - return -1; - -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1;*/ - -PAIRS pairs; - -return ToSTGPairs(pairs); + STG_LOCKER lock(m_mutex); + if (!m_conn || !m_conn->connected()) + m_conn.reset(new Conn(m_address, &Impl::callback, this)); + if (!m_conn->connected()) + throw Conn::Error("Failed to create connection to '" + m_address + "'."); + + m_done = false; + m_conn->request(type, userName, password, pairs); + timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + ts.tv_sec += 5; + int res = 0; + while (!m_done && res == 0) + res = pthread_cond_timedwait(&m_cond, &m_mutex, &ts); + if (res != 0) + throw Conn::Error("Request failed."); + return m_result; } -const STG_PAIRS * STG_CLIENT::PostAuth(const std::string & login, const std::string & svc) +Client::Client(const std::string& address) + : m_impl(new Impl(address)) { -/*RAD_PACKET packet; - -userPassword = ""; - -if (Request(&packet, login, svc, RAD_POST_AUTH_PACKET)) - return -1; - -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1; - -if (svc == "Framed-User") - framedIP = packet.ip; -else - framedIP = 0;*/ - -PAIRS pairs; -pairs.push_back(std::make_pair("Framed-IP-Address", inet_ntostring(framedIP))); - -return ToSTGPairs(pairs); } -const STG_PAIRS * STG_CLIENT::PreAcct(const std::string & login, const std::String & service) +Client::~Client() { -PAIRS pairs; - -return ToSTGPairs(pairs); } -const STG_PAIRS * STG_CLIENT::Account(const std::string & type, const std::string & login, const std::string & svc, const std::string & sessid) +bool Client::stop() { -/*RAD_PACKET packet; - -userPassword = ""; -strncpy((char *)packet.sessid, sessid.c_str(), RAD_SESSID_LEN); - -if (type == "Start") - { - if (Request(&packet, login, svc, RAD_ACCT_START_PACKET)) - return -1; - } -else if (type == "Stop") - { - if (Request(&packet, login, svc, RAD_ACCT_STOP_PACKET)) - return -1; - } -else if (type == "Interim-Update") - { - if (Request(&packet, login, svc, RAD_ACCT_UPDATE_PACKET)) - return -1; - } -else - { - if (Request(&packet, login, svc, RAD_ACCT_OTHER_PACKET)) - return -1; - } - -if (packet.packetType != RAD_ACCEPT_PACKET) - return -1;*/ - -PAIRS pairs; - -return ToSTGPairs(pairs); + return m_impl->stop(); } -//----------------------------------------------------------------------------- - -std::string STG_CLIENT_ST::m_host; -uint16_t STG_CLIENT_ST::m_port(6666); -std::string STG_CLIENT_ST::m_password; - -//----------------------------------------------------------------------------- - -STG_CLIENT * STG_CLIENT_ST::Get() +RESULT Client::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs) { - static STG_CLIENT * stgClient = NULL; - if ( stgClient == NULL ) - stgClient = new STG_CLIENT(m_host, m_port, m_password); - return stgClient; + return m_impl->request(type, userName, password, pairs); } -void STG_CLIENT_ST::Configure(const std::string & host, uint16_t port, const std::string & password) +Client* Client::get() { - m_host = host; - m_port = port; - m_password = password; + return stgClient; } -//----------------------------------------------------------------------------- - -const STG_PAIR * ToSTGPairs(const PAIRS & source) +bool Client::configure(const std::string& address) { - STG_PAIR * pairs = new STG_PAIR[source.size() + 1]; - for (size_t pos = 0; pos < source.size(); ++pos) { - bzero(pairs[pos].key, sizeof(STG_PAIR::key)); - bzero(pairs[pos].value, sizeof(STG_PAIR::value)); - strncpy(pairs[pos].key, source[pos].first.c_str(), sizeof(STG_PAIR::key)); - strncpy(pairs[pos].value, source[pos].second.c_str(), sizeof(STG_PAIR::value)); - ++pos; + if ( stgClient != NULL ) + return stgClient->configure(address); + try { + stgClient = new Client(address); + return true; + } catch (const std::exception& ex) { + RadLog("Client configuration error: %s.", ex.what()); } - bzero(pairs[sources.size()].key, sizeof(STG_PAIR::key)); - bzero(pairs[sources.size()].value, sizeof(STG_PAIR::value)); - - return pairs; + return false; } diff --git a/projects/rlm_stg/stg_client.h b/projects/rlm_stg/stg_client.h index 5ee000c7..917d0e51 100644 --- a/projects/rlm_stg/stg_client.h +++ b/projects/rlm_stg/stg_client.h @@ -18,61 +18,41 @@ * Author : Maxim Mamontov */ -/* - * Header file for client part of data access via Stargazer for RADIUS - * - * $Revision: 1.4 $ - * $Date: 2010/04/16 12:30:02 $ - * - */ +#ifndef __STG_RLM_CLIENT_H__ +#define __STG_RLM_CLIENT_H__ -#ifndef STG_CLIENT_H -#define STG_CLIENT_H +#include "types.h" -#include +#include "stg/os_int.h" -#include -#include -#include // socklen_t +#include -#include "stg/blowfish.h" -#include "stg/rad_packets.h" +#include -#include "stgpair.h" +namespace STG +{ +namespace RLM +{ -class STG_CLIENT +class Client { public: - STG_CLIENT(const std::string & host, uint16_t port, const std::string & password); - ~STG_CLIENT(); - - const STG_PAIR * Authorize(const std::string & login, const std::string & service); - const STG_PAIR * Authenticate(const std::string & login, const std::string & service); - const STG_PAIR * PostAuth(const std::string & login, const std::string & service); - const STG_PAIR * PreAcct(const std::string & login, const std::string & service); - const STG_PAIR * Account(const std::string & type, const std::string & login, const std::string & service, const std::string & sessionId); + explicit Client(const std::string& address); + ~Client(); -private: - std::string password; + bool stop(); - int PrepareNet(); + static Client* get(); + static bool configure(const std::string& address); - int Request(RAD_PACKET * packet, const std::string & login, const std::string & svc, uint8_t packetType); + RESULT request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs); - int RecvData(RAD_PACKET * packet); - int Send(const RAD_PACKET & packet); +private: + class Impl; + boost::scoped_ptr m_impl; }; -struct STG_CLIENT_ST -{ - public: - static void Configure(const std::string & host, uint16_t port, const std::string & password); - static STG_CLIENT * Get(); - - private: - static std::string m_host; - static uint16_t m_port; - static std::string m_password; -}; +} // namespace RLM +} // namespace STG #endif diff --git a/projects/rlm_stg/stgpair.h b/projects/rlm_stg/stgpair.h index 19b42bc1..ef7ab4b7 100644 --- a/projects/rlm_stg/stgpair.h +++ b/projects/rlm_stg/stgpair.h @@ -1,12 +1,47 @@ #ifndef __STG_STGPAIR_H__ #define __STG_STGPAIR_H__ +#include + #define STGPAIR_KEYLENGTH 64 #define STGPAIR_VALUELENGTH 256 +#ifdef __cplusplus +extern "C" { +#endif + typedef struct STG_PAIR { char key[STGPAIR_KEYLENGTH]; char value[STGPAIR_VALUELENGTH]; } STG_PAIR; +typedef struct STG_RESULT { + STG_PAIR* modify; + STG_PAIR* reply; + int returnCode; +} STG_RESULT; + +inline +int emptyPair(const STG_PAIR* pair) +{ + return pair == NULL || pair->key[0] == '\0' || pair->value[0] == '\0'; +} + +enum +{ + STG_REJECT, + STG_FAIL, + STG_OK, + STG_HANDLED, + STG_INVALID, + STG_USERLOCK, + STG_NOTFOUND, + STG_NOOP, + STG_UPDATED +}; + +#ifdef __cplusplus +} +#endif + #endif diff --git a/projects/rlm_stg/types.h b/projects/rlm_stg/types.h new file mode 100644 index 00000000..2bc721fc --- /dev/null +++ b/projects/rlm_stg/types.h @@ -0,0 +1,52 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#ifndef __STG_RLM_CLIENT_CONN_H__ +#define __STG_RLM_CLIENT_CONN_H__ + +#include +#include + +namespace STG +{ +namespace RLM +{ + +typedef std::vector > PAIRS; + +struct RESULT +{ + PAIRS modify; + PAIRS reply; + int returnCode; +}; + +enum REQUEST_TYPE { + AUTHORIZE, + AUTHENTICATE, + POST_AUTH, + PRE_ACCT, + ACCOUNT +}; + +} // namespace RLM +} // namespace STG + +#endif diff --git a/projects/stargazer/build b/projects/stargazer/build index 80d17d00..a43e3e13 100755 --- a/projects/stargazer/build +++ b/projects/stargazer/build @@ -105,7 +105,6 @@ PLUGINS="authorization/ao configuration/sgconfig other/ping other/rscript - other/radius other/smux store/files capture/cap_nf" @@ -419,6 +418,55 @@ else fi rm -f fake +printf "Checking for -lyajl... " +pkg-config --version > /dev/null 2> /dev/null +if [ "$?" = "0" ] +then + pkg-config --atleast-version=2.0.0 yajl + if [ "$?" != "0" ] + then + CHECK_YAJL=no + printf "no\n" + else + CHECK_YAJL=yes + printf `pkg-config --modversion yajl`"\n" + fi +else + printf "#include \n" > build_check.c + printf "#include \n" >> build_check.c + printf "int main() { printf(\"%%d\", yajl_version()); return 0; }\n" >> build_check.c + $CC $CFLAGS $LDFLAGS build_check.c -lyajl -o fake > /dev/null 2> /dev/null + if [ $? != 0 ] + then + CHECK_YAJL=no + printf "no\n" + else + YAJL_VERSION=`./fake` + if [ $YAJL_VERSION -ge 20000 ] + then + CHECK_YAJL=yes + printf "${YAJL_VERSION}\n" + else + CHECK_YAJL=no + printf "no. Need at least version 2.0.0, existing version is ${YAJL_VERSION}\n" + fi + fi + rm -f fake +fi + +printf "Checking for boost::scoped_ptr... " +printf "#include \nint main() { boost::scoped_ptr test(new int(1)); return 0; }\n" > build_check.cpp +$CXX $CXXFLAGS $LDFLAGS build_check.cpp -o fake # > /dev/null 2> /dev/null +if [ $? != 0 ] +then + CHECK_BOOST_SCOPED_PTR=no + printf "no\n" +else + CHECK_BOOST_SCOPED_PTR=yes + printf "yes\n" +fi +rm -f fake + if [ "$OS" = "linux" ] then printf "Checking for linux/netfilter_ipv4/ip_queue.h... " @@ -441,6 +489,7 @@ then fi rm -f build_check.c +rm -f build_check.cpp if [ "$CHECK_EXPAT" != "yes" ] then @@ -480,6 +529,14 @@ then capture/nfqueue" fi +if [ "$CHECK_YAJL" = "yes" -a "$CHECK_BOOST_SCOPED_PTR" = "yes" ] +then + PLUGINS="$PLUGINS + other/radius" + STG_LIBS="$STG_LIBS + json.lib" +fi + printf "OS=$OS\n" > $CONFFILE printf "STG_TIME=yes\n" >> $CONFFILE printf "DEBUG=$DEBUG\n" >> $CONFFILE diff --git a/projects/stargazer/plugins/other/radius/Makefile b/projects/stargazer/plugins/other/radius/Makefile index 62a05183..14428914 100644 --- a/projects/stargazer/plugins/other/radius/Makefile +++ b/projects/stargazer/plugins/other/radius/Makefile @@ -4,16 +4,19 @@ include ../../../../../Makefile.conf -LIBS += $(LIB_THREAD) +LIBS += $(LIB_THREAD) -lyajl PROG = mod_radius.so -SRCS = ./radius.cpp +SRCS = radius.cpp \ + config.cpp \ + conn.cpp STGLIBS = common \ crypto \ logger \ - scriptexecuter + scriptexecuter \ + json include ../../Makefile.in diff --git a/projects/stargazer/plugins/other/radius/config.cpp b/projects/stargazer/plugins/other/radius/config.cpp new file mode 100644 index 00000000..6742f811 --- /dev/null +++ b/projects/stargazer/plugins/other/radius/config.cpp @@ -0,0 +1,335 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#include "config.h" + +#include "stg/common.h" + +#include +#include + +#include // strncasecmp + +using STG::Config; + +namespace +{ + +struct ParserError : public std::runtime_error +{ + ParserError(size_t pos, const std::string& message) + : runtime_error("Parsing error at position " + x2str(pos) + ". " + message), + position(pos), + error(message) + {} + virtual ~ParserError() throw() {} + + size_t position; + std::string error; +}; + +size_t skipSpaces(const std::string& value, size_t start) +{ + while (start < value.length() && std::isspace(value[start])) + ++start; + return start; +} + +size_t checkChar(const std::string& value, size_t start, char ch) +{ + if (start >= value.length()) + throw ParserError(start, "Unexpected end of string. Expected '" + std::string(1, ch) + "'."); + if (value[start] != ch) + throw ParserError(start, "Expected '" + std::string(1, ch) + "', got '" + std::string(1, value[start]) + "'."); + return start + 1; +} + +std::pair readString(const std::string& value, size_t start) +{ + std::string dest; + while (start < value.length() && !std::isspace(value[start]) && + value[start] != ',' && value[start] != '(' && value[start] != ')') + dest.push_back(value[start++]); + if (dest.empty()) { + if (start == value.length()) + throw ParserError(start, "Unexpected end of string. Expected string."); + else + throw ParserError(start, "Unexpected whitespace. Expected string."); + } + return std::make_pair(start, dest); +} + +Config::Pairs toPairs(const std::vector& values) +{ + if (values.empty()) + return Config::Pairs(); + std::string value(values[0]); + Config::Pairs res; + size_t start = 0; + while (start < value.size()) { + Config::Pair pair; + start = skipSpaces(value, start); + if (!res.empty()) + { + start = checkChar(value, start, ','); + start = skipSpaces(value, start); + } + size_t pairStart = start; + start = checkChar(value, start, '('); + const std::pair key = readString(value, start); + start = key.first; + pair.first = key.second; + start = skipSpaces(value, start); + start = checkChar(value, start, ','); + start = skipSpaces(value, start); + const std::pair val = readString(value, start); + start = val.first; + pair.second = val.second; + start = skipSpaces(value, start); + start = checkChar(value, start, ')'); + if (res.find(pair.first) != res.end()) + throw ParserError(pairStart, "Duplicate field."); + res.insert(pair); + } + return res; +} + +bool toBool(const std::vector& values) +{ + if (values.empty()) + return false; + std::string value(values[0]); + return strncasecmp(value.c_str(), "yes", 3) == 0; +} + +std::string toString(const std::vector& values) +{ + if (values.empty()) + return ""; + return values[0]; +} + +uid_t toUID(const std::vector& values) +{ + if (values.empty()) + return -1; + uid_t res = str2uid(values[0]); + if (res == static_cast(-1)) + throw ParserError(0, "Invalid user name: '" + values[0] + "'"); + return res; +} + +gid_t toGID(const std::vector& values) +{ + if (values.empty()) + return -1; + gid_t res = str2gid(values[0]); + if (res == static_cast(-1)) + throw ParserError(0, "Invalid group name: '" + values[0] + "'"); + return res; +} + +mode_t toMode(const std::vector& values) +{ + if (values.empty()) + return -1; + mode_t res = str2mode(values[0]); + if (res == static_cast(-1)) + throw ParserError(0, "Invalid mode: '" + values[0] + "'"); + return res; +} + +template +T toInt(const std::vector& values) +{ + if (values.empty()) + return 0; + T res = 0; + if (str2x(values[0], res) == 0) + return res; + return 0; +} + +uint16_t toPort(const std::string& value) +{ + uint16_t res = 0; + if (str2x(value, res) == 0) + return res; + throw ParserError(0, "'" + value + "' is not a valid port number."); +} + +typedef std::map Codes; + +// One-time call to initialize the list of codes. +Codes getCodes() +{ + Codes res; + res["reject"] = Config::REJECT; + res["fail"] = Config::FAIL; + res["ok"] = Config::OK; + res["handled"] = Config::HANDLED; + res["invalid"] = Config::INVALID; + res["userlock"] = Config::USERLOCK; + res["notfound"] = Config::NOTFOUND; + res["noop"] = Config::NOOP; + res["updated"] = Config::UPDATED; + return res; +} + +Config::ReturnCode toReturnCode(const std::vector& values) +{ + static Codes codes(getCodes()); + if (values.empty()) + return Config::REJECT; + std::string code = ToLower(values[0]); + const Codes::const_iterator it = codes.find(code); + if (it == codes.end()) + return Config::REJECT; + return it->second; +} + +Config::Pairs parseVector(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toPairs(params[i].value); + return Config::Pairs(); +} + +Config::ReturnCode parseReturnCode(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toReturnCode(params[i].value); + return Config::REJECT; +} + +bool parseBool(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toBool(params[i].value); + return false; +} + +std::string parseString(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toString(params[i].value); + return ""; +} + +std::string parseAddress(Config::Type connectionType, const std::string& value) +{ + size_t pos = value.find_first_of(':'); + if (pos == std::string::npos) + throw ParserError(0, "Connection type is not specified. Should be either 'unix' or 'tcp'."); + if (connectionType == Config::UNIX) + return value.substr(pos + 1); + std::string address(value.substr(pos + 1)); + pos = address.find_first_of(':', pos + 1); + if (pos == std::string::npos) + throw ParserError(0, "Port is not specified."); + return address.substr(0, pos - 1); +} + +std::string parsePort(Config::Type connectionType, const std::string& value) +{ + size_t pos = value.find_first_of(':'); + if (pos == std::string::npos) + throw ParserError(0, "Connection type is not specified. Should be either 'unix' or 'tcp'."); + if (connectionType == Config::UNIX) + return value.substr(pos + 1); + std::string address(value.substr(pos + 1)); + pos = address.find_first_of(':', pos + 1); + if (pos == std::string::npos) + throw ParserError(0, "Port is not specified."); + return address.substr(pos + 1); +} + +Config::Type parseConnectionType(const std::string& address) +{ + size_t pos = address.find_first_of(':'); + if (pos == std::string::npos) + throw ParserError(0, "Connection type is not specified. Should be either 'unix' or 'tcp'."); + std::string type = ToLower(address.substr(0, pos)); + if (type == "unix") + return Config::UNIX; + else if (type == "tcp") + return Config::TCP; + throw ParserError(0, "Invalid connection type. Should be either 'unix' or 'tcp', got '" + type + "'"); +} + +Config::Section parseSection(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return Config::Section(parseVector("match", params[i].sections), + parseVector("modify", params[i].sections), + parseVector("reply", params[i].sections), + parseReturnCode("no_match", params[i].sections)); + return Config::Section(); +} + +uid_t parseUID(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toUID(params[i].value); + return -1; +} + +gid_t parseGID(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toGID(params[i].value); + return -1; +} + +mode_t parseMode(const std::string& paramName, const std::vector& params) +{ + for (size_t i = 0; i < params.size(); ++i) + if (params[i].param == paramName) + return toMode(params[i].value); + return -1; +} + +} // namespace anonymous + +Config::Config(const MODULE_SETTINGS& settings) + : autz(parseSection("autz", settings.moduleParams)), + auth(parseSection("auth", settings.moduleParams)), + postauth(parseSection("postauth", settings.moduleParams)), + preacct(parseSection("preacct", settings.moduleParams)), + acct(parseSection("acct", settings.moduleParams)), + verbose(parseBool("verbose", settings.moduleParams)), + address(parseString("bind_address", settings.moduleParams)), + connectionType(parseConnectionType(address)), + bindAddress(parseAddress(connectionType, address)), + portStr(parsePort(connectionType, address)), + port(toPort(portStr)), + key(parseString("key", settings.moduleParams)), + sockUID(parseUID("sock_owner", settings.moduleParams)), + sockGID(parseGID("sock_group", settings.moduleParams)), + sockMode(parseMode("sock_mode", settings.moduleParams)) +{ +} diff --git a/projects/stargazer/plugins/other/radius/config.h b/projects/stargazer/plugins/other/radius/config.h new file mode 100644 index 00000000..2d4f638f --- /dev/null +++ b/projects/stargazer/plugins/other/radius/config.h @@ -0,0 +1,91 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#ifndef __STG_RADIUS_CONFIG_H__ +#define __STG_RADIUS_CONFIG_H__ + +#include "stg/module_settings.h" + +#include "stg/os_int.h" + +#include +#include + +#include // uid_t, gid_t +#include // mode_t + +namespace STG +{ + +struct Config +{ + typedef std::map Pairs; + typedef std::pair Pair; + enum Type { UNIX, TCP }; + enum ReturnCode + { + REJECT, // Reject the request immediately. + FAIL, // Module failed. + OK, // Module is OK, continue. + HANDLED, // The request is handled, no further handling. + INVALID, // The request is invalud. + USERLOCK, // Reject the request, user is locked. + NOTFOUND, // User not found. + NOOP, // Module performed no action. + UPDATED // Module sends some updates. + }; + + struct Section + { + Section() {} + Section(const Pairs& ma, const Pairs& mo, const Pairs& re, ReturnCode code) + : match(ma), modify(mo), reply(re), returnCode(code) {} + Pairs match; + Pairs modify; + Pairs reply; + ReturnCode returnCode; + }; + + Config() {} + Config(const MODULE_SETTINGS& settings); + + Section autz; + Section auth; + Section postauth; + Section preacct; + Section acct; + + bool verbose; + + std::string address; + Type connectionType; + std::string bindAddress; + std::string portStr; + uint16_t port; + std::string key; + + uid_t sockUID; + gid_t sockGID; + mode_t sockMode; +}; + +} // namespace STG + +#endif diff --git a/projects/stargazer/plugins/other/radius/conn.cpp b/projects/stargazer/plugins/other/radius/conn.cpp new file mode 100644 index 00000000..8a416ae0 --- /dev/null +++ b/projects/stargazer/plugins/other/radius/conn.cpp @@ -0,0 +1,513 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#include "conn.h" + +#include "config.h" + +#include "stg/json_parser.h" +#include "stg/json_generator.h" +#include "stg/users.h" +#include "stg/user.h" +#include "stg/logger.h" +#include "stg/common.h" + +#include + +#include +#include +#include +#include + +#include +#include +#include + +using STG::Conn; +using STG::Config; +using STG::JSON::Parser; +using STG::JSON::PairsParser; +using STG::JSON::EnumParser; +using STG::JSON::NodeParser; +using STG::JSON::Gen; +using STG::JSON::MapGen; +using STG::JSON::StringGen; + +namespace +{ + +double CONN_TIMEOUT = 60; +double PING_TIMEOUT = 10; + +enum Packet +{ + PING, + PONG, + DATA +}; + +enum Stage +{ + AUTHORIZE, + AUTHENTICATE, + PREACCT, + ACCOUNTING, + POSTAUTH +}; + +std::map packetCodes; +std::map stageCodes; + +class PacketParser : public EnumParser +{ + public: + PacketParser(NodeParser* next, Packet& packet, std::string& packetStr) + : EnumParser(next, packet, packetStr, packetCodes) + { + if (!packetCodes.empty()) + return; + packetCodes["ping"] = PING; + packetCodes["pong"] = PONG; + packetCodes["data"] = DATA; + } +}; + +class StageParser : public EnumParser +{ + public: + StageParser(NodeParser* next, Stage& stage, std::string& stageStr) + : EnumParser(next, stage, stageStr, stageCodes) + { + if (!stageCodes.empty()) + return; + stageCodes["authorize"] = AUTHORIZE; + stageCodes["authenticate"] = AUTHENTICATE; + stageCodes["preacct"] = PREACCT; + stageCodes["accounting"] = ACCOUNTING; + stageCodes["postauth"] = POSTAUTH; + } +}; + +class TopParser : public NodeParser +{ + public: + typedef void (*Callback) (void* /*data*/); + TopParser(Callback callback, void* data) + : m_packetParser(this, m_packet, m_packetStr), + m_stageParser(this, m_stage, m_stageStr), + m_pairsParser(this, m_data), + m_callback(callback), m_callbackData(data) + {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& value) + { + std::string key = ToLower(value); + + if (key == "packet") + return &m_packetParser; + else if (key == "stage") + return &m_stageParser; + else if (key == "pairs") + return &m_pairsParser; + + return this; + } + virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; } + + const std::string& packetStr() const { return m_packetStr; } + Packet packet() const { return m_packet; } + const std::string& stageStr() const { return m_stageStr; } + Stage stage() const { return m_stage; } + const Config::Pairs& data() const { return m_data; } + + private: + std::string m_packetStr; + Packet m_packet; + std::string m_stageStr; + Stage m_stage; + Config::Pairs m_data; + + PacketParser m_packetParser; + StageParser m_stageParser; + PairsParser m_pairsParser; + + Callback m_callback; + void* m_callbackData; +}; + +class ProtoParser : public Parser +{ + public: + ProtoParser(TopParser::Callback callback, void* data) + : Parser( &m_topParser ), + m_topParser(callback, data) + {} + + const std::string& packetStr() const { return m_topParser.packetStr(); } + Packet packet() const { return m_topParser.packet(); } + const std::string& stageStr() const { return m_topParser.stageStr(); } + Stage stage() const { return m_topParser.stage(); } + const Config::Pairs& data() const { return m_topParser.data(); } + + private: + TopParser m_topParser; +}; + +class PacketGen : public Gen +{ + public: + PacketGen(const std::string& type) + : m_type(type) + { + m_gen.add("packet", m_type); + } + void run(yajl_gen_t* handle) const + { + m_gen.run(handle); + } + PacketGen& add(const std::string& key, const std::string& value) + { + m_gen.add(key, new StringGen(value)); + return *this; + } + PacketGen& add(const std::string& key, MapGen* map) + { + m_gen.add(key, map); + return *this; + } + PacketGen& add(const std::string& key, MapGen& map) + { + m_gen.add(key, map); + return *this; + } + private: + MapGen m_gen; + StringGen m_type; +}; + +std::string toString(Config::ReturnCode code) +{ + switch (code) + { + case Config::REJECT: return "reject"; + case Config::FAIL: return "fail"; + case Config::OK: return "ok"; + case Config::HANDLED: return "handled"; + case Config::INVALID: return "invalid"; + case Config::USERLOCK: return "userlock"; + case Config::NOTFOUND: return "notfound"; + case Config::NOOP: return "noop"; + case Config::UPDATED: return "noop"; + } + return "reject"; +} + +} + +class Conn::Impl +{ + public: + Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote); + ~Impl(); + + int sock() const { return m_sock; } + + bool read(); + bool tick(); + + bool isOk() const { return m_ok; } + + private: + USERS& m_users; + PLUGIN_LOGGER& m_logger; + const Config& m_config; + int m_sock; + std::string m_remote; + bool m_ok; + time_t m_lastPing; + time_t m_lastActivity; + ProtoParser m_parser; + + template + const T& stageMember(T Config::Section::* member) const + { + switch (m_parser.stage()) + { + case AUTHORIZE: return m_config.autz.*member; + case AUTHENTICATE: return m_config.auth.*member; + case POSTAUTH: return m_config.postauth.*member; + case PREACCT: return m_config.preacct.*member; + case ACCOUNTING: return m_config.acct.*member; + } + throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'."); + } + + const Config::Pairs& match() const { return stageMember(&Config::Section::match); } + const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); } + const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); } + Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); } + + static void process(void* data); + void processPing(); + void processPong(); + void processData(); + bool answer(const USER& user); + bool answerNo(); + bool sendPing(); + bool sendPong(); + + static bool write(void* data, const char* buf, size_t size); +}; + +Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) + : m_impl(new Impl(users, logger, config, fd, remote)) +{ +} + +Conn::~Conn() +{ +} + +int Conn::sock() const +{ + return m_impl->sock(); +} + +bool Conn::read() +{ + return m_impl->read(); +} + +bool Conn::tick() +{ + return m_impl->tick(); +} + +bool Conn::isOk() const +{ + return m_impl->isOk(); +} + +Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote) + : m_users(users), + m_logger(logger), + m_config(config), + m_sock(fd), + m_remote(remote), + m_ok(true), + m_lastPing(time(NULL)), + m_lastActivity(m_lastPing), + m_parser(&Conn::Impl::process, this) +{ +} + +Conn::Impl::~Impl() +{ + close(m_sock); +} + +bool Conn::Impl::read() +{ + static std::vector buffer(1024); + ssize_t res = ::read(m_sock, buffer.data(), buffer.size()); + if (res < 0) + { + m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno)); + m_ok = false; + return false; + } + printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str()); + m_lastActivity = time(NULL); + if (res == 0) + { + m_ok = false; + return true; + } + return m_parser.append(buffer.data(), res); +} + +bool Conn::Impl::tick() +{ + time_t now = time(NULL); + if (difftime(now, m_lastActivity) > CONN_TIMEOUT) + { + int delta = difftime(now, m_lastActivity); + printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta); + m_logger("Connection to " + m_remote + " timed out."); + m_ok = false; + return false; + } + if (difftime(now, m_lastPing) > PING_TIMEOUT) + { + int delta = difftime(now, m_lastPing); + printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta); + sendPing(); + } + return true; +} + +void Conn::Impl::process(void* data) +{ + Impl& impl = *static_cast(data); + try + { + switch (impl.m_parser.packet()) + { + case PING: + impl.processPing(); + return; + case PONG: + impl.processPong(); + return; + case DATA: + impl.processData(); + return; + } + } + catch (const std::exception& ex) + { + printfd(__FILE__, "Processing error. %s", ex.what()); + impl.m_logger("Processing error. %s", ex.what()); + } + printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str()); + impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr()); +} + +void Conn::Impl::processPing() +{ + printfd(__FILE__, "Got ping. Sending pong...\n"); + sendPong(); +} + +void Conn::Impl::processPong() +{ + printfd(__FILE__, "Got pong.\n"); + m_lastActivity = time(NULL); +} + +void Conn::Impl::processData() +{ + printfd(__FILE__, "Got data.\n"); + int handle = m_users.OpenSearch(); + + USER_PTR user = NULL; + bool matched = false; + while (m_users.SearchNext(handle, &user) == 0) + { + if (user == NULL) + continue; + + matched = true; + for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it) + { + Config::Pairs::const_iterator pos = m_parser.data().find(it->first); + if (pos == m_parser.data().end()) + { + matched = false; + break; + } + if (user->GetParamValue(it->second) != pos->second) + { + matched = false; + break; + } + } + if (!matched) + continue; + answer(*user); + break; + } + + if (!matched) + answerNo(); + + m_users.CloseSearch(handle); +} + +bool Conn::Impl::answer(const USER& user) +{ + printfd(__FILE__, "Got match. Sending answer...\n"); + MapGen replyData; + for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it) + replyData.add(it->first, new StringGen(user.GetParamValue(it->second))); + + MapGen modifyData; + for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it) + modifyData.add(it->first, new StringGen(user.GetParamValue(it->second))); + + PacketGen gen("data"); + gen.add("result", "ok") + .add("reply", replyData) + .add("modify", modifyData); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::answerNo() +{ + printfd(__FILE__, "No match. Sending answer...\n"); + PacketGen gen("data"); + gen.add("result", "no"); + gen.add("return_code", toString(returnCode())); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPing() +{ + PacketGen gen("ping"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::sendPong() +{ + PacketGen gen("pong"); + + m_lastPing = time(NULL); + + return generate(gen, &Conn::Impl::write, this); +} + +bool Conn::Impl::write(void* data, const char* buf, size_t size) +{ + std::string json(buf, size); + printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str()); + Conn::Impl& conn = *static_cast(data); + while (size > 0) + { + ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL); + if (res < 0) + { + conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno)); + conn.m_ok = false; + return false; + } + size -= res; + } + return true; +} diff --git a/projects/stargazer/plugins/other/radius/conn.h b/projects/stargazer/plugins/other/radius/conn.h new file mode 100644 index 00000000..17ebc8cf --- /dev/null +++ b/projects/stargazer/plugins/other/radius/conn.h @@ -0,0 +1,57 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#ifndef __STG_SGCONFIG_CONN_H__ +#define __STG_SGCONFIG_CONN_H__ + +#include + +#include + +class USER; +class USERS; +class PLUGIN_LOGGER; + +namespace STG +{ + +struct Config; + +class Conn +{ + public: + Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote); + ~Conn(); + + int sock() const; + + bool read(); + bool tick(); + + bool isOk() const; + + private: + class Impl; + boost::scoped_ptr m_impl; +}; + +} + +#endif diff --git a/projects/stargazer/plugins/other/radius/radius.cpp b/projects/stargazer/plugins/other/radius/radius.cpp index 8e52cdb4..376e4278 100644 --- a/projects/stargazer/plugins/other/radius/radius.cpp +++ b/projects/stargazer/plugins/other/radius/radius.cpp @@ -18,586 +18,346 @@ * Author : Maxim Mamontov */ -/* - * This file contains a realization of radius data access plugin for Stargazer - * - * $Revision: 1.14 $ - * $Date: 2009/12/13 14:17:13 $ - * - */ +#include "radius.h" + +#include "stg/store.h" +#include "stg/users.h" +#include "stg/plugin_creator.h" +#include "stg/common.h" +#include +#include #include #include -#include +#include -#include "stg/store.h" -#include "stg/common.h" -#include "stg/user_conf.h" -#include "stg/user_property.h" -#include "stg/plugin_creator.h" -#include "radius.h" +#include +#include +#include // UNIX +#include // IP +#include // TCP +#include -extern volatile time_t stgTime; +using STG::Config; +using STG::Conn; -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- namespace { -PLUGIN_CREATOR radc; -void InitEncrypt(BLOWFISH_CTX * ctx, const std::string & password); -void Decrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8); -void Encrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8); -} -extern "C" PLUGIN * GetPlugin(); -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -PLUGIN * GetPlugin() -{ -return radc.GetPlugin(); +PLUGIN_CREATOR creator; + } -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -int RAD_SETTINGS::ParseServices(const std::vector & str, std::list * lst) + +extern "C" PLUGIN * GetPlugin() { -std::copy(str.begin(), str.end(), std::back_inserter(*lst)); -std::list::iterator it(std::find(lst->begin(), - lst->end(), - "empty")); -if (it != lst->end()) - *it = ""; - -return 0; + return creator.GetPlugin(); } -//----------------------------------------------------------------------------- -int RAD_SETTINGS::ParseSettings(const MODULE_SETTINGS & s) -{ -int p; -PARAM_VALUE pv; -std::vector::const_iterator pvi; -/////////////////////////// -pv.param = "Port"; -pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv); -if (pvi == s.moduleParams.end()) - { - errorStr = "Parameter \'Port\' not found."; - printfd(__FILE__, "Parameter 'Port' not found\n"); - return -1; - } -if (ParseIntInRange(pvi->value[0], 2, 65535, &p)) - { - errorStr = "Cannot parse parameter \'Port\': " + errorStr; - printfd(__FILE__, "Cannot parse parameter 'Port'\n"); - return -1; - } -port = static_cast(p); -/////////////////////////// -pv.param = "Password"; -pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv); -if (pvi == s.moduleParams.end()) - { - errorStr = "Parameter \'Password\' not found."; - printfd(__FILE__, "Parameter 'Password' not found\n"); - return -1; - } -password = pvi->value[0]; -/////////////////////////// -pv.param = "AuthServices"; -pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv); -if (pvi != s.moduleParams.end()) - { - ParseServices(pvi->value, &authServices); - } -/////////////////////////// -pv.param = "AcctServices"; -pvi = std::find(s.moduleParams.begin(), s.moduleParams.end(), pv); -if (pvi != s.moduleParams.end()) - { - ParseServices(pvi->value, &acctServices); - } -return 0; -} -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- -//----------------------------------------------------------------------------- RADIUS::RADIUS() - : ctx(), - errorStr(), - radSettings(), - settings(), - authServices(), - acctServices(), - sessions(), - nonstop(false), - isRunning(false), - users(NULL), - stgSettings(NULL), - store(NULL), - thread(), - mutex(), - sock(-1), - packet(), - logger(GetPluginLogger(GetStgLogger(), "radius")) + : m_config(), + m_running(false), + m_stopped(true), + m_users(NULL), + m_store(NULL), + m_listenSocket(0), + m_logger(GetPluginLogger(GetStgLogger(), "radius")) { -InitEncrypt(&ctx, ""); } -//----------------------------------------------------------------------------- + int RADIUS::ParseSettings() { -int ret = radSettings.ParseSettings(settings); -if (ret) - errorStr = radSettings.GetStrError(); -return ret; -} -//----------------------------------------------------------------------------- -int RADIUS::PrepareNet() -{ -sock = socket(AF_INET, SOCK_DGRAM, 0); - -if (sock < 0) - { - errorStr = "Cannot create socket."; - logger("Cannot create a socket: %s", strerror(errno)); - printfd(__FILE__, "Cannot create socket\n"); - return -1; - } - -struct sockaddr_in inAddr; -inAddr.sin_family = AF_INET; -inAddr.sin_port = htons(radSettings.GetPort()); -inAddr.sin_addr.s_addr = inet_addr("0.0.0.0"); - -if (bind(sock, (struct sockaddr*)&inAddr, sizeof(inAddr)) < 0) - { - errorStr = "RADIUS: Bind failed."; - logger("Cannot bind the socket: %s", strerror(errno)); - printfd(__FILE__, "Cannot bind socket\n"); - return -1; + try { + m_config = STG::Config(m_settings); + return reconnect() ? 0 : -1; + } catch (const std::runtime_error& ex) { + m_logger("Failed to parse settings. %s", ex.what()); + return -1; } - -return 0; } -//----------------------------------------------------------------------------- -int RADIUS::FinalizeNet() -{ -close(sock); -return 0; -} -//----------------------------------------------------------------------------- + int RADIUS::Start() { -std::string password(radSettings.GetPassword()); - -authServices = radSettings.GetAuthServices(); -acctServices = radSettings.GetAcctServices(); - -InitEncrypt(&ctx, password); + if (m_running) + return 0; -nonstop = true; + int res = pthread_create(&m_thread, NULL, run, this); + if (res == 0) + return 0; -if (PrepareNet()) - { + m_error = strerror(res); + m_logger("Failed to create thread: '" + m_error + "'."); return -1; - } - -if (!isRunning) - { - if (pthread_create(&thread, NULL, Run, this)) - { - errorStr = "Cannot create thread."; - logger("Cannot create thread."); - printfd(__FILE__, "Cannot create thread\n"); - return -1; - } - } - -errorStr = ""; -return 0; } -//----------------------------------------------------------------------------- + int RADIUS::Stop() { -if (!IsRunning()) - return 0; + if (m_stopped) + return 0; -nonstop = false; + m_running = false; -std::map::iterator it; -for (it = sessions.begin(); it != sessions.end(); ++it) - { - USER_PTR ui; - if (users->FindByName(it->second.userName, &ui)) - { - users->Unauthorize(ui->GetLogin(), this); - } - } -sessions.erase(sessions.begin(), sessions.end()); - -FinalizeNet(); - -if (isRunning) - { - //5 seconds to thread stops itself - for (int i = 0; i < 25 && isRunning; i++) - { + for (size_t i = 0; i < 25 && !m_stopped; i++) { struct timespec ts = {0, 200000000}; nanosleep(&ts, NULL); - } } -if (isRunning) - return -1; - -return 0; -} -//----------------------------------------------------------------------------- -void * RADIUS::Run(void * d) -{ -sigset_t signalSet; -sigfillset(&signalSet); -pthread_sigmask(SIG_BLOCK, &signalSet, NULL); - -RADIUS * rad = static_cast(d); -RAD_PACKET packet; - -rad->isRunning = true; - -while (rad->nonstop) - { - if (!WaitPackets(rad->sock)) - { - continue; - } - struct sockaddr_in outerAddr; - if (rad->RecvData(&packet, &outerAddr)) - { - printfd(__FILE__, "RADIUS::Run Error on RecvData\n"); - } - else - { - if (rad->ProcessData(&packet)) - { - packet.packetType = RAD_REJECT_PACKET; - } - rad->Send(packet, &outerAddr); - } + if (m_stopped) { + pthread_join(m_thread, NULL); + return 0; } -rad->isRunning = false; - -return NULL; -} -//----------------------------------------------------------------------------- -int RADIUS::RecvData(RAD_PACKET * packet, struct sockaddr_in * outerAddr) -{ - int8_t buf[RAD_MAX_PACKET_LEN]; - socklen_t outerAddrLen = sizeof(struct sockaddr_in); - ssize_t dataLen = recvfrom(sock, buf, RAD_MAX_PACKET_LEN, 0, reinterpret_cast(outerAddr), &outerAddrLen); - if (dataLen < 0) - { - logger("recvfrom error: %s", strerror(errno)); - return -1; - } - if (dataLen == 0) - return -1; - - Decrypt(&ctx, (char *)packet, (const char *)buf, dataLen / 8); - - if (strncmp((char *)packet->magic, RAD_ID, RAD_MAGIC_LEN)) - { - printfd(__FILE__, "RADIUS::RecvData Error magic. Wanted: '%s', got: '%s'\n", RAD_ID, packet->magic); - return -1; - } + if (m_config.connectionType == Config::UNIX) + unlink(m_config.bindAddress.c_str()); - return 0; -} -//----------------------------------------------------------------------------- -ssize_t RADIUS::Send(const RAD_PACKET & packet, struct sockaddr_in * outerAddr) -{ -size_t len = sizeof(RAD_PACKET); -char buf[1032]; - -Encrypt(&ctx, buf, (char *)&packet, len / 8); -ssize_t res = sendto(sock, buf, len, 0, reinterpret_cast(outerAddr), sizeof(struct sockaddr_in)); -if (res < 0) - logger("sendto error: %s", strerror(errno)); -return res; -} -//----------------------------------------------------------------------------- -int RADIUS::ProcessData(RAD_PACKET * packet) -{ -if (strncmp((const char *)packet->protoVer, "01", 2)) - { - printfd(__FILE__, "RADIUS::ProcessData packet.protoVer incorrect\n"); + m_error = "Failed to stop thread."; + m_logger(m_error); return -1; - } -switch (packet->packetType) - { - case RAD_AUTZ_PACKET: - return ProcessAutzPacket(packet); - case RAD_AUTH_PACKET: - return ProcessAuthPacket(packet); - case RAD_POST_AUTH_PACKET: - return ProcessPostAuthPacket(packet); - case RAD_ACCT_START_PACKET: - return ProcessAcctStartPacket(packet); - case RAD_ACCT_STOP_PACKET: - return ProcessAcctStopPacket(packet); - case RAD_ACCT_UPDATE_PACKET: - return ProcessAcctUpdatePacket(packet); - case RAD_ACCT_OTHER_PACKET: - return ProcessAcctOtherPacket(packet); - default: - printfd(__FILE__, "RADIUS::ProcessData Unsupported packet type: %d\n", packet->packetType); - return -1; - }; } //----------------------------------------------------------------------------- -int RADIUS::ProcessAutzPacket(RAD_PACKET * packet) +void* RADIUS::run(void* d) { -USER_CONF conf; + sigset_t signalSet; + sigfillset(&signalSet); + pthread_sigmask(SIG_BLOCK, &signalSet, NULL); -if (!IsAllowedService((char *)packet->service)) - { - printfd(__FILE__, "RADIUS::ProcessAutzPacket service '%s' is not allowed to authorize\n", packet->service); - packet->packetType = RAD_REJECT_PACKET; - return 0; - } - -if (store->RestoreUserConf(&conf, (char *)packet->login)) - { - packet->packetType = RAD_REJECT_PACKET; - printfd(__FILE__, "RADIUS::ProcessAutzPacket cannot restore conf for user '%s'\n", packet->login); - return 0; - } + static_cast(d)->runImpl(); -// At this point service can be authorized at least -// So we send a plain-text password - -packet->packetType = RAD_ACCEPT_PACKET; -strncpy((char *)packet->password, conf.password.c_str(), RAD_PASSWORD_LEN); - -return 0; + return NULL; } -//----------------------------------------------------------------------------- -int RADIUS::ProcessAuthPacket(RAD_PACKET * packet) -{ -USER_PTR ui; - -if (!CanAcctService((char *)packet->service)) - { - - // There are no sense to check for allowed service - // It has allready checked at previous stage (authorization) - printfd(__FILE__, "RADIUS::ProcessAuthPacket service '%s' neednot stargazer authentication\n", (char *)packet->service); - packet->packetType = RAD_ACCEPT_PACKET; - return 0; - } - -// At this point we have an accountable service -// All other services got a password if allowed or rejected - -if (!FindUser(&ui, (char *)packet->login)) +bool RADIUS::reconnect() +{ + if (!m_conns.empty()) { - packet->packetType = RAD_REJECT_PACKET; - printfd(__FILE__, "RADIUS::ProcessAuthPacket user '%s' not found\n", (char *)packet->login); - return 0; + std::deque::const_iterator it; + for (it = m_conns.begin(); it != m_conns.end(); ++it) + delete(*it); + m_conns.clear(); } - -if (ui->IsInetable()) + if (m_listenSocket != 0) { - packet->packetType = RAD_ACCEPT_PACKET; + shutdown(m_listenSocket, SHUT_RDWR); + close(m_listenSocket); } -else + if (m_config.connectionType == Config::UNIX) + m_listenSocket = createUNIX(); + else + m_listenSocket = createTCP(); + if (m_listenSocket == 0) + return false; + if (listen(m_listenSocket, 100) == -1) { - packet->packetType = RAD_REJECT_PACKET; + m_error = std::string("Error starting to listen socket: ") + strerror(errno); + m_logger(m_error); + return false; } - -packet->packetType = RAD_ACCEPT_PACKET; -return 0; + return true; } -//----------------------------------------------------------------------------- -int RADIUS::ProcessPostAuthPacket(RAD_PACKET * packet) -{ -USER_PTR ui; -if (!CanAcctService((char *)packet->service)) +int RADIUS::createUNIX() const +{ + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) { - - // There are no sense to check for allowed service - // It has allready checked at previous stage (authorization) - - packet->packetType = RAD_ACCEPT_PACKET; - return 0; + m_error = std::string("Error creating UNIX socket: ") + strerror(errno); + m_logger(m_error); + return 0; } - -if (!FindUser(&ui, (char *)packet->login)) + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, m_config.bindAddress.c_str(), m_config.bindAddress.length()); + unlink(m_config.bindAddress.c_str()); + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) { - packet->packetType = RAD_REJECT_PACKET; - printfd(__FILE__, "RADIUS::ProcessPostAuthPacket user '%s' not found\n", (char *)packet->login); - return 0; + shutdown(fd, SHUT_RDWR); + close(fd); + m_error = std::string("Error binding UNIX socket: ") + strerror(errno); + m_logger(m_error); + return 0; } - -// I think that only Framed-User services has sense to be accountable -// So we have to supply a Framed-IP - -USER_IPS ips = ui->GetProperty().ips; -packet->packetType = RAD_ACCEPT_PACKET; - -// Additional checking for Framed-User service - -if (!strncmp((char *)packet->service, "Framed-User", RAD_SERVICE_LEN)) - packet->ip = ips[0].ip; -else - packet->ip = 0; - -return 0; + chown(m_config.bindAddress.c_str(), m_config.sockUID, m_config.sockGID); + if (m_config.sockMode != static_cast(-1)) + chmod(m_config.bindAddress.c_str(), m_config.sockMode); + return fd; } -//----------------------------------------------------------------------------- -int RADIUS::ProcessAcctStartPacket(RAD_PACKET * packet) -{ -USER_PTR ui; -if (!FindUser(&ui, (char *)packet->login)) +int RADIUS::createTCP() const +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_family = AF_INET; /* Allow IPv4 */ + hints.ai_socktype = SOCK_STREAM; /* Stream socket */ + hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + addrinfo* ais = NULL; + int res = getaddrinfo(m_config.bindAddress.c_str(), m_config.portStr.c_str(), &hints, &ais); + if (res != 0) { - packet->packetType = RAD_REJECT_PACKET; - printfd(__FILE__, "RADIUS::ProcessAcctStartPacket user '%s' not found\n", (char *)packet->login); - return 0; + m_error = "Error resolving address '" + m_config.bindAddress + "': " + gai_strerror(res); + m_logger(m_error); + return 0; } -// At this point we have to unauthorize user only if it is an accountable service - -if (CanAcctService((char *)packet->service)) + for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next) { - if (sessions.find((const char *)packet->sessid) != sessions.end()) + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) { - printfd(__FILE__, "RADIUS::ProcessAcctStartPacket session already started!\n"); - packet->packetType = RAD_REJECT_PACKET; - return -1; + m_error = std::string("Error creating TCP socket: ") + strerror(errno); + m_logger(m_error); + freeaddrinfo(ais); + return 0; } - USER_IPS ips = ui->GetProperty().ips; - if (!users->Authorize(ui->GetLogin(), ips[0].ip, 0xffFFffFF, this)) + if (bind(fd, ai->ai_addr, ai->ai_addrlen) == -1) { - printfd(__FILE__, "RADIUS::ProcessAcctStartPacket cannot authorize user '%s'\n", packet->login); - packet->packetType = RAD_REJECT_PACKET; - return -1; + shutdown(fd, SHUT_RDWR); + close(fd); + m_error = std::string("Error binding TCP socket: ") + strerror(errno); + m_logger(m_error); + continue; } - sessions[(const char *)packet->sessid].userName = (const char *)packet->login; - sessions[(const char *)packet->sessid].serviceType = (const char *)packet->service; - for_each(sessions.begin(), sessions.end(), SPrinter()); - } -else - { - printfd(__FILE__, "RADIUS::ProcessAcctStartPacket service '%s' can not be accounted\n", (char *)packet->service); + freeaddrinfo(ais); + return fd; } -packet->packetType = RAD_ACCEPT_PACKET; -return 0; + m_error = "Failed to resolve '" + m_config.bindAddress; + m_logger(m_error); + + freeaddrinfo(ais); + return 0; } -//----------------------------------------------------------------------------- -int RADIUS::ProcessAcctStopPacket(RAD_PACKET * packet) + +void RADIUS::runImpl() { -std::map::iterator sid; + m_running = true; + m_stopped = false; -if ((sid = sessions.find((const char *)packet->sessid)) == sessions.end()) - { - printfd(__FILE__, "RADIUS::ProcessAcctStopPacket session had not started yet\n"); - packet->packetType = RAD_REJECT_PACKET; - return -1; - } + while (m_running) { + fd_set fds; -USER_PTR ui; + buildFDSet(fds); -if (!FindUser(&ui, sid->second.userName)) - { - packet->packetType = RAD_REJECT_PACKET; - printfd(__FILE__, "RADIUS::ProcessPostAuthPacket user '%s' not found\n", sid->second.userName.c_str()); - return 0; - } + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 500000; -sessions.erase(sid); + int res = select(maxFD() + 1, &fds, NULL, NULL, &tv); + if (res < 0) + { + if (errno == EINTR) + continue; + m_error = std::string("'select' is failed: '") + strerror(errno) + "'."; + m_logger(m_error); + break; + } -users->Unauthorize(ui->GetLogin(), this); + if (!m_running) + break; + + if (res > 0) + handleEvents(fds); + else + { + for (std::deque::iterator it = m_conns.begin(); it != m_conns.end(); ++it) + (*it)->tick(); + } + + cleanupConns(); + } -packet->packetType = RAD_ACCEPT_PACKET; -return 0; + m_stopped = true; } -//----------------------------------------------------------------------------- -int RADIUS::ProcessAcctUpdatePacket(RAD_PACKET * packet) + +int RADIUS::maxFD() const { -// Fake. May be use it later -packet->packetType = RAD_ACCEPT_PACKET; -return 0; + int maxFD = m_listenSocket; + std::deque::const_iterator it; + for (it = m_conns.begin(); it != m_conns.end(); ++it) + if (maxFD < (*it)->sock()) + maxFD = (*it)->sock(); + return maxFD; } -//----------------------------------------------------------------------------- -int RADIUS::ProcessAcctOtherPacket(RAD_PACKET * packet) + +void RADIUS::buildFDSet(fd_set & fds) const { -// Fake. May be use it later -packet->packetType = RAD_ACCEPT_PACKET; -return 0; + FD_ZERO(&fds); + FD_SET(m_listenSocket, &fds); + std::deque::const_iterator it; + for (it = m_conns.begin(); it != m_conns.end(); ++it) + FD_SET((*it)->sock(), &fds); } -//----------------------------------------------------------------------------- -bool RADIUS::FindUser(USER_PTR * ui, const std::string & login) const -{ -if (users->FindByName(login, ui)) - { - return false; - } -return true; -} -//----------------------------------------------------------------------------- -bool RADIUS::CanAuthService(const std::string & svc) const + +void RADIUS::cleanupConns() { -return find(authServices.begin(), authServices.end(), svc) != authServices.end(); + std::deque::iterator pos; + for (pos = m_conns.begin(); pos != m_conns.end(); ++pos) + if (!(*pos)->isOk()) { + delete *pos; + *pos = NULL; + } + + pos = std::remove(m_conns.begin(), m_conns.end(), static_cast(NULL)); + m_conns.erase(pos, m_conns.end()); } -//----------------------------------------------------------------------------- -bool RADIUS::CanAcctService(const std::string & svc) const + +void RADIUS::handleEvents(const fd_set & fds) { -return find(acctServices.begin(), acctServices.end(), svc) != acctServices.end(); + if (FD_ISSET(m_listenSocket, &fds)) + acceptConnection(); + else + { + std::deque::iterator it; + for (it = m_conns.begin(); it != m_conns.end(); ++it) + if (FD_ISSET((*it)->sock(), &fds)) + (*it)->read(); + else + (*it)->tick(); + } } -//----------------------------------------------------------------------------- -bool RADIUS::IsAllowedService(const std::string & svc) const + +void RADIUS::acceptConnection() { -return CanAuthService(svc) || CanAcctService(svc); + if (m_config.connectionType == Config::UNIX) + acceptUNIX(); + else + acceptTCP(); } -//----------------------------------------------------------------------------- -namespace -{ -inline -void InitEncrypt(BLOWFISH_CTX * ctx, const std::string & password) +void RADIUS::acceptUNIX() { -unsigned char keyL[RAD_PASSWORD_LEN]; // Пароль для шифровки -memset(keyL, 0, RAD_PASSWORD_LEN); -strncpy((char *)keyL, password.c_str(), RAD_PASSWORD_LEN); -Blowfish_Init(ctx, keyL, RAD_PASSWORD_LEN); + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + socklen_t size = sizeof(addr); + int res = accept(m_listenSocket, reinterpret_cast(&addr), &size); + if (res == -1) + { + m_error = std::string("Failed to accept UNIX connection: ") + strerror(errno); + m_logger(m_error); + return; + } + printfd(__FILE__, "New UNIX connection: '%s'\n", addr.sun_path); + m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, addr.sun_path)); } -//----------------------------------------------------------------------------- -inline -void Encrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8) -{ -// len8 - длина в 8-ми байтовых блоках -if (dst != src) - memcpy(dst, src, len8 * 8); -for (size_t i = 0; i < len8; i++) - Blowfish_Encrypt(ctx, static_cast(dst) + i * 2, static_cast(dst) + i * 2 + 1); -} -//----------------------------------------------------------------------------- -inline -void Decrypt(BLOWFISH_CTX * ctx, void * dst, const void * src, unsigned long len8) +void RADIUS::acceptTCP() { -// len8 - длина в 8-ми байтовых блоках -if (dst != src) - memcpy(dst, src, len8 * 8); - -for (size_t i = 0; i < len8; i++) - Blowfish_Decrypt(ctx, static_cast(dst) + i * 2, static_cast(dst) + i * 2 + 1); + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + socklen_t size = sizeof(addr); + int res = accept(m_listenSocket, reinterpret_cast(&addr), &size); + if (res == -1) + { + m_error = std::string("Failed to accept TCP connection: ") + strerror(errno); + m_logger(m_error); + return; + } + std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port)); + printfd(__FILE__, "New TCP connection: '%s'\n", remote.c_str()); + m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, remote)); } - -} // namespace anonymous diff --git a/projects/stargazer/plugins/other/radius/radius.h b/projects/stargazer/plugins/other/radius/radius.h index 8f5ba2a4..2573cf9e 100644 --- a/projects/stargazer/plugins/other/radius/radius.h +++ b/projects/stargazer/plugins/other/radius/radius.h @@ -18,154 +18,88 @@ * Author : Maxim Mamontov */ -/* - * Radius data access plugin for Stargazer - * - * $Revision: 1.10 $ - * $Date: 2009/12/13 14:17:13 $ - * - */ - -#ifndef RADIUS_H -#define RADIUS_H - -#include - -#include -#include -#include -#include -#include -#include +#ifndef __STG_RADIUS_H__ +#define __STG_RADIUS_H__ #include "stg/os_int.h" #include "stg/auth.h" #include "stg/module_settings.h" -#include "stg/notifer.h" -#include "stg/user_ips.h" -#include "stg/user.h" -#include "stg/users.h" -#include "stg/blowfish.h" -#include "stg/rad_packets.h" #include "stg/logger.h" -extern "C" PLUGIN * GetPlugin(); +#include "config.h" +#include "conn.h" + +#include +#include -#define RAD_DEBUG (1) +#include +#include +#include +#include -class RADIUS; -//----------------------------------------------------------------------------- -class RAD_SETTINGS { -public: - RAD_SETTINGS() - : port(0), errorStr(), password(), - authServices(), acctServices() - {} - virtual ~RAD_SETTINGS() {} - const std::string & GetStrError() const { return errorStr; } - int ParseSettings(const MODULE_SETTINGS & s); - uint16_t GetPort() const { return port; } - const std::string & GetPassword() const { return password; } - const std::list & GetAuthServices() const { return authServices; } - const std::list & GetAcctServices() const { return acctServices; } +extern "C" PLUGIN * GetPlugin(); -private: - int ParseServices(const std::vector & str, std::list * lst); +class STORE; +class USERS; - uint16_t port; - std::string errorStr; - std::string password; - std::list authServices; - std::list acctServices; -}; -//----------------------------------------------------------------------------- -struct RAD_SESSION { - RAD_SESSION() : userName(), serviceType() {} - std::string userName; - std::string serviceType; -}; -//----------------------------------------------------------------------------- -class RADIUS :public AUTH { +class RADIUS : public AUTH { public: - RADIUS(); - virtual ~RADIUS() {} + RADIUS(); + virtual ~RADIUS() {} - void SetUsers(USERS * u) { users = u; } - void SetStore(STORE * s) { store = s; } - void SetStgSettings(const SETTINGS *) {} - void SetSettings(const MODULE_SETTINGS & s) { settings = s; } - int ParseSettings(); + void SetUsers(USERS* u) { m_users = u; } + void SetStore(STORE* s) { m_store = s; } + void SetStgSettings(const SETTINGS*) {} + void SetSettings(const MODULE_SETTINGS& s) { m_settings = s; } + int ParseSettings(); - int Start(); - int Stop(); - int Reload() { return 0; } - bool IsRunning() { return isRunning; } + int Start(); + int Stop(); + int Reload() { return 0; } + bool IsRunning() { return m_running; } - const std::string & GetStrError() const { return errorStr; } - std::string GetVersion() const { return "RADIUS data access plugin v 0.6"; } - uint16_t GetStartPosition() const { return 30; } - uint16_t GetStopPosition() const { return 30; } + const std::string& GetStrError() const { return m_error; } + std::string GetVersion() const { return "RADIUS data access plugin v 1.0"; } + uint16_t GetStartPosition() const { return 30; } + uint16_t GetStopPosition() const { return 30; } - int SendMessage(const STG_MSG &, uint32_t) const { return 0; } + int SendMessage(const STG_MSG&, uint32_t) const { return 0; } private: RADIUS(const RADIUS & rvalue); RADIUS & operator=(const RADIUS & rvalue); - static void * Run(void *); - int PrepareNet(); - int FinalizeNet(); - - ssize_t Send(const RAD_PACKET & packet, struct sockaddr_in * outerAddr); - int RecvData(RAD_PACKET * packet, struct sockaddr_in * outerAddr); - int ProcessData(RAD_PACKET * packet); - - int ProcessAutzPacket(RAD_PACKET * packet); - int ProcessAuthPacket(RAD_PACKET * packet); - int ProcessPostAuthPacket(RAD_PACKET * packet); - int ProcessAcctStartPacket(RAD_PACKET * packet); - int ProcessAcctStopPacket(RAD_PACKET * packet); - int ProcessAcctUpdatePacket(RAD_PACKET * packet); - int ProcessAcctOtherPacket(RAD_PACKET * packet); - - bool FindUser(USER_PTR * ui, const std::string & login) const; - bool CanAuthService(const std::string & svc) const; - bool CanAcctService(const std::string & svc) const; - bool IsAllowedService(const std::string & svc) const; - - struct SPrinter : public std::unary_function, void> - { - void operator()(const std::pair & it) - { - printfd("radius.cpp", "%s - ('%s', '%s')\n", it.first.c_str(), it.second.userName.c_str(), it.second.serviceType.c_str()); - } - }; + static void* run(void*); - BLOWFISH_CTX ctx; + bool reconnect(); + int createUNIX() const; + int createTCP() const; + void runImpl(); + int maxFD() const; + void buildFDSet(fd_set & fds) const; + void cleanupConns(); + void handleEvents(const fd_set & fds); + void acceptConnection(); + void acceptUNIX(); + void acceptTCP(); - mutable std::string errorStr; - RAD_SETTINGS radSettings; - MODULE_SETTINGS settings; - std::list authServices; - std::list acctServices; - std::map sessions; + mutable std::string m_error; + STG::Config m_config; - bool nonstop; - bool isRunning; + MODULE_SETTINGS m_settings; - USERS * users; - const SETTINGS * stgSettings; - const STORE * store; + bool m_running; + bool m_stopped; - pthread_t thread; - pthread_mutex_t mutex; + USERS* m_users; + const STORE* m_store; - int sock; + int m_listenSocket; + std::deque m_conns; - RAD_PACKET packet; + pthread_t m_thread; - PLUGIN_LOGGER logger; + PLUGIN_LOGGER m_logger; }; -//----------------------------------------------------------------------------- #endif diff --git a/projects/stargazer/settings_impl.cpp b/projects/stargazer/settings_impl.cpp index d115919d..67b4272f 100644 --- a/projects/stargazer/settings_impl.cpp +++ b/projects/stargazer/settings_impl.cpp @@ -14,27 +14,80 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -/* - * Date: 27.10.2002 - */ - /* * Author : Boris Mikhailenko */ -/* -$Revision: 1.45 $ -$Date: 2010/08/19 13:42:30 $ -$Author: faust $ -*/ +#include "settings_impl.h" +#include "stg/logger.h" +#include "stg/dotconfpp.h" +#include "stg/common.h" + +#include #include #include -#include -#include "stg/logger.h" -#include "stg/dotconfpp.h" -#include "settings_impl.h" +namespace +{ + +struct Error : public std::runtime_error +{ + Error(const std::string& message) : runtime_error(message) {} +}; + +std::vector toValues(const DOTCONFDocumentNode& node) +{ + std::vector values; + + size_t i = 0; + const char* value = NULL; + while ((value = node.getValue(i++)) != NULL) + values.push_back(value); + + return values; +} + +std::vector toPVS(const DOTCONFDocumentNode& node) +{ + std::vector pvs; + + const DOTCONFDocumentNode* child = node.getChildNode(); + while (child != NULL) + { + if (child->getName() == NULL) + continue; + + if (child->getChildNode() == NULL) + pvs.push_back(PARAM_VALUE(child->getName(), toValues(*child))); + else + pvs.push_back(PARAM_VALUE(child->getName(), toValues(*child), toPVS(*child))); + + child = child->getNextNode(); + } + + return pvs; +} + +unsigned toPeriod(const char* value) +{ + if (value == NULL) + throw Error("No detail stat period value."); + + std::string period(value); + if (period == "1") + return dsPeriod_1; + else if (period == "1/2") + return dsPeriod_1_2; + else if (period == "1/4") + return dsPeriod_1_4; + else if (period == "1/6") + return dsPeriod_1_6; + + throw Error("Invalid detail stat period value: '" + period + "'. Should be one of '1', '1/2', '1/4' or '1/6'."); +} + +} //----------------------------------------------------------------------------- SETTINGS_IMPL::SETTINGS_IMPL(const std::string & cd) @@ -66,78 +119,6 @@ SETTINGS_IMPL::SETTINGS_IMPL(const std::string & cd) { } //----------------------------------------------------------------------------- -SETTINGS_IMPL::SETTINGS_IMPL(const SETTINGS_IMPL & rval) - : SETTINGS(), - strError(), - modulesPath(rval.modulesPath), - dirName(rval.dirName), - confDir(rval.confDir), - scriptsDir(rval.scriptsDir), - rules(rval.rules), - logFile(rval.logFile), - pidFile(rval.pidFile), - monitorDir(rval.monitorDir), - monitoring(rval.monitoring), - detailStatWritePeriod(rval.detailStatWritePeriod), - statWritePeriod(rval.statWritePeriod), - stgExecMsgKey(rval.stgExecMsgKey), - executersNum(rval.executersNum), - fullFee(rval.fullFee), - dayFee(rval.dayFee), - dayResetTraff(rval.dayResetTraff), - spreadFee(rval.spreadFee), - freeMbAllowInet(rval.freeMbAllowInet), - dayFeeIsLastDay(rval.dayFeeIsLastDay), - writeFreeMbTraffCost(rval.writeFreeMbTraffCost), - showFeeInCash(rval.showFeeInCash), - messageTimeout(rval.messageTimeout), - feeChargeType(rval.feeChargeType), - reconnectOnTariffChange(rval.reconnectOnTariffChange), - modulesSettings(rval.modulesSettings), - storeModuleSettings(rval.storeModuleSettings), - logger(GetStgLogger()) -{ -} -//----------------------------------------------------------------------------- -int SETTINGS_IMPL::ParseModuleSettings(const DOTCONFDocumentNode * node, std::vector * params) -{ -const DOTCONFDocumentNode * childNode; -PARAM_VALUE pv; -const char * value; - -pv.param = node->getName(); - -if (node->getValue(1)) - { - strError = "Unexpected value \'" + std::string(node->getValue(1)) + "\'."; - return -1; - } - -value = node->getValue(0); - -if (!value) - { - strError = "Module name expected."; - return -1; - } - -childNode = node->getChildNode(); -while (childNode) - { - pv.param = childNode->getName(); - int i = 0; - while ((value = childNode->getValue(i++)) != NULL) - { - pv.value.push_back(value); - } - params->push_back(pv); - pv.value.clear(); - childNode = childNode->getNextNode(); - } - -return 0; -} -//----------------------------------------------------------------------------- void SETTINGS_IMPL::ErrorCallback(void * data, const char * buf) { printfd(__FILE__, "SETTINGS_IMPL::ErrorCallback() - %s\n", buf); @@ -207,11 +188,15 @@ while (node) if (strcasecmp(node->getName(), "DetailStatWritePeriod") == 0) { - if (ParseDetailStatWritePeriod(node->getValue(0)) != 0) - { - strError = "Incorrect DetailStatWritePeriod value: \'" + std::string(node->getValue(0)) + "\'"; + try + { + detailStatWritePeriod = toPeriod(node->getValue(0)); + } + catch (const Error& error) + { + strError = error.what(); return -1; - } + } } if (strcasecmp(node->getName(), "StatWritePeriod") == 0) @@ -387,8 +372,13 @@ while (node) } storeModulesCount++; + if (node->getValue(0) == NULL) + { + strError = "No module name in the StoreModule section."; + return -1; + } storeModuleSettings.moduleName = node->getValue(0); - ParseModuleSettings(node, &storeModuleSettings.moduleParams); + storeModuleSettings.moduleParams = toPVS(*node); } if (strcasecmp(node->getName(), "Modules") == 0) @@ -406,13 +396,14 @@ while (node) child = child->getNextNode(); continue; } - MODULE_SETTINGS modSettings; - modSettings.moduleParams.clear(); - modSettings.moduleName = child->getValue(); - ParseModuleSettings(child, &modSettings.moduleParams); + if (child->getValue(0) == NULL) + { + strError = "No module name in the Module section."; + return -1; + } - modulesSettings.push_back(modSettings); + modulesSettings.push_back(MODULE_SETTINGS(child->getValue(0), toPVS(*child))); child = child->getNextNode(); } @@ -431,29 +422,3 @@ while (node) return 0; } //----------------------------------------------------------------------------- -int SETTINGS_IMPL::ParseDetailStatWritePeriod(const std::string & detailStatPeriodStr) -{ -if (detailStatPeriodStr == "1") - { - detailStatWritePeriod = dsPeriod_1; - return 0; - } -else if (detailStatPeriodStr == "1/2") - { - detailStatWritePeriod = dsPeriod_1_2; - return 0; - } -else if (detailStatPeriodStr == "1/4") - { - detailStatWritePeriod = dsPeriod_1_4; - return 0; - } -else if (detailStatPeriodStr == "1/6") - { - detailStatWritePeriod = dsPeriod_1_6; - return 0; - } - -return -1; -} -//----------------------------------------------------------------------------- diff --git a/projects/stargazer/settings_impl.h b/projects/stargazer/settings_impl.h index 69013764..bab818c6 100644 --- a/projects/stargazer/settings_impl.h +++ b/projects/stargazer/settings_impl.h @@ -1,9 +1,3 @@ - /* - $Revision: 1.27 $ - $Date: 2010/08/19 13:42:30 $ - $Author: faust $ - */ - /* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -20,20 +14,10 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -/* - * Date: 27.10.2002 - */ - /* * Author : Boris Mikhailenko */ - /* - $Revision: 1.27 $ - $Date: 2010/08/19 13:42:30 $ - */ - - #ifndef SETTINGS_IMPL_H #define SETTINGS_IMPL_H @@ -41,16 +25,15 @@ #include #include "stg/settings.h" -#include "stg/common.h" -#include "stg/logger.h" #include "stg/module_settings.h" +#include "stg/ref.h" //----------------------------------------------------------------------------- enum DETAIL_STAT_PERIOD { -dsPeriod_1, -dsPeriod_1_2, -dsPeriod_1_4, -dsPeriod_1_6 + dsPeriod_1, + dsPeriod_1_2, + dsPeriod_1_4, + dsPeriod_1_6 }; //----------------------------------------------------------------------------- class STG_LOGGER; @@ -59,7 +42,6 @@ class DOTCONFDocumentNode; class SETTINGS_IMPL : public SETTINGS { public: SETTINGS_IMPL(const std::string &); - SETTINGS_IMPL(const SETTINGS_IMPL &); virtual ~SETTINGS_IMPL() {} int Reload() { return ReadSettings(); } int ReadSettings(); @@ -74,7 +56,7 @@ public: const std::string & GetRulesFileName() const { return rules; } const std::string & GetLogFileName() const { return logFile; } const std::string & GetPIDFileName() const { return pidFile; } - unsigned GetDetailStatWritePeriod() const + unsigned GetDetailStatWritePeriod() const { return detailStatWritePeriod; } unsigned GetStatWritePeriod() const { return statWritePeriod * 60; } unsigned GetDayFee() const { return dayFee; } @@ -101,9 +83,6 @@ public: private: - int ParseDetailStatWritePeriod(const std::string & str); - int ParseModuleSettings(const DOTCONFDocumentNode * dirNameNode, std::vector * params); - static void ErrorCallback(void * data, const char * buf); std::string strError; @@ -112,7 +91,7 @@ private: std::string modulesPath; std::vector dirName; std::string confDir; - std::string scriptsDir; + std::string scriptsDir; std::string rules; std::string logFile; std::string pidFile; @@ -137,7 +116,7 @@ private: std::vector modulesSettings; MODULE_SETTINGS storeModuleSettings; - STG_LOGGER & logger; + STG::RefWrapper logger; }; //----------------------------------------------------------------------------- diff --git a/stglibs/common.lib/Makefile b/stglibs/common.lib/Makefile index 1f662c25..932e2709 100644 --- a/stglibs/common.lib/Makefile +++ b/stglibs/common.lib/Makefile @@ -7,9 +7,11 @@ include ../../Makefile.conf LIB_NAME = stgcommon SRCS = common.cpp \ - strptime.cpp + strptime.cpp \ + blockio.cpp -INCS = common.h +INCS = common.h \ + blockio.h LIBS += $(LIBICONV) diff --git a/stglibs/common.lib/blockio.cpp b/stglibs/common.lib/blockio.cpp new file mode 100644 index 00000000..04fd1d81 --- /dev/null +++ b/stglibs/common.lib/blockio.cpp @@ -0,0 +1,102 @@ +#include "stg/blockio.h" + +namespace +{ + +void* adjust(void* base, size_t shift) +{ + char* ptr = static_cast(base); + return ptr + shift; +} + +} // namspace anonymous + +using STG::BlockReader; +using STG::BlockWriter; + +BlockReader::BlockReader(const IOVec& ioVec) + : m_dest(ioVec), + m_remainder(0) +{ + for (size_t i = 0; i < m_dest.size(); ++i) + m_remainder += m_dest[i].iov_len; +} + +bool BlockReader::read(int socket) +{ + if (m_remainder == 0) + return true; + + size_t offset = m_dest.size() - 1; + size_t toRead = m_remainder; + while (offset > 0) { + if (toRead < m_dest[offset].iov_len) + break; + toRead -= m_dest[offset].iov_len; + --offset; + } + + IOVec dest(m_dest.size() - offset); + for (size_t i = 0; i < dest.size(); ++i) { + if (i == 0) { + dest[0].iov_len = toRead; + dest[0].iov_base = adjust(m_dest[offset].iov_base, m_dest[offset].iov_len - toRead); + } else { + dest[i] = m_dest[offset + i]; + } + } + + ssize_t res = readv(socket, dest.data(), dest.size()); + if (res < 0) + return false; + if (res == 0) + return m_remainder == 0; + if (res < static_cast(m_remainder)) + m_remainder -= res; + else + m_remainder = 0; + return true; +} + +BlockWriter::BlockWriter(const IOVec& ioVec) + : m_source(ioVec), + m_remainder(0) +{ + for (size_t i = 0; i < m_source.size(); ++i) + m_remainder += m_source[i].iov_len; +} + +bool BlockWriter::write(int socket) +{ + if (m_remainder == 0) + return true; + + size_t offset = m_source.size() - 1; + size_t toWrite = m_remainder; + while (offset > 0) { + if (toWrite < m_source[offset].iov_len) + break; + toWrite -= m_source[offset].iov_len; + --offset; + } + + IOVec source(m_source.size() - offset); + for (size_t i = 0; i < source.size(); ++i) { + if (i == 0) { + source[0].iov_len = toWrite; + source[0].iov_base = adjust(m_source[offset].iov_base, m_source[offset].iov_len - toWrite); + } else { + source[i] = m_source[offset + i]; + } + } + ssize_t res = writev(socket, source.data(), source.size()); + if (res < 0) + return false; + if (res == 0) + return m_remainder == 0; + if (res < static_cast(m_remainder)) + m_remainder -= res; + else + m_remainder = 0; + return true; +} diff --git a/stglibs/common.lib/common.cpp b/stglibs/common.lib/common.cpp index 22ec354b..90493e9f 100644 --- a/stglibs/common.lib/common.cpp +++ b/stglibs/common.lib/common.cpp @@ -32,7 +32,8 @@ // Like FreeBSD4 #include #include -#include +#include +#include #include @@ -1074,3 +1075,37 @@ std::string ToPrintable(const std::string & src) return dest; } + +uid_t str2uid(const std::string& name) +{ + const passwd* res = getpwnam(name.c_str()); + if (res == NULL) + return -1; + return res->pw_uid; +} + +gid_t str2gid(const std::string& name) +{ + const group* res = getgrnam(name.c_str()); + if (res == NULL) + return -1; + return res->gr_gid; +} + +mode_t str2mode(const std::string& name) +{ + if (name.length() < 3 || name.length() > 4) + return -1; + + if (name.length() == 4 && name[0] != '0') + return -1; + + mode_t res = 0; + for (size_t i = 0; i < name.length(); ++i) + { + if (name[i] > '7' || name[i] < '0') + return -1; + res = (res << 3) + (name[i] - '0'); + } + return res; +} diff --git a/stglibs/common.lib/include/stg/blockio.h b/stglibs/common.lib/include/stg/blockio.h new file mode 100644 index 00000000..3879e39e --- /dev/null +++ b/stglibs/common.lib/include/stg/blockio.h @@ -0,0 +1,43 @@ +#ifndef __STG_STGLIBS_BLOCK_IO_H__ +#define __STG_STGLIBS_BLOCK_IO_H__ + +#include + +#include + +namespace STG +{ + +typedef std::vector IOVec; + +class BlockReader +{ + public: + BlockReader(const IOVec& ioVec); + + bool read(int socket); + bool done() const { return m_remainder == 0; } + size_t remainder() const { return m_remainder; } + + private: + IOVec m_dest; + size_t m_remainder; +}; + +class BlockWriter +{ + public: + BlockWriter(const IOVec& ioVec); + + bool write(int socket); + bool done() const { return m_remainder == 0; } + size_t remainder() const { return m_remainder; } + + private: + IOVec m_source; + size_t m_remainder; +}; + +} // namespace STG + +#endif diff --git a/stglibs/common.lib/include/stg/common.h b/stglibs/common.lib/include/stg/common.h index f46c922a..e90ad75a 100644 --- a/stglibs/common.lib/include/stg/common.h +++ b/stglibs/common.lib/include/stg/common.h @@ -27,16 +27,15 @@ #ifndef common_h #define common_h -#ifdef __BORLANDC__ -#include -#else -#include -#endif +#include "stg/os_int.h" +#include "stg/const.h" + #include #include +#include -#include "stg/os_int.h" -#include "stg/const.h" +#include // uid_t, gid_t +#include // mode_t #define STAT_TIME_3 (1) #define STAT_TIME_2 (2) @@ -299,4 +298,8 @@ const std::string & unsigned2str(varT x, std::string & s) char * stg_strptime(const char *, const char *, struct tm *); time_t stg_timegm(struct tm *); +uid_t str2uid(const std::string& name); +gid_t str2gid(const std::string& name); +mode_t str2mode(const std::string& mode); + #endif diff --git a/stglibs/json.lib/Makefile b/stglibs/json.lib/Makefile new file mode 100644 index 00000000..947fce40 --- /dev/null +++ b/stglibs/json.lib/Makefile @@ -0,0 +1,18 @@ +############################################################################### +# $Id: Makefile,v 1.9 2010/08/18 07:47:03 faust Exp $ +############################################################################### + +LIB_NAME = stgjson + +STGLIBS = -lstgcommon +LIBS = + +SRCS = parser.cpp \ + generator.cpp + +INCS = json_parser.h \ + json_generator.h + +LIB_INCS = -I ../common.lib/include + +include ../Makefile.in diff --git a/stglibs/json.lib/generator.cpp b/stglibs/json.lib/generator.cpp new file mode 100644 index 00000000..d18ef04a --- /dev/null +++ b/stglibs/json.lib/generator.cpp @@ -0,0 +1,82 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#include "stg/json_generator.h" + +#include + +using STG::JSON::NullGen; +using STG::JSON::BoolGen; +using STG::JSON::StringGen; +using STG::JSON::NumberGen; +using STG::JSON::MapGen; +using STG::JSON::ArrayGen; +using STG::JSON::Callback; + +namespace +{ + +void genString(yajl_gen_t* handle, const std::string& value) +{ + yajl_gen_string(handle, reinterpret_cast(value.c_str()), value.length()); +} + +} + +void NullGen::run(yajl_gen_t* handle) const { yajl_gen_null(handle); } +void BoolGen::run(yajl_gen_t* handle) const { yajl_gen_bool(handle, m_value); } +void StringGen::run(yajl_gen_t* handle) const { genString(handle, m_value); } +void NumberGen::run(yajl_gen_t* handle) const { yajl_gen_number(handle, m_value.c_str(), m_value.length()); } + +void MapGen::run(yajl_gen_t* handle) const +{ + yajl_gen_map_open(handle); + for (Value::const_iterator it = m_value.begin(); it != m_value.end(); ++it) + { + genString(handle, it->first); + it->second.first->run(handle); + } + yajl_gen_map_close(handle); +} + +void ArrayGen::run(yajl_gen_t* handle) const +{ + yajl_gen_array_open(handle); + for (Value::const_iterator it = m_value.begin(); it != m_value.end(); ++it) + it->first->run(handle); + yajl_gen_array_close(handle); +} + +bool STG::JSON::generate(Gen& gen, Callback callback, void* data) +{ + yajl_gen handle = yajl_gen_alloc(NULL); + + gen.run(handle); + + const unsigned char* buf = NULL; + size_t size = 0; + yajl_gen_get_buf(handle, &buf, &size); + + bool res = callback(data, reinterpret_cast(buf), size); + + yajl_gen_free(handle); + + return res; +} diff --git a/stglibs/json.lib/include/stg/json_generator.h b/stglibs/json.lib/include/stg/json_generator.h new file mode 100644 index 00000000..4f1523fc --- /dev/null +++ b/stglibs/json.lib/include/stg/json_generator.h @@ -0,0 +1,122 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#ifndef __STG_STGLIBS_JSON_GENERATOR_H__ +#define __STG_STGLIBS_JSON_GENERATOR_H__ + +#include +#include +#include +#include + +#include + +struct yajl_gen_t; + +namespace STG +{ +namespace JSON +{ + +struct Gen +{ + virtual ~Gen() {} + virtual void run(yajl_gen_t* handle) const = 0; +}; + +struct NullGen : public Gen +{ + virtual void run(yajl_gen_t* handle) const; +}; + +class BoolGen : public Gen +{ + public: + explicit BoolGen(bool value) : m_value(value) {} + virtual void run(yajl_gen_t* handle) const; + private: + bool m_value; +}; + +class StringGen : public Gen +{ + public: + explicit StringGen(const std::string& value) : m_value(value) {} + virtual void run(yajl_gen_t* handle) const; + private: + std::string m_value; +}; + +class NumberGen : public Gen +{ + public: + explicit NumberGen(const std::string& value) : m_value(value) {} + template + explicit NumberGen(const T& value) : m_value(x2str(value)) {} + virtual void run(yajl_gen_t* handle) const; + private: + std::string m_value; +}; + +class MapGen : public Gen +{ + public: + MapGen() {} + virtual ~MapGen() + { + for (Value::iterator it = m_value.begin(); it != m_value.end(); ++it) + if (it->second.second) + delete it->second.first; + } + MapGen& add(const std::string& key, Gen* value) { m_value[key] = std::make_pair(value, true); return *this; } + MapGen& add(const std::string& key, Gen& value) { m_value[key] = std::make_pair(&value, false); return *this; } + virtual void run(yajl_gen_t* handle) const; + private: + typedef std::pair SmartGen; + typedef std::map Value; + Value m_value; +}; + +class ArrayGen : public Gen +{ + public: + ArrayGen() {} + virtual ~ArrayGen() + { + for (Value::iterator it = m_value.begin(); it != m_value.end(); ++it) + if (it->second) + delete it->first; + } + void add(Gen* value) { m_value.push_back(std::make_pair(value, true)); } + void add(Gen& value) { m_value.push_back(std::make_pair(&value, false)); } + virtual void run(yajl_gen_t* handle) const; + private: + typedef std::pair SmartGen; + typedef std::vector Value; + Value m_value; +}; + +typedef bool (*Callback)(void* /*data*/, const char* /*buf*/, size_t /*size*/); +bool generate(Gen& gen, Callback callback, void* data); + +} +} + +#endif diff --git a/stglibs/json.lib/include/stg/json_parser.h b/stglibs/json.lib/include/stg/json_parser.h new file mode 100644 index 00000000..a6142579 --- /dev/null +++ b/stglibs/json.lib/include/stg/json_parser.h @@ -0,0 +1,107 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#ifndef __STG_STGLIBS_JSON_PARSER_H__ +#define __STG_STGLIBS_JSON_PARSER_H__ + +#include "stg/common.h" + +#include +#include + +#include + +namespace STG +{ +namespace JSON +{ + +struct NodeParser +{ + virtual ~NodeParser() {} + + virtual NodeParser* parseNull() { return this; } + virtual NodeParser* parseBoolean(const bool& /*value*/) { return this; } + virtual NodeParser* parseNumber(const std::string& /*value*/) { return this; } + virtual NodeParser* parseString(const std::string& /*value*/) { return this; } + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseMapKey(const std::string& /*value*/) { return this; } + virtual NodeParser* parseEndMap() { return this; } + virtual NodeParser* parseStartArray() { return this; } + virtual NodeParser* parseEndArray() { return this; } +}; + +class Parser +{ + public: + explicit Parser(NodeParser* topParser); + virtual ~Parser(); + + bool append(const char* data, size_t size); + bool last(); + + private: + class Impl; + boost::scoped_ptr m_impl; +}; + +template +class EnumParser : public NodeParser +{ + public: + typedef std::map Codes; + EnumParser(NodeParser* next, T& data, std::string& dataStr, const Codes& codes) + : m_next(next), m_data(data), m_dataStr(dataStr), m_codes(codes) {} + virtual NodeParser* parseString(const std::string& value) + { + m_dataStr = value; + const typename Codes::const_iterator it = m_codes.find(ToLower(value)); + if (it != m_codes.end()) + m_data = it->second; + return m_next; + } + private: + NodeParser* m_next; + T& m_data; + std::string& m_dataStr; + const Codes& m_codes; +}; + +class PairsParser : public NodeParser +{ + public: + typedef std::map Pairs; + + PairsParser(NodeParser* next, Pairs& pairs) : m_next(next), m_pairs(pairs) {} + + virtual NodeParser* parseStartMap() { return this; } + virtual NodeParser* parseString(const std::string& value) { m_pairs[m_key] = value; return this; } + virtual NodeParser* parseMapKey(const std::string& value) { m_key = value; return this; } + virtual NodeParser* parseEndMap() { return m_next; } + private: + NodeParser* m_next; + Pairs& m_pairs; + std::string m_key; +}; + +} +} + +#endif diff --git a/stglibs/json.lib/parser.cpp b/stglibs/json.lib/parser.cpp new file mode 100644 index 00000000..57115048 --- /dev/null +++ b/stglibs/json.lib/parser.cpp @@ -0,0 +1,123 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +/* + * Author : Maxim Mamontov + */ + +#include "stg/json_parser.h" + +#include + +using STG::JSON::Parser; +using STG::JSON::NodeParser; + +class Parser::Impl +{ + public: + Impl(NodeParser* topParser); + ~Impl() + { + yajl_free(m_handle); + } + + bool append(const char* data, size_t size) { return yajl_parse(m_handle, reinterpret_cast(data), size) == yajl_status_ok; } + bool last() { return yajl_complete_parse(m_handle) == yajl_status_ok; } + + static int parseNull(void* ctx) + { return runParser(ctx, &NodeParser::parseNull); } + static int parseBoolean(void* ctx, int value) + { return runParser(ctx, &NodeParser::parseBoolean, value != 0); } + static int parseNumber(void* ctx, const char* value, size_t size) + { return runParser(ctx, &NodeParser::parseNumber, std::string(value, size)); } + static int parseString(void* ctx, const unsigned char* value, size_t size) + { return runParser(ctx, &NodeParser::parseString, std::string(reinterpret_cast(value), size)); } + static int parseStartMap(void* ctx) + { return runParser(ctx, &NodeParser::parseStartMap); } + static int parseMapKey(void* ctx, const unsigned char* value, size_t size) + { return runParser(ctx, &NodeParser::parseMapKey, std::string(reinterpret_cast(value), size)); } + static int parseEndMap(void* ctx) + { return runParser(ctx, &NodeParser::parseEndMap); } + static int parseStartArray(void* ctx) + { return runParser(ctx, &NodeParser::parseStartArray); } + static int parseEndArray(void* ctx) + { return runParser(ctx, &NodeParser::parseEndArray); } + + private: + yajl_handle m_handle; + NodeParser* m_parser; + + static yajl_callbacks callbacks; + + static NodeParser& getParser(void* ctx) { return *static_cast(ctx)->m_parser; } + static bool runParser(void* ctx, NodeParser* (NodeParser::*func)()) + { + Impl& p = *static_cast(ctx); + NodeParser* next = (p.m_parser->*func)(); + if (next != NULL) + p.m_parser = next; + return next != NULL; + } + template + static bool runParser(void* ctx, NodeParser* (NodeParser::*func)(const T&), const T& value) + { + Impl& p = *static_cast(ctx); + NodeParser* next = (p.m_parser->*func)(value); + if (next != NULL) + p.m_parser = next; + return next != NULL; + } +}; + +yajl_callbacks Parser::Impl::callbacks = { + Parser::Impl::parseNull, + Parser::Impl::parseBoolean, + NULL, // parsing of integer is done using parseNumber + NULL, // parsing of double is done using parseNumber + Parser::Impl::parseNumber, + Parser::Impl::parseString, + Parser::Impl::parseStartMap, + Parser::Impl::parseMapKey, + Parser::Impl::parseEndMap, + Parser::Impl::parseStartArray, + Parser::Impl::parseEndArray +}; + +Parser::Impl::Impl(NodeParser* topParser) + : m_handle(yajl_alloc(&callbacks, NULL, this)), + m_parser(topParser) +{ + yajl_config(m_handle, yajl_allow_multiple_values, 1); +} + +Parser::Parser(NodeParser* topParser) + : m_impl(new Impl(topParser)) +{ +} + +Parser::~Parser() +{ +} + +bool Parser::append(const char* data, size_t size) +{ + return m_impl->append(data, size); +} + +bool Parser::last() +{ + return m_impl->last(); +}