]> git.stg.codes - stg.git/blobdiff - projects/stargazer/plugins/other/radius/server.cpp
Sending attributes from <auth>/send section. (#13)
[stg.git] / projects / stargazer / plugins / other / radius / server.cpp
index 375fcd6950e0c8549d8f439935a0d1f7be8f435e..11aa9a79344ee506c28e2b4f7d1fa1efe3af4ccd 100644 (file)
@@ -1,20 +1,26 @@
 #include "server.h"
+#include "radproto/attribute.h"
 #include "radproto/packet_codes.h"
-#include "radproto/attribute_types.h"
+#include "radproto/attribute_codes.h"
 #include "stg/user.h"
 #include "stg/users.h"
 #include "stg/common.h"
+#include <vector>
+#include <string>
+#include <sstream>
 #include <cstring>
 #include <functional>
 #include <cstdint> //uint8_t, uint32_t
 
 using STG::Server;
+using STG::User;
 using boost::system::error_code;
 
-Server::Server(boost::asio::io_service& io_service, const std::string& secret, uint16_t port, const std::string& filePath, std::stop_token token, PluginLogger& logger, Users* users)
-    : m_radius(io_service, secret, port),
+Server::Server(boost::asio::io_context& io_context, const std::string& secret, uint16_t port, const std::string& filePath, std::stop_token token, PluginLogger& logger, Users* users, const Config& config)
+    : m_radius(io_context, secret, port),
       m_dictionaries(filePath),
       m_users(users),
+      m_config(config),
       m_token(std::move(token)),
       m_logger(logger)
 {
@@ -37,27 +43,42 @@ void Server::startReceive()
     m_radius.asyncReceive([this](const auto& error, const auto& packet, const boost::asio::ip::udp::endpoint& source){ handleReceive(error, packet, source); });
 }
 
-RadProto::Packet Server::makeResponse(const RadProto::Packet& request)
+std::vector<RadProto::Attribute*> Server::makeAttributes(const User* user)
 {
     std::vector<RadProto::Attribute*> attributes;
-    attributes.push_back(new RadProto::String(m_dictionaries.attributeCode("User-Name"), "test"));
-    attributes.push_back(new RadProto::Integer(m_dictionaries.attributeCode("NAS-Port"), 20));
-    std::array<uint8_t, 4> address {127, 104, 22, 17};
-    attributes.push_back(new RadProto::IpAddress(m_dictionaries.attributeCode("NAS-IP-Address"), address));
-    std::vector<uint8_t> bytes {'1', '2', '3', 'a', 'b', 'c'};
-    attributes.push_back(new RadProto::Bytes(m_dictionaries.attributeCode("Callback-Number"), bytes));
-    std::vector<uint8_t> chapPassword {'1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g' };
-    attributes.push_back(new RadProto::ChapPassword(m_dictionaries.attributeCode("CHAP-Password"), 1, chapPassword));
-
-    std::vector<RadProto::VendorSpecific> vendorSpecific;
-    std::vector<uint8_t> vendorValue {0, 0, 0, 3};
-    vendorSpecific.push_back(RadProto::VendorSpecific(m_dictionaries.vendorCode("Dlink"), m_dictionaries.vendorAttributeCode("Dlink", "Dlink-User-Level"), vendorValue));
-
-    if (request.type() != RadProto::ACCESS_REQUEST)
+
+    for (const auto& at : m_config.getAuth().send)
+    {
+        std::string attrValue;
+
+        if (at.second.type == Config::AttrValue::Type::PARAM_NAME)
+            attrValue = user->GetParamValue(at.second.value);
+        else
+            attrValue = at.second.value;
+
+        const auto attrName = at.first;
+        const auto attrCode = m_dictionaries.attributeCode(attrName);
+        const auto attrType = m_dictionaries.attributeType(attrCode);
+
+        if ((attrType == "integer") && (m_dictionaries.attributeValueFindByName(attrName, attrValue)))
+            attributes.push_back(RadProto::Attribute::make(attrCode, attrType, std::to_string(m_dictionaries.attributeValueCode(attrName, attrValue))));
+        else
+            attributes.push_back(RadProto::Attribute::make(attrCode, attrType, attrValue));
+    }
+    return attributes;
+}
+
+RadProto::Packet Server::makeResponse(const RadProto::Packet& request)
+{
+    if (request.code() != RadProto::ACCESS_REQUEST)
         return RadProto::Packet(RadProto::ACCESS_REJECT, request.id(), request.auth(), {}, {});
 
-    if (findUser(request))
-        return RadProto::Packet(RadProto::ACCESS_ACCEPT, request.id(), request.auth(), attributes, vendorSpecific);
+    const User* user;
+
+    user = findUser(request);
+
+    if (user != nullptr)
+        return RadProto::Packet(RadProto::ACCESS_ACCEPT, request.id(), request.auth(), makeAttributes(user), {});
 
     printfd(__FILE__, "Error findUser\n");
     return RadProto::Packet(RadProto::ACCESS_REJECT, request.id(), request.auth(), {}, {});
@@ -85,6 +106,7 @@ void Server::handleReceive(const error_code& error, const std::optional<RadProto
     {
         m_logger("Error asyncReceive: %s", error.message().c_str());
         printfd(__FILE__, "Error asyncReceive: '%s'\n", error.message().c_str());
+        return;
     }
 
     if (packet == std::nullopt)
@@ -97,16 +119,16 @@ void Server::handleReceive(const error_code& error, const std::optional<RadProto
     m_radius.asyncSend(makeResponse(*packet), source, [this](const auto& ec){ handleSend(ec); });
 }
 
-bool Server::findUser(const RadProto::Packet& packet)
+const User* Server::findUser(const RadProto::Packet& packet)
 {
     std::string login;
     std::string password;
     for (const auto& attribute : packet.attributes())
     {
-        if (attribute->type() == RadProto::USER_NAME)
+        if (attribute->code() == RadProto::USER_NAME)
             login = attribute->toString();
 
-        if (attribute->type() == RadProto::USER_PASSWORD)
+        if (attribute->code() == RadProto::USER_PASSWORD)
             password = attribute->toString();
     }
 
@@ -115,16 +137,16 @@ bool Server::findUser(const RadProto::Packet& packet)
     {
         m_logger("User '%s' not found.", login.c_str());
         printfd(__FILE__, "User '%s' NOT found!\n", login.c_str());
-        return false;
+        return nullptr;
     }
 
     printfd(__FILE__, "User '%s' FOUND!\n", user->GetLogin().c_str());
 
     if (password != user->GetProperties().password.Get())
     {
-        m_logger("User's password is incorrect. %s", password.c_str());
-        printfd(__FILE__, "User's password is incorrect.\n", password.c_str());
-        return false;
+        m_logger("User's password is incorrect.");
+        printfd(__FILE__, "User's password is incorrect.\n");
+        return nullptr;
     }
-    return true;
+    return user;
 }