]> git.stg.codes - stg.git/blobdiff - projects/rlm_stg/rlm_stg.c
Final icing on rlm_stg.
[stg.git] / projects / rlm_stg / rlm_stg.c
index 3c6e9c9b3b16210c34c10d1661815b529c45ba01..0333fadc55304795cbf2b3475895d6e5979a88ae 100644 (file)
@@ -26,6 +26,9 @@
  *
  */
 
  *
  */
 
+#include "iface.h"
+#include "stgpair.h"
+
 #ifndef NDEBUG
 #define NDEBUG
 #include <freeradius/ident.h>
 #ifndef NDEBUG
 #define NDEBUG
 #include <freeradius/ident.h>
 #undef NDEBUG
 #endif
 
 #undef NDEBUG
 #endif
 
-#include "iface.h"
-#include "stgpair.h"
+#include <stddef.h> // size_t
 
 typedef struct rlm_stg_t {
 
 typedef struct rlm_stg_t {
-    char* server;
-    uint16_t port;
-    char* password;
+    char* address;
 } rlm_stg_t;
 
 static const CONF_PARSER module_config[] = {
 } rlm_stg_t;
 
 static const CONF_PARSER module_config[] = {
-  { "server",  PW_TYPE_STRING_PTR, offsetof(rlm_stg_t,server), NULL,  "localhost"},
-  { "port",  PW_TYPE_INTEGER,     offsetof(rlm_stg_t,port), NULL,  "9091" },
-  { "password",  PW_TYPE_STRING_PTR, offsetof(rlm_stg_t,password), NULL,  "123456"},
+  { "address",  PW_TYPE_STRING_PTR, offsetof(rlm_stg_t, address), NULL,  "unix:/var/run/stg.sock"},
 
   { NULL, -1, 0, NULL, NULL }        /* end the list */
 };
 
 
   { NULL, -1, 0, NULL, NULL }        /* end the list */
 };
 
+static void deletePairs(STG_PAIR* pairs)
+{
+    free(pairs);
+}
+
+static size_t toVPS(const STG_PAIR* pairs, VALUE_PAIR* vps)
+{
+    const STG_PAIR* pair = pairs;
+    size_t count = 0;
+
+    while (!emptyPair(pair)) {
+        VALUE_PAIR* vp = pairmake(pair->key, pair->value, T_OP_SET);
+        pairadd(&vps, vp);
+        DEBUG("Adding pair '%s': '%s'", pair->key, pair->value);
+        ++pair;
+        ++count;
+    }
+
+    return count;
+}
+
+static size_t toReply(STG_RESULT result, REQUEST* request)
+{
+    size_t count = 0;
+
+    count += toVPS(result.modify, request->config_items);
+    count += toVPS(result.reply, request->reply->vps);
+
+    deletePairs(result.modify);
+    deletePairs(result.reply);
+
+    return count;
+}
+
+static int countVPS(const VALUE_PAIR* pairs)
+{
+    unsigned count = 0;
+    while (pairs != NULL) {
+        ++count;
+        pairs = pairs->next;
+    }
+    return count;
+}
+
+static STG_PAIR* fromVPS(const VALUE_PAIR* pairs)
+{
+    unsigned size = countVPS(pairs);
+    STG_PAIR* res = (STG_PAIR*)malloc(sizeof(STG_PAIR) * (size + 1));
+    size_t pos = 0;
+    while (pairs != NULL) {
+        bzero(res[pos].key, sizeof(res[0].key));
+        bzero(res[pos].value, sizeof(res[0].value));
+        strncpy(res[pos].key, pairs->name, sizeof(res[0].key));
+        strncpy(res[pos].value, pairs->data.strvalue, sizeof(res[0].value));
+        ++pos;
+        pairs = pairs->next;
+    }
+    bzero(res[pos].key, sizeof(res[0].key));
+    bzero(res[pos].value, sizeof(res[0].value));
+    return res;
+}
+
 /*
  *    Do any per-module initialization that is separate to each
  *    configured instance of the module.  e.g. set up connections
 /*
  *    Do any per-module initialization that is separate to each
  *    configured instance of the module.  e.g. set up connections
@@ -83,7 +143,7 @@ static int stg_instantiate(CONF_SECTION* conf, void** instance)
         return -1;
     }
 
         return -1;
     }
 
-    if (!stgInstantiateImpl(data->server, data->port)) {
+    if (!stgInstantiateImpl(data->address)) {
         free(data);
         return -1;
     }
         free(data);
         return -1;
     }
@@ -99,10 +159,10 @@ static int stg_instantiate(CONF_SECTION* conf, void** instance)
  *    from the database. The authentication code only needs to check
  *    the password, the rest is done here.
  */
  *    from the database. The authentication code only needs to check
  *    the password, the rest is done here.
  */
-static int stg_authorize(void*, REQUEST* request)
+static int stg_authorize(void* instance, REQUEST* request)
 {
 {
-    const STG_PAIR* pairs;
-    const STG_PAIR* pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
     const char* username = NULL;
     const char* password = NULL;
     size_t count = 0;
     const char* username = NULL;
     const char* password = NULL;
@@ -121,22 +181,15 @@ static int stg_authorize(void*, REQUEST* request)
         DEBUG("rlm_stg: stg_authorize() request password field: '%s'", password);
     }
 
         DEBUG("rlm_stg: stg_authorize() request password field: '%s'", password);
     }
 
-    pairs = stgAuthorizeImpl(username, password, request->packet->vps);
+    result = stgAuthorizeImpl(username, password, pairs);
+    deletePairs(pairs);
 
 
-    if (!pairs) {
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_authorize() failed.");
         return RLM_MODULE_REJECT;
     }
 
         DEBUG("rlm_stg: stg_authorize() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR* vp = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->config_items, vp);
-        DEBUG("Adding pair '%s': '%s'", pair->key, pair->value);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
     if (count)
         return RLM_MODULE_UPDATED;
@@ -147,10 +200,10 @@ static int stg_authorize(void*, REQUEST* request)
 /*
  *    Authenticate the user with the given password.
  */
 /*
  *    Authenticate the user with the given password.
  */
-static int stg_authenticate(void*, REQUEST* request)
+static int stg_authenticate(void* instance, REQUEST* request)
 {
 {
-    const STG_PAIR* pairs;
-    const STG_PAIR* pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
     const char* username = NULL;
     const char* password = NULL;
     size_t count = 0;
     const char* username = NULL;
     const char* password = NULL;
@@ -169,21 +222,15 @@ static int stg_authenticate(void*, REQUEST* request)
         DEBUG("rlm_stg: stg_authenticate() request password field: '%s'", password);
     }
 
         DEBUG("rlm_stg: stg_authenticate() request password field: '%s'", password);
     }
 
-    pairs = stgAuthenticateImpl(username, password, request->packet->vps);
+    result = stgAuthenticateImpl(username, password, pairs);
+    deletePairs(pairs);
 
 
-    if (!pairs) {
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_authenticate() failed.");
         return RLM_MODULE_REJECT;
     }
 
         DEBUG("rlm_stg: stg_authenticate() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR* vp = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, vp);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
     if (count)
         return RLM_MODULE_UPDATED;
@@ -194,10 +241,10 @@ static int stg_authenticate(void*, REQUEST* request)
 /*
  *    Massage the request before recording it or proxying it
  */
 /*
  *    Massage the request before recording it or proxying it
  */
-static int stg_preacct(void*, REQUEST*)
+static int stg_preacct(void* instance, REQUEST* request)
 {
 {
-    const STG_PAIR* pairs;
-    const STG_PAIR* pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
     const char* username = NULL;
     const char* password = NULL;
     size_t count = 0;
     const char* username = NULL;
     const char* password = NULL;
@@ -216,21 +263,15 @@ static int stg_preacct(void*, REQUEST*)
         DEBUG("rlm_stg: stg_preacct() request password field: '%s'", password);
     }
 
         DEBUG("rlm_stg: stg_preacct() request password field: '%s'", password);
     }
 
-    pairs = stgPreAcctImpl(username, password, request->packet->vps);
+    result = stgPreAcctImpl(username, password, pairs);
+    deletePairs(pairs);
 
 
-    if (!pairs) {
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_preacct() failed.");
         return RLM_MODULE_REJECT;
     }
 
         DEBUG("rlm_stg: stg_preacct() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR* vp = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, vp);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
     if (count)
         return RLM_MODULE_UPDATED;
@@ -241,51 +282,37 @@ static int stg_preacct(void*, REQUEST*)
 /*
  *    Write accounting information to this modules database.
  */
 /*
  *    Write accounting information to this modules database.
  */
-static int stg_accounting(void*, REQUEST* request)
+static int stg_accounting(void* instance, REQUEST* request)
 {
 {
-    const STG_PAIR* pairs;
-    const STG_PAIR* pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
     size_t count = 0;
-
-    instance = instance;
+    const char* username = NULL;
+    const char* password = NULL;
 
     DEBUG("rlm_stg: stg_accounting()");
 
 
     DEBUG("rlm_stg: stg_accounting()");
 
-    VALUE_PAIR* svc = pairfind(request->packet->vps, PW_SERVICE_TYPE);
-    VALUE_PAIR* sessid = pairfind(request->packet->vps, PW_ACCT_SESSION_ID);
-    VALUE_PAIR* sttype = pairfind(request->packet->vps, PW_ACCT_STATUS_TYPE);
+    instance = instance;
 
 
-    if (!sessid) {
-        DEBUG("rlm_stg: stg_accounting() Acct-Session-ID undefined");
-        return RLM_MODULE_FAIL;
+    if (request->username) {
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_accounting() request username field: '%s'", username);
     }
 
     }
 
-    if (sttype) {
-        DEBUG("Acct-Status-Type := %s", sttype->data.strvalue);
-        if (svc) {
-            DEBUG("rlm_stg: stg_accounting() Service-Type defined as '%s'", svc->data.strvalue);
-            pairs = stgAccountingImpl((const char*)request->username->data.strvalue, (const char*)svc->data.strvalue, (const char*)sttype->data.strvalue, (const char*)sessid->data.strvalue);
-        } else {
-            DEBUG("rlm_stg: stg_accounting() Service-Type undefined");
-            pairs = stgAccountingImpl((const char*)request->username->data.strvalue, "", (const char*)sttype->data.strvalue, (const char*)sessid->data.strvalue);
-        }
-    } else {
-        DEBUG("rlm_stg: stg_accounting() Acct-Status-Type := NULL");
-        return RLM_MODULE_OK;
+    if (request->password) {
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_accounting() request password field: '%s'", password);
     }
     }
-    if (!pairs) {
+
+    result = stgAccountingImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_accounting() failed.");
         return RLM_MODULE_REJECT;
     }
 
         DEBUG("rlm_stg: stg_accounting() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR* pwd = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, pwd);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
     if (count)
         return RLM_MODULE_UPDATED;
@@ -303,7 +330,7 @@ static int stg_accounting(void*, REQUEST* request)
  *    max. number of logins, do a second pass and validate all
  *    logins by querying the terminal server (using eg. SNMP).
  */
  *    max. number of logins, do a second pass and validate all
  *    logins by querying the terminal server (using eg. SNMP).
  */
-static int stg_checksimul(void*, REQUEST* request)
+static int stg_checksimul(void* instance, REQUEST* request)
 {
     DEBUG("rlm_stg: stg_checksimul()");
 
 {
     DEBUG("rlm_stg: stg_checksimul()");
 
@@ -314,38 +341,37 @@ static int stg_checksimul(void*, REQUEST* request)
     return RLM_MODULE_OK;
 }
 
     return RLM_MODULE_OK;
 }
 
-static int stg_postauth(void*, REQUEST* request)
+static int stg_postauth(void* instance, REQUEST* request)
 {
 {
-    const STG_PAIR* pairs;
-    const STG_PAIR* pair;
+    STG_RESULT result;
+    STG_PAIR* pairs = fromVPS(request->packet->vps);
     size_t count = 0;
     size_t count = 0;
-
-    instance = instance;
+    const char* username = NULL;
+    const char* password = NULL;
 
     DEBUG("rlm_stg: stg_postauth()");
 
 
     DEBUG("rlm_stg: stg_postauth()");
 
-    VALUE_PAIR* svc = pairfind(request->packet->vps, PW_SERVICE_TYPE);
+    instance = instance;
 
 
-    if (svc) {
-        DEBUG("rlm_stg: stg_postauth() Service-Type defined as '%s'", svc->data.strvalue);
-        pairs = stgPostAuthImpl((const char*)request->username->data.strvalue, (const char*)svc->data.strvalue);
-    } else {
-        DEBUG("rlm_stg: stg_postauth() Service-Type undefined");
-        pairs = stgPostAuthImpl((const char*)request->username->data.strvalue, "");
+    if (request->username) {
+        username = request->username->data.strvalue;
+        DEBUG("rlm_stg: stg_postauth() request username field: '%s'", username);
     }
     }
-    if (!pairs) {
+
+    if (request->password) {
+        password = request->password->data.strvalue;
+        DEBUG("rlm_stg: stg_postauth() request password field: '%s'", password);
+    }
+
+    result = stgPostAuthImpl(username, password, pairs);
+    deletePairs(pairs);
+
+    if (!result.modify && !result.reply) {
         DEBUG("rlm_stg: stg_postauth() failed.");
         return RLM_MODULE_REJECT;
     }
 
         DEBUG("rlm_stg: stg_postauth() failed.");
         return RLM_MODULE_REJECT;
     }
 
-    pair = pairs;
-    while (!emptyPair(pair)) {
-        VALUE_PAIR* pwd = pairmake(pair->key, pair->value, T_OP_SET);
-        pairadd(&request->reply->vps, pwd);
-        ++pair;
-        ++count;
-    }
-    deletePairs(pairs);
+    count = toReply(result, request);
 
     if (count)
         return RLM_MODULE_UPDATED;
 
     if (count)
         return RLM_MODULE_UPDATED;
@@ -355,7 +381,7 @@ static int stg_postauth(void*, REQUEST* request)
 
 static int stg_detach(void* instance)
 {
 
 static int stg_detach(void* instance)
 {
-    free(((struct rlm_stg_t*)instance)->server);
+    free(((struct rlm_stg_t*)instance)->address);
     free(instance);
     return 0;
 }
     free(instance);
     return 0;
 }
@@ -381,8 +407,8 @@ module_t rlm_stg = {
         stg_preacct,      /* preaccounting */
         stg_accounting,   /* accounting */
         stg_checksimul,   /* checksimul */
         stg_preacct,      /* preaccounting */
         stg_accounting,   /* accounting */
         stg_checksimul,   /* checksimul */
-        stg_pre_proxy,    /* pre-proxy */
-        stg_post_proxy,   /* post-proxy */
+        NULL,    /* pre-proxy */
+        NULL,   /* post-proxy */
         stg_postauth      /* post-auth */
     },
 };
         stg_postauth      /* post-auth */
     },
 };