position(pos),
error(message)
{}
+ virtual ~ParserError() throw() {}
size_t position;
std::string error;
return start;
}
-size_t checkChar(const std:string& value, size_t start, char ch)
+size_t checkChar(const std::string& value, size_t start, char ch)
{
if (start >= value.length())
- throw ParserError(start, "Unexpected end of string. Expected '" + std::string(ch) + "'.");
+ throw ParserError(start, "Unexpected end of string. Expected '" + std::string(1, ch) + "'.");
if (value[start] != ch)
- throw ParserError(start, "Expected '" + std::string(ch) + "', got '" + std::string(value[start]) + "'.");
+ throw ParserError(start, "Expected '" + std::string(1, ch) + "', got '" + std::string(1, value[start]) + "'.");
return start + 1;
}
else
throw ParserError(start, "Unexpected whitespace. Expected string.");
}
- return dest;
+ return std::make_pair(start, dest);
}
Config::Pairs toPairs(const std::vector<std::string>& values)
start = key.first;
pair.first = key.second;
start = skipSpaces(value, start);
- start = checkChar(value, start, ',')
+ start = checkChar(value, start, ',');
start = skipSpaces(value, start);
- std::pair<size_t, std::string> value = readString(value, start);
+ std::pair<size_t, std::string> val = readString(value, start);
start = key.first;
- pair.second = value.second;
+ pair.second = val.second;
start = skipSpaces(value, start);
start = checkChar(value, start, ')');
if (res.find(pair.first) != res.end())
if (values.empty())
return 0;
T res = 0;
- if (srt2x(values[0], res) == 0)
+ if (str2x(values[0], res) == 0)
return res;
return 0;
}
Config::Pairs parseVector(const std::string& paramName, const MODULE_SETTINGS& params)
{
for (size_t i = 0; i < params.moduleParams.size(); ++i)
- if (params.moduleParams[i].first == paramName)
- return toPairs(params.moduleParams[i].second);
+ if (params.moduleParams[i].param == paramName)
+ return toPairs(params.moduleParams[i].value);
return Config::Pairs();
}
bool parseBool(const std::string& paramName, const MODULE_SETTINGS& params)
{
for (size_t i = 0; i < params.moduleParams.size(); ++i)
- if (params.moduleParams[i].first == paramName)
- return toBool(params.moduleParams[i].second);
+ if (params.moduleParams[i].param == paramName)
+ return toBool(params.moduleParams[i].value);
return false;
}
std::string parseString(const std::string& paramName, const MODULE_SETTINGS& params)
{
for (size_t i = 0; i < params.moduleParams.size(); ++i)
- if (params.moduleParams[i].first == paramName)
- return toString(params.moduleParams[i].second);
+ if (params.moduleParams[i].param == paramName)
+ return toString(params.moduleParams[i].value);
return "";
}
T parseInt(const std::string& paramName, const MODULE_SETTINGS& params)
{
for (size_t i = 0; i < params.moduleParams.size(); ++i)
- if (params.moduleParams[i].first == paramName)
- return toInt<T>(params.moduleParams[i].second);
+ if (params.moduleParams[i].param == paramName)
+ return toInt<T>(params.moduleParams[i].value);
return 0;
}
reply(parseVector("reply", settings)),
verbose(parseBool("verbose", settings)),
bindAddress(parseString("bind_address", settings)),
+ portStr(parseString("port", settings)),
port(parseInt<uint16_t>("port", settings)),
key(parseString("key", settings))
{
struct Config
{
typedef std::map<std::string, std::string> Pairs;
- typedef Pairs::value_type Pair;
+ typedef std::pair<std::string, std::string> Pair;
+ enum Type { UNIX, TCP };
Config(const MODULE_SETTINGS& settings);
bool verbose;
+ Type connectionType;
std::string bindAddress;
+ std::string portStr;
uint16_t port;
std::string key;
};
#include "config.h"
+#include "stg/json_parser.h"
+#include "stg/json_generator.h"
#include "stg/users.h"
#include "stg/user.h"
#include "stg/logger.h"
#include "stg/common.h"
+#include <yajl/yajl_gen.h>
+
+#include <map>
#include <cstring>
#include <cerrno>
+#include <unistd.h>
+
using STG::Conn;
+using STG::Config;
+using STG::JSON::Parser;
+using STG::JSON::PairsParser;
+using STG::JSON::EnumParser;
+using STG::JSON::NodeParser;
+using STG::JSON::Gen;
+using STG::JSON::MapGen;
+using STG::JSON::StringGen;
-Conn::Conn(USERS& users, PLUGIN_LOGGER & logger, const Config& config)
- : m_users(users),
- m_logger(logger),
- m_config(config)
+namespace
+{
+
+double CONN_TIMEOUT = 5;
+double PING_TIMEOUT = 1;
+
+enum Packet
+{
+ PING,
+ PONG,
+ DATA
+};
+
+enum Stage
+{
+ AUTHORIZE,
+ AUTHENTICATE,
+ PREACCT,
+ ACCOUNTING,
+ POSTAUTH
+};
+
+std::map<std::string, Packet> packetCodes;
+std::map<std::string, Stage> stageCodes;
+
+class PacketParser : public EnumParser<Packet>
+{
+ public:
+ PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
+ : EnumParser(next, packet, packetStr, packetCodes)
+ {
+ if (!packetCodes.empty())
+ return;
+ packetCodes["ping"] = PING;
+ packetCodes["pong"] = PONG;
+ packetCodes["data"] = DATA;
+ }
+};
+
+class StageParser : public EnumParser<Stage>
+{
+ public:
+ StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
+ : EnumParser(next, stage, stageStr, stageCodes)
+ {
+ if (!stageCodes.empty())
+ return;
+ stageCodes["authorize"] = AUTHORIZE;
+ stageCodes["authenticate"] = AUTHENTICATE;
+ stageCodes["preacct"] = PREACCT;
+ stageCodes["accounting"] = ACCOUNTING;
+ stageCodes["postauth"] = POSTAUTH;
+ }
+};
+
+class TopParser : public NodeParser
+{
+ public:
+ TopParser()
+ : m_packetParser(this, m_packet, m_packetStr),
+ m_stageParser(this, m_stage, m_stageStr),
+ m_pairsParser(this, m_data)
+ {}
+
+ virtual NodeParser* parseStartMap() { return this; }
+ virtual NodeParser* parseMapKey(const std::string& value)
+ {
+ std::string key = ToLower(value);
+
+ if (key == "packet")
+ return &m_packetParser;
+ else if (key == "stage")
+ return &m_stageParser;
+ else if (key == "pairs")
+ return &m_pairsParser;
+
+ return this;
+ }
+ virtual NodeParser* parseEndMap() { return this; }
+
+ const std::string& packetStr() const { return m_packetStr; }
+ Packet packet() const { return m_packet; }
+ const std::string& stageStr() const { return m_stageStr; }
+ Stage stage() const { return m_stage; }
+ const Config::Pairs& data() const { return m_data; }
+
+ private:
+ std::string m_packetStr;
+ Packet m_packet;
+ std::string m_stageStr;
+ Stage m_stage;
+ Config::Pairs m_data;
+
+ PacketParser m_packetParser;
+ StageParser m_stageParser;
+ PairsParser m_pairsParser;
+};
+
+class ProtoParser : public Parser
+{
+ public:
+ ProtoParser() : Parser( &m_topParser ) {}
+
+ const std::string& packetStr() const { return m_topParser.packetStr(); }
+ Packet packet() const { return m_topParser.packet(); }
+ const std::string& stageStr() const { return m_topParser.stageStr(); }
+ Stage stage() const { return m_topParser.stage(); }
+ const Config::Pairs& data() const { return m_topParser.data(); }
+
+ private:
+ TopParser m_topParser;
+};
+
+class PacketGen : public Gen
+{
+ public:
+ PacketGen(const std::string& type)
+ : m_type(type)
+ {
+ m_gen.add("packet", m_type);
+ }
+ void run(yajl_gen_t* handle) const
+ {
+ m_gen.run(handle);
+ }
+ PacketGen& add(const std::string& key, const std::string& value)
+ {
+ m_gen.add(key, new StringGen(value));
+ return *this;
+ }
+ PacketGen& add(const std::string& key, MapGen* map)
+ {
+ m_gen.add(key, map);
+ return *this;
+ }
+ private:
+ MapGen m_gen;
+ StringGen m_type;
+};
+
+}
+
+class Conn::Impl
+{
+ public:
+ Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
+ ~Impl();
+
+ int sock() const { return m_sock; }
+
+ bool read();
+ bool tick();
+
+ bool isOk() const { return m_ok; }
+
+ private:
+ USERS& m_users;
+ PLUGIN_LOGGER& m_logger;
+ const Config& m_config;
+ int m_sock;
+ std::string m_remote;
+ bool m_ok;
+ time_t m_lastPing;
+ time_t m_lastActivity;
+ ProtoParser m_parser;
+
+ bool process();
+ bool processPing();
+ bool processPong();
+ bool processData();
+ bool answer(const USER& user);
+ bool answerNo();
+ bool sendPing();
+ bool sendPong();
+
+ static bool write(void* data, const char* buf, size_t size);
+};
+
+Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
+ : m_impl(new Impl(users, logger, config, fd, remote))
{
}
{
}
+int Conn::sock() const
+{
+ return m_impl->sock();
+}
+
bool Conn::read()
{
- ssize_t res = read(m_sock, m_buffer, m_bufferSize);
+ return m_impl->read();
+}
+
+bool Conn::tick()
+{
+ return m_impl->tick();
+}
+
+bool Conn::isOk() const
+{
+ return m_impl->isOk();
+}
+
+Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
+ : m_users(users),
+ m_logger(logger),
+ m_config(config),
+ m_sock(fd),
+ m_remote(remote),
+ m_ok(true),
+ m_lastPing(time(NULL)),
+ m_lastActivity(m_lastPing)
+{
+}
+
+Conn::Impl::~Impl()
+{
+ close(m_sock);
+}
+
+bool Conn::Impl::read()
+{
+ static std::vector<char> buffer(1024);
+ ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
if (res < 0)
{
- m_state = ERROR;
- Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Reason: '" + strerror(errno) + "'");
+ m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+ m_ok = false;
return false;
}
- if (res == 0 && m_state != DATA) // EOF is ok for data.
+ m_lastActivity = time(NULL);
+ if (res == 0)
+ {
+ if (!m_parser.done())
+ {
+ m_ok = false;
+ m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
+ return false;
+ }
+ return process();
+ }
+ return m_parser.append(buffer.data(), res);
+}
+
+bool Conn::Impl::tick()
+{
+ time_t now = time(NULL);
+ if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
{
- m_state = ERROR;
- Log(__FILE__, "Failed to read data from " + inet_ntostring(IP()) + ":" + x2str(port()) + ". Unexpected EOF.");
+ m_logger("Connection to " + m_remote + " timed out.");
+ m_ok = false;
return false;
}
- m_bufferSize -= res;
- return HandleBuffer(res);
+ if (difftime(now, m_lastPing) > PING_TIMEOUT)
+ sendPing();
+ return true;
+}
+
+bool Conn::Impl::process()
+{
+ switch (m_parser.packet())
+ {
+ case PING:
+ return processPing();
+ case PONG:
+ return processPong();
+ case DATA:
+ return processData();
+ }
+ m_logger("Received invalid packet type: " + m_parser.packetStr());
+ return false;
+}
+
+bool Conn::Impl::processPing()
+{
+ return sendPong();
+}
+
+bool Conn::Impl::processPong()
+{
+ m_lastActivity = time(NULL);
+ return true;
+}
+
+bool Conn::Impl::processData()
+{
+ int handle = m_users.OpenSearch();
+
+ USER_PTR user = NULL;
+ bool match = true;
+ while (m_users.SearchNext(handle, &user))
+ {
+ if (user == NULL)
+ continue;
+
+ match = true;
+ for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
+ {
+ Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
+ if (pos == m_parser.data().end())
+ {
+ match = false;
+ break;
+ }
+ if (user->GetParamValue(it->second) != pos->second)
+ {
+ match = false;
+ break;
+ }
+ }
+ if (!match)
+ continue;
+ answer(*user);
+ break;
+ }
+
+ if (!match)
+ answerNo();
+
+ m_users.CloseSearch(handle);
+
+ return true;
+}
+
+bool Conn::Impl::answer(const USER& user)
+{
+ boost::scoped_ptr<MapGen> reply(new MapGen);
+ for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
+ reply->add(it->first, new StringGen(user.GetParamValue(it->second)));
+
+ boost::scoped_ptr<MapGen> modify(new MapGen);
+ for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
+ modify->add(it->first, new StringGen(user.GetParamValue(it->second)));
+
+ PacketGen gen("data");
+ gen.add("result", "ok")
+ .add("reply", reply.get())
+ .add("modify", modify.get());
+
+ m_lastPing = time(NULL);
+
+ return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::answerNo()
+{
+ PacketGen gen("data");
+ gen.add("result", "ok");
+
+ m_lastPing = time(NULL);
+
+ return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPing()
+{
+ PacketGen gen("ping");
+
+ m_lastPing = time(NULL);
+
+ return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::sendPong()
+{
+ PacketGen gen("pong");
+
+ m_lastPing = time(NULL);
+
+ return generate(gen, &Conn::Impl::write, this);
+}
+
+bool Conn::Impl::write(void* data, const char* buf, size_t size)
+{
+ Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
+ while (size > 0)
+ {
+ ssize_t res = ::write(conn.m_sock, buf, size);
+ if (res < 0)
+ {
+ conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
+ conn.m_ok = false;
+ return false;
+ }
+ size -= res;
+ }
+ return true;
}
#ifndef __STG_SGCONFIG_CONN_H__
#define __STG_SGCONFIG_CONN_H__
-#include "stg/os_int.h"
+#include <boost/scoped_ptr.hpp>
-#include <stdexcept>
#include <string>
+class USER;
class USERS;
+class PLUGIN_LOGGER;
namespace STG
{
+class Config;
+
class Conn
{
public:
- struct Error : public std::runtime_error
- {
- Error(const std::string& message) : runtime_error(message.c_str()) {}
- };
-
- Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config);
+ Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
~Conn();
- int sock() const { return m_sock; }
+ int sock() const;
bool read();
+ bool tick();
+
+ bool isOk() const;
private:
- USERS& m_users;
- PLUGIN_LOGGER& m_logger;
- const Config& m_config;
+ class Impl;
+ boost::scoped_ptr<Impl> m_impl;
};
}
#include "stg/users.h"
#include "stg/plugin_creator.h"
+#include <algorithm>
+#include <stdexcept>
#include <csignal>
-#include <cerror>
+#include <cerrno>
#include <cstring>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h> // UNIX
+#include <netinet/in.h> // IP
+#include <netinet/tcp.h> // TCP
+#include <netdb.h>
+
+using STG::Config;
+using STG::Conn;
+
namespace
{
}
RADIUS::RADIUS()
- : m_running(false),
+ : m_config(m_settings),
+ m_running(false),
m_stopped(true),
m_users(NULL),
m_store(NULL),
+ m_listenSocket(0),
m_logger(GetPluginLogger(GetStgLogger(), "radius"))
{
}
{
try {
m_config = STG::Config(m_settings);
- return 0;
+ return reconnect() ? 0 : -1;
} catch (const std::runtime_error& ex) {
m_logger("Failed to parse settings. %s", ex.what());
return -1;
return NULL;
}
+bool RADIUS::reconnect()
+{
+ if (!m_conns.empty())
+ {
+ std::deque<STG::Conn *>::const_iterator it;
+ for (it = m_conns.begin(); it != m_conns.end(); ++it)
+ delete(*it);
+ m_conns.clear();
+ }
+ if (m_listenSocket != 0)
+ {
+ shutdown(m_listenSocket, SHUT_RDWR);
+ close(m_listenSocket);
+ if (m_config.connectionType == Config::UNIX)
+ unlink(m_config.bindAddress.c_str());
+ }
+ if (m_config.connectionType == Config::UNIX)
+ m_listenSocket = createUNIX();
+ else
+ m_listenSocket = createTCP();
+ if (m_listenSocket == 0)
+ return false;
+ if (listen(m_listenSocket, 100) == -1)
+ {
+ m_error = std::string("Error starting to listen socket: ") + strerror(errno);
+ m_logger(m_error);
+ return false;
+ }
+ return true;
+}
+
+int RADIUS::createUNIX() const
+{
+ int fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (fd == -1)
+ {
+ m_error = std::string("Error creating UNIX socket: ") + strerror(errno);
+ m_logger(m_error);
+ return 0;
+ }
+ struct sockaddr_un addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ strncpy(addr.sun_path, m_config.bindAddress.c_str(), m_config.bindAddress.length());
+ if (bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
+ {
+ shutdown(fd, SHUT_RDWR);
+ close(fd);
+ m_error = std::string("Error binding UNIX socket: ") + strerror(errno);
+ m_logger(m_error);
+ return 0;
+ }
+ return fd;
+}
+
+int RADIUS::createTCP() const
+{
+ addrinfo hints;
+ memset(&hints, 0, sizeof(addrinfo));
+
+ hints.ai_family = AF_INET; /* Allow IPv4 */
+ hints.ai_socktype = SOCK_STREAM; /* Stream socket */
+ hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
+ hints.ai_protocol = 0; /* Any protocol */
+ hints.ai_canonname = NULL;
+ hints.ai_addr = NULL;
+ hints.ai_next = NULL;
+
+ addrinfo* ais = NULL;
+ int res = getaddrinfo(m_config.bindAddress.c_str(), m_config.portStr.c_str(), &hints, &ais);
+ if (res != 0)
+ {
+ m_error = "Error resolvin address '" + m_config.bindAddress + "': " + gai_strerror(res);
+ m_logger(m_error);
+ return 0;
+ }
+
+ for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
+ {
+ int fd = socket(AF_INET, SOCK_STREAM, 0);
+ if (fd == -1)
+ {
+ m_error = std::string("Error creating TCP socket: ") + strerror(errno);
+ m_logger(m_error);
+ freeaddrinfo(ais);
+ return 0;
+ }
+ if (bind(fd, ai->ai_addr, ai->ai_addrlen) == -1)
+ {
+ shutdown(fd, SHUT_RDWR);
+ close(fd);
+ m_error = std::string("Error binding TCP socket: ") + strerror(errno);
+ m_logger(m_error);
+ continue;
+ }
+ freeaddrinfo(ais);
+ return fd;
+ }
+
+ m_error = "Failed to resolve '" + m_config.bindAddress;
+ m_logger(m_error);
+
+ freeaddrinfo(ais);
+ return 0;
+}
+
void RADIUS::runImpl()
{
m_running = true;
if (res > 0)
handleEvents(fds);
+ else
+ {
+ for (std::deque<Conn*>::iterator it = m_conns.begin(); it != m_conns.end(); ++it)
+ (*it)->tick();
+ }
cleanupConns();
}
{
std::deque<STG::Conn *>::iterator pos;
for (pos = m_conns.begin(); pos != m_conns.end(); ++pos)
- if (((*pos)->isDone() && !(*pos)->isKeepAlive()) || !(*pos)->isOk()) {
+ if (!(*pos)->isOk()) {
delete *pos;
*pos = NULL;
}
for (it = m_conns.begin(); it != m_conns.end(); ++it)
if (FD_ISSET((*it)->sock(), &fds))
(*it)->read();
+ else
+ (*it)->tick();
+ }
+}
+
+void RADIUS::acceptConnection()
+{
+ if (m_config.connectionType == Config::UNIX)
+ acceptUNIX();
+ else
+ acceptTCP();
+}
+
+void RADIUS::acceptUNIX()
+{
+ struct sockaddr_un addr;
+ memset(&addr, 0, sizeof(addr));
+ socklen_t size = sizeof(addr);
+ int res = accept(m_listenSocket, reinterpret_cast<sockaddr*>(&addr), &size);
+ if (res == -1)
+ {
+ m_error = std::string("Failed to accept UNIX connection: ") + strerror(errno);
+ m_logger(m_error);
+ return;
+ }
+ m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, addr.sun_path));
+}
+
+void RADIUS::acceptTCP()
+{
+ struct sockaddr_in addr;
+ memset(&addr, 0, sizeof(addr));
+ socklen_t size = sizeof(addr);
+ int res = accept(m_listenSocket, reinterpret_cast<sockaddr*>(&addr), &size);
+ if (res == -1)
+ {
+ m_error = std::string("Failed to accept TCP connection: ") + strerror(errno);
+ m_logger(m_error);
+ return;
}
+ std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port));
+ m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, remote));
}
#include "stg/module_settings.h"
#include "stg/logger.h"
+#include "config.h"
+#include "conn.h"
+
#include <string>
+#include <deque>
#include <pthread.h>
#include <unistd.h>
RADIUS();
virtual ~RADIUS() {}
- void SetUsers(USERS* u) { users = u; }
- void SetStore(STORE* s) { store = s; }
+ void SetUsers(USERS* u) { m_users = u; }
+ void SetStore(STORE* s) { m_store = s; }
void SetStgSettings(const SETTINGS*) {}
- void SetSettings(const MODULE_SETTINGS& s) { settings = s; }
+ void SetSettings(const MODULE_SETTINGS& s) { m_settings = s; }
int ParseSettings();
int Start();
static void* run(void*);
- void rumImpl();
+ bool reconnect();
+ int createUNIX() const;
+ int createTCP() const;
+ void runImpl();
int maxFD() const;
void buildFDSet(fd_set & fds) const;
void cleanupConns();
void handleEvents(const fd_set & fds);
void acceptConnection();
+ void acceptUNIX();
+ void acceptTCP();
mutable std::string m_error;
STG::Config m_config;
USERS* m_users;
const STORE* m_store;
+ int m_listenSocket;
+ std::deque<STG::Conn*> m_conns;
+
pthread_t m_thread;
- pthread_mutex_t m_mutex;
PLUGIN_LOGGER m_logger;
};