X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/1538d6276533140505fddb71c99a0bafe6ca9182..d6e4a058a37bdaea7df8c8d360978c0dc8848fff:/projects/stargazer/plugins/other/radius/radius.cpp diff --git a/projects/stargazer/plugins/other/radius/radius.cpp b/projects/stargazer/plugins/other/radius/radius.cpp index fe989b28..376e4278 100644 --- a/projects/stargazer/plugins/other/radius/radius.cpp +++ b/projects/stargazer/plugins/other/radius/radius.cpp @@ -23,11 +23,24 @@ #include "stg/store.h" #include "stg/users.h" #include "stg/plugin_creator.h" +#include "stg/common.h" +#include +#include #include -#include +#include #include +#include +#include +#include // UNIX +#include // IP +#include // TCP +#include + +using STG::Config; +using STG::Conn; + namespace { @@ -41,10 +54,12 @@ extern "C" PLUGIN * GetPlugin() } RADIUS::RADIUS() - : m_running(false), + : m_config(), + m_running(false), m_stopped(true), m_users(NULL), m_store(NULL), + m_listenSocket(0), m_logger(GetPluginLogger(GetStgLogger(), "radius")) { } @@ -53,7 +68,7 @@ int RADIUS::ParseSettings() { try { m_config = STG::Config(m_settings); - return 0; + return reconnect() ? 0 : -1; } catch (const std::runtime_error& ex) { m_logger("Failed to parse settings. %s", ex.what()); return -1; @@ -91,6 +106,9 @@ int RADIUS::Stop() return 0; } + if (m_config.connectionType == Config::UNIX) + unlink(m_config.bindAddress.c_str()); + m_error = "Failed to stop thread."; m_logger(m_error); return -1; @@ -107,9 +125,118 @@ void* RADIUS::run(void* d) return NULL; } +bool RADIUS::reconnect() +{ + if (!m_conns.empty()) + { + std::deque::const_iterator it; + for (it = m_conns.begin(); it != m_conns.end(); ++it) + delete(*it); + m_conns.clear(); + } + if (m_listenSocket != 0) + { + shutdown(m_listenSocket, SHUT_RDWR); + close(m_listenSocket); + } + if (m_config.connectionType == Config::UNIX) + m_listenSocket = createUNIX(); + else + m_listenSocket = createTCP(); + if (m_listenSocket == 0) + return false; + if (listen(m_listenSocket, 100) == -1) + { + m_error = std::string("Error starting to listen socket: ") + strerror(errno); + m_logger(m_error); + return false; + } + return true; +} + +int RADIUS::createUNIX() const +{ + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) + { + m_error = std::string("Error creating UNIX socket: ") + strerror(errno); + m_logger(m_error); + return 0; + } + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, m_config.bindAddress.c_str(), m_config.bindAddress.length()); + unlink(m_config.bindAddress.c_str()); + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + m_error = std::string("Error binding UNIX socket: ") + strerror(errno); + m_logger(m_error); + return 0; + } + chown(m_config.bindAddress.c_str(), m_config.sockUID, m_config.sockGID); + if (m_config.sockMode != static_cast(-1)) + chmod(m_config.bindAddress.c_str(), m_config.sockMode); + return fd; +} + +int RADIUS::createTCP() const +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + + hints.ai_family = AF_INET; /* Allow IPv4 */ + hints.ai_socktype = SOCK_STREAM; /* Stream socket */ + hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */ + hints.ai_protocol = 0; /* Any protocol */ + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + addrinfo* ais = NULL; + int res = getaddrinfo(m_config.bindAddress.c_str(), m_config.portStr.c_str(), &hints, &ais); + if (res != 0) + { + m_error = "Error resolving address '" + m_config.bindAddress + "': " + gai_strerror(res); + m_logger(m_error); + return 0; + } + + for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next) + { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) + { + m_error = std::string("Error creating TCP socket: ") + strerror(errno); + m_logger(m_error); + freeaddrinfo(ais); + return 0; + } + if (bind(fd, ai->ai_addr, ai->ai_addrlen) == -1) + { + shutdown(fd, SHUT_RDWR); + close(fd); + m_error = std::string("Error binding TCP socket: ") + strerror(errno); + m_logger(m_error); + continue; + } + freeaddrinfo(ais); + return fd; + } + + m_error = "Failed to resolve '" + m_config.bindAddress; + m_logger(m_error); + + freeaddrinfo(ais); + return 0; +} + void RADIUS::runImpl() { m_running = true; + m_stopped = false; while (m_running) { fd_set fds; @@ -123,6 +250,8 @@ void RADIUS::runImpl() int res = select(maxFD() + 1, &fds, NULL, NULL, &tv); if (res < 0) { + if (errno == EINTR) + continue; m_error = std::string("'select' is failed: '") + strerror(errno) + "'."; m_logger(m_error); break; @@ -133,6 +262,11 @@ void RADIUS::runImpl() if (res > 0) handleEvents(fds); + else + { + for (std::deque::iterator it = m_conns.begin(); it != m_conns.end(); ++it) + (*it)->tick(); + } cleanupConns(); } @@ -163,7 +297,7 @@ void RADIUS::cleanupConns() { std::deque::iterator pos; for (pos = m_conns.begin(); pos != m_conns.end(); ++pos) - if (((*pos)->isDone() && !(*pos)->isKeepAlive()) || !(*pos)->isOk()) { + if (!(*pos)->isOk()) { delete *pos; *pos = NULL; } @@ -182,5 +316,48 @@ void RADIUS::handleEvents(const fd_set & fds) for (it = m_conns.begin(); it != m_conns.end(); ++it) if (FD_ISSET((*it)->sock(), &fds)) (*it)->read(); + else + (*it)->tick(); + } +} + +void RADIUS::acceptConnection() +{ + if (m_config.connectionType == Config::UNIX) + acceptUNIX(); + else + acceptTCP(); +} + +void RADIUS::acceptUNIX() +{ + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + socklen_t size = sizeof(addr); + int res = accept(m_listenSocket, reinterpret_cast(&addr), &size); + if (res == -1) + { + m_error = std::string("Failed to accept UNIX connection: ") + strerror(errno); + m_logger(m_error); + return; + } + printfd(__FILE__, "New UNIX connection: '%s'\n", addr.sun_path); + m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, addr.sun_path)); +} + +void RADIUS::acceptTCP() +{ + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + socklen_t size = sizeof(addr); + int res = accept(m_listenSocket, reinterpret_cast(&addr), &size); + if (res == -1) + { + m_error = std::string("Failed to accept TCP connection: ") + strerror(errno); + m_logger(m_error); + return; } + std::string remote = inet_ntostring(addr.sin_addr.s_addr) + ":" + x2str(ntohs(addr.sin_port)); + printfd(__FILE__, "New TCP connection: '%s'\n", remote.c_str()); + m_conns.push_back(new Conn(*m_users, m_logger, m_config, res, remote)); }