]> git.stg.codes - stg.git/blob - projects/stargazer/plugins/other/radius/conn.cpp
Merge branch 'stg-2.409' into stg-2.409-radius
[stg.git] / projects / stargazer / plugins / other / radius / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radius.h"
24 #include "config.h"
25
26 #include "stg/json_parser.h"
27 #include "stg/json_generator.h"
28 #include "stg/users.h"
29 #include "stg/user.h"
30 #include "stg/logger.h"
31 #include "stg/common.h"
32
33 #include <yajl/yajl_gen.h>
34
35 #include <map>
36 #include <stdexcept>
37 #include <cstring>
38 #include <cerrno>
39
40 #include <unistd.h>
41 #include <sys/types.h>
42 #include <sys/socket.h>
43
44 using STG::Conn;
45 using STG::Config;
46 using STG::JSON::Parser;
47 using STG::JSON::PairsParser;
48 using STG::JSON::EnumParser;
49 using STG::JSON::NodeParser;
50 using STG::JSON::Gen;
51 using STG::JSON::MapGen;
52 using STG::JSON::StringGen;
53
54 namespace
55 {
56
57 double CONN_TIMEOUT = 60;
58 double PING_TIMEOUT = 10;
59
60 enum Packet
61 {
62     PING,
63     PONG,
64     DATA
65 };
66
67 enum Stage
68 {
69     AUTHORIZE,
70     AUTHENTICATE,
71     PREACCT,
72     ACCOUNTING,
73     POSTAUTH
74 };
75
76 std::map<std::string, Packet> packetCodes;
77 std::map<std::string, Stage> stageCodes;
78
79 class PacketParser : public EnumParser<Packet>
80 {
81     public:
82         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
83             : EnumParser(next, packet, packetStr, packetCodes)
84         {
85             if (!packetCodes.empty())
86                 return;
87             packetCodes["ping"] = PING;
88             packetCodes["pong"] = PONG;
89             packetCodes["data"] = DATA;
90         }
91 };
92
93 class StageParser : public EnumParser<Stage>
94 {
95     public:
96         StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
97             : EnumParser(next, stage, stageStr, stageCodes)
98         {
99             if (!stageCodes.empty())
100                 return;
101             stageCodes["authorize"] = AUTHORIZE;
102             stageCodes["authenticate"] = AUTHENTICATE;
103             stageCodes["preacct"] = PREACCT;
104             stageCodes["accounting"] = ACCOUNTING;
105             stageCodes["postauth"] = POSTAUTH;
106         }
107 };
108
109 class TopParser : public NodeParser
110 {
111     public:
112         typedef void (*Callback) (void* /*data*/);
113         TopParser(Callback callback, void* data)
114             : m_packetParser(this, m_packet, m_packetStr),
115               m_stageParser(this, m_stage, m_stageStr),
116               m_pairsParser(this, m_data),
117               m_callback(callback), m_callbackData(data)
118         {}
119
120         virtual NodeParser* parseStartMap() { return this; }
121         virtual NodeParser* parseMapKey(const std::string& value)
122         {
123             std::string key = ToLower(value);
124
125             if (key == "packet")
126                 return &m_packetParser;
127             else if (key == "stage")
128                 return &m_stageParser;
129             else if (key == "pairs")
130                 return &m_pairsParser;
131
132             return this;
133         }
134         virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
135
136         const std::string& packetStr() const { return m_packetStr; }
137         Packet packet() const { return m_packet; }
138         const std::string& stageStr() const { return m_stageStr; }
139         Stage stage() const { return m_stage; }
140         const Config::Pairs& data() const { return m_data; }
141
142     private:
143         std::string m_packetStr;
144         Packet m_packet;
145         std::string m_stageStr;
146         Stage m_stage;
147         Config::Pairs m_data;
148
149         PacketParser m_packetParser;
150         StageParser m_stageParser;
151         PairsParser m_pairsParser;
152
153         Callback m_callback;
154         void* m_callbackData;
155 };
156
157 class ProtoParser : public Parser
158 {
159     public:
160         ProtoParser(TopParser::Callback callback, void* data)
161             : Parser( &m_topParser ),
162               m_topParser(callback, data)
163         {}
164
165         const std::string& packetStr() const { return m_topParser.packetStr(); }
166         Packet packet() const { return m_topParser.packet(); }
167         const std::string& stageStr() const { return m_topParser.stageStr(); }
168         Stage stage() const { return m_topParser.stage(); }
169         const Config::Pairs& data() const { return m_topParser.data(); }
170
171     private:
172         TopParser m_topParser;
173 };
174
175 class PacketGen : public Gen
176 {
177     public:
178         PacketGen(const std::string& type)
179             : m_type(type)
180         {
181             m_gen.add("packet", m_type);
182         }
183         void run(yajl_gen_t* handle) const
184         {
185             m_gen.run(handle);
186         }
187         PacketGen& add(const std::string& key, const std::string& value)
188         {
189             m_gen.add(key, new StringGen(value));
190             return *this;
191         }
192         PacketGen& add(const std::string& key, MapGen* map)
193         {
194             m_gen.add(key, map);
195             return *this;
196         }
197         PacketGen& add(const std::string& key, MapGen& map)
198         {
199             m_gen.add(key, map);
200             return *this;
201         }
202     private:
203         MapGen m_gen;
204         StringGen m_type;
205 };
206
207 std::string toString(Config::ReturnCode code)
208 {
209     switch (code)
210     {
211         case Config::REJECT:   return "reject";
212         case Config::FAIL:     return "fail";
213         case Config::OK:       return "ok";
214         case Config::HANDLED:  return "handled";
215         case Config::INVALID:  return "invalid";
216         case Config::USERLOCK: return "userlock";
217         case Config::NOTFOUND: return "notfound";
218         case Config::NOOP:     return "noop";
219         case Config::UPDATED:  return "noop";
220     }
221     return "reject";
222 }
223
224 }
225
226 class Conn::Impl
227 {
228     public:
229         Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
230         ~Impl();
231
232         int sock() const { return m_sock; }
233
234         bool read();
235         bool tick();
236
237         bool isOk() const { return m_ok; }
238
239     private:
240         USERS& m_users;
241         PLUGIN_LOGGER& m_logger;
242         RADIUS& m_plugin;
243         const Config& m_config;
244         int m_sock;
245         std::string m_remote;
246         bool m_ok;
247         time_t m_lastPing;
248         time_t m_lastActivity;
249         ProtoParser m_parser;
250         std::set<std::string> m_authorized;
251
252         template <typename T>
253         const T& stageMember(T Config::Section::* member) const
254         {
255             switch (m_parser.stage())
256             {
257                 case AUTHORIZE: return m_config.autz.*member;
258                 case AUTHENTICATE: return m_config.auth.*member;
259                 case POSTAUTH: return m_config.postauth.*member;
260                 case PREACCT: return m_config.preacct.*member;
261                 case ACCOUNTING: return m_config.acct.*member;
262             }
263             throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
264         }
265
266         const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
267         const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
268         const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
269         Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
270         const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
271
272         static void process(void* data);
273         void processPing();
274         void processPong();
275         void processData();
276         bool answer(const USER& user);
277         bool answerNo();
278         bool sendPing();
279         bool sendPong();
280
281         static bool write(void* data, const char* buf, size_t size);
282 };
283
284 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
285     : m_impl(new Impl(users, logger, plugin, config, fd, remote))
286 {
287 }
288
289 Conn::~Conn()
290 {
291 }
292
293 int Conn::sock() const
294 {
295     return m_impl->sock();
296 }
297
298 bool Conn::read()
299 {
300     return m_impl->read();
301 }
302
303 bool Conn::tick()
304 {
305     return m_impl->tick();
306 }
307
308 bool Conn::isOk() const
309 {
310     return m_impl->isOk();
311 }
312
313 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
314     : m_users(users),
315       m_logger(logger),
316       m_plugin(plugin),
317       m_config(config),
318       m_sock(fd),
319       m_remote(remote),
320       m_ok(true),
321       m_lastPing(time(NULL)),
322       m_lastActivity(m_lastPing),
323       m_parser(&Conn::Impl::process, this)
324 {
325 }
326
327 Conn::Impl::~Impl()
328 {
329     close(m_sock);
330
331     std::set<std::string>::const_iterator it = m_authorized.begin();
332     for (; it != m_authorized.end(); ++it)
333         m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
334 }
335
336 bool Conn::Impl::read()
337 {
338     static std::vector<char> buffer(1024);
339     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
340     if (res < 0)
341     {
342         m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
343         m_ok = false;
344         return false;
345     }
346     printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
347     m_lastActivity = time(NULL);
348     if (res == 0)
349     {
350         m_ok = false;
351         return true;
352     }
353     return m_parser.append(buffer.data(), res);
354 }
355
356 bool Conn::Impl::tick()
357 {
358     time_t now = time(NULL);
359     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
360     {
361         int delta = difftime(now, m_lastActivity);
362         printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
363         m_logger("Connection to " + m_remote + " timed out.");
364         m_ok = false;
365         return false;
366     }
367     if (difftime(now, m_lastPing) > PING_TIMEOUT)
368     {
369         int delta = difftime(now, m_lastPing);
370         printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
371         sendPing();
372     }
373     return true;
374 }
375
376 void Conn::Impl::process(void* data)
377 {
378     Impl& impl = *static_cast<Impl*>(data);
379     try
380     {
381         switch (impl.m_parser.packet())
382         {
383             case PING:
384                 impl.processPing();
385                 return;
386             case PONG:
387                 impl.processPong();
388                 return;
389             case DATA:
390                 impl.processData();
391                 return;
392         }
393     }
394     catch (const std::exception& ex)
395     {
396         printfd(__FILE__, "Processing error. %s", ex.what());
397         impl.m_logger("Processing error. %s", ex.what());
398     }
399     printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
400     impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
401 }
402
403 void Conn::Impl::processPing()
404 {
405     printfd(__FILE__, "Got ping. Sending pong...\n");
406     sendPong();
407 }
408
409 void Conn::Impl::processPong()
410 {
411     printfd(__FILE__, "Got pong.\n");
412     m_lastActivity = time(NULL);
413 }
414
415 void Conn::Impl::processData()
416 {
417     printfd(__FILE__, "Got data.\n");
418     int handle = m_users.OpenSearch();
419
420     USER_PTR user = NULL;
421     bool matched = false;
422     while (m_users.SearchNext(handle, &user) == 0)
423     {
424         if (user == NULL)
425             continue;
426
427         matched = true;
428         for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
429         {
430             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
431             if (pos == m_parser.data().end())
432             {
433                 matched = false;
434                 break;
435             }
436             if (user->GetParamValue(it->second) != pos->second)
437             {
438                 matched = false;
439                 break;
440             }
441         }
442         if (!matched)
443             continue;
444         answer(*user);
445         if (authorize().check(*user, m_parser.data()))
446         {
447             m_plugin.authorize(*user);
448             m_authorized.insert(user->GetLogin());
449         }
450         break;
451     }
452
453     if (!matched)
454         answerNo();
455
456     m_users.CloseSearch(handle);
457 }
458
459 bool Conn::Impl::answer(const USER& user)
460 {
461     printfd(__FILE__, "Got match. Sending answer...\n");
462     MapGen replyData;
463     for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
464         replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
465
466     MapGen modifyData;
467     for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
468         modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
469
470     PacketGen gen("data");
471     gen.add("result", "ok")
472        .add("reply", replyData)
473        .add("modify", modifyData);
474
475     m_lastPing = time(NULL);
476
477     return generate(gen, &Conn::Impl::write, this);
478 }
479
480 bool Conn::Impl::answerNo()
481 {
482     printfd(__FILE__, "No match. Sending answer...\n");
483     PacketGen gen("data");
484     gen.add("result", "no");
485     gen.add("return_code", toString(returnCode()));
486
487     m_lastPing = time(NULL);
488
489     return generate(gen, &Conn::Impl::write, this);
490 }
491
492 bool Conn::Impl::sendPing()
493 {
494     PacketGen gen("ping");
495
496     m_lastPing = time(NULL);
497
498     return generate(gen, &Conn::Impl::write, this);
499 }
500
501 bool Conn::Impl::sendPong()
502 {
503     PacketGen gen("pong");
504
505     m_lastPing = time(NULL);
506
507     return generate(gen, &Conn::Impl::write, this);
508 }
509
510 bool Conn::Impl::write(void* data, const char* buf, size_t size)
511 {
512     std::string json(buf, size);
513     printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
514     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
515     while (size > 0)
516     {
517         ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
518         if (res < 0)
519         {
520             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
521             conn.m_ok = false;
522             return false;
523         }
524         size -= res;
525     }
526     return true;
527 }