X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/37324ea9b8c06d96b9383be993da02a01f103253..21ba4dfad49d2d489a9399d36d078eab8c44e0d6:/projects/stargazer/plugins/other/radius/server.cpp diff --git a/projects/stargazer/plugins/other/radius/server.cpp b/projects/stargazer/plugins/other/radius/server.cpp index 4d16b6cb..9e6766d5 100644 --- a/projects/stargazer/plugins/other/radius/server.cpp +++ b/projects/stargazer/plugins/other/radius/server.cpp @@ -1,18 +1,37 @@ #include "server.h" #include "radproto/packet_codes.h" +#include "radproto/attribute_types.h" +#include "stg/user.h" +#include "stg/users.h" +#include "stg/common.h" +#include #include -#include +#include //uint8_t, uint32_t using STG::Server; using boost::system::error_code; -Server::Server(boost::asio::io_service& io_service, const std::string& secret, uint16_t port, const std::string& filePath) - : m_radius(io_service, secret, port), - m_dictionaries(filePath) +Server::Server(boost::asio::io_context& io_context, const std::string& secret, uint16_t port, const std::string& filePath, std::stop_token token, PluginLogger& logger, Users* users) + : m_radius(io_context, secret, port), + m_dictionaries(filePath), + m_users(users), + m_token(std::move(token)), + m_logger(logger) +{ + start(); +} + +void Server::start() { startReceive(); } +void Server::stop() +{ + error_code ec; + m_radius.close(ec); +} + void Server::startReceive() { m_radius.asyncReceive([this](const auto& error, const auto& packet, const boost::asio::ip::udp::endpoint& source){ handleReceive(error, packet, source); }); @@ -34,35 +53,78 @@ RadProto::Packet Server::makeResponse(const RadProto::Packet& request) std::vector vendorValue {0, 0, 0, 3}; vendorSpecific.push_back(RadProto::VendorSpecific(m_dictionaries.vendorCode("Dlink"), m_dictionaries.vendorAttributeCode("Dlink", "Dlink-User-Level"), vendorValue)); - if (request.type() == RadProto::ACCESS_REQUEST) + if (request.type() != RadProto::ACCESS_REQUEST) + return RadProto::Packet(RadProto::ACCESS_REJECT, request.id(), request.auth(), {}, {}); + + if (findUser(request)) return RadProto::Packet(RadProto::ACCESS_ACCEPT, request.id(), request.auth(), attributes, vendorSpecific); - return RadProto::Packet(RadProto::ACCESS_REJECT, request.id(), request.auth(), attributes, vendorSpecific); + printfd(__FILE__, "Error findUser\n"); + return RadProto::Packet(RadProto::ACCESS_REJECT, request.id(), request.auth(), {}, {}); } void Server::handleSend(const error_code& ec) { - if (ec) - std::cout << "Error asyncSend: " << ec.message() << "\n"; + if (m_token.stop_requested()) + return; + if (ec) + { + m_logger("Error asyncSend: %s", ec.message().c_str()); + printfd(__FILE__, "Error asyncSend: '%s'\n", ec.message().c_str()); + } startReceive(); } void Server::handleReceive(const error_code& error, const std::optional& packet, const boost::asio::ip::udp::endpoint& source) { + if (m_token.stop_requested()) + return; + if (error) { - std::cout << "Error asyncReceive: " << error.message() << "\n"; - return; + m_logger("Error asyncReceive: %s", error.message().c_str()); + printfd(__FILE__, "Error asyncReceive: '%s'\n", error.message().c_str()); } if (packet == std::nullopt) { - std::cout << "Error asyncReceive: the request packet is missing\n"; + m_logger("Error asyncReceive: the request packet is missing\n"); + printfd(__FILE__, "Error asyncReceive: the request packet is missing\n"); return; } - else + + m_radius.asyncSend(makeResponse(*packet), source, [this](const auto& ec){ handleSend(ec); }); +} + +bool Server::findUser(const RadProto::Packet& packet) +{ + std::string login; + std::string password; + for (const auto& attribute : packet.attributes()) + { + if (attribute->type() == RadProto::USER_NAME) + login = attribute->toString(); + + if (attribute->type() == RadProto::USER_PASSWORD) + password = attribute->toString(); + } + + User* user = nullptr; + if (m_users->FindByName(login, &user)) + { + m_logger("User '%s' not found.", login.c_str()); + printfd(__FILE__, "User '%s' NOT found!\n", login.c_str()); + return false; + } + + printfd(__FILE__, "User '%s' FOUND!\n", user->GetLogin().c_str()); + + if (password != user->GetProperties().password.Get()) { - m_radius.asyncSend(makeResponse(*packet), source, [this](const auto& ec){ handleSend(ec); }); + m_logger("User's password is incorrect. %s", password.c_str()); + printfd(__FILE__, "User's password is incorrect.\n", password.c_str()); + return false; } + return true; }