]> git.stg.codes - stg.git/blob - stargazer/plugins/other/radius/conn.cpp
Fix build on OSX.
[stg.git] / stargazer / plugins / other / radius / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radius.h"
24 #include "config.h"
25
26 #include "stg/json_parser.h"
27 #include "stg/json_generator.h"
28 #include "stg/users.h"
29 #include "stg/user.h"
30 #include "stg/logger.h"
31 #include "stg/common.h"
32
33 #include <yajl/yajl_gen.h>
34
35 #include <map>
36 #include <stdexcept>
37 #include <cstring>
38 #include <cerrno>
39
40 #include <unistd.h>
41 #include <sys/types.h>
42 #include <sys/socket.h>
43
44 #ifndef MSG_NOSIGNAL
45 // On OSX this flag does not exist.
46 #define MSG_NOSIGNAL 0
47 #endif
48
49 using STG::Conn;
50 using STG::Config;
51 using STG::JSON::Parser;
52 using STG::JSON::PairsParser;
53 using STG::JSON::EnumParser;
54 using STG::JSON::NodeParser;
55 using STG::JSON::Gen;
56 using STG::JSON::MapGen;
57 using STG::JSON::StringGen;
58
59 namespace
60 {
61
62 double CONN_TIMEOUT = 60;
63 double PING_TIMEOUT = 10;
64
65 enum Packet
66 {
67     PING,
68     PONG,
69     DATA
70 };
71
72 enum Stage
73 {
74     AUTHORIZE,
75     AUTHENTICATE,
76     PREACCT,
77     ACCOUNTING,
78     POSTAUTH
79 };
80
81 std::map<std::string, Packet> packetCodes;
82 std::map<std::string, Stage> stageCodes;
83
84 class PacketParser : public EnumParser<Packet>
85 {
86     public:
87         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
88             : EnumParser(next, packet, packetStr, packetCodes)
89         {
90             if (!packetCodes.empty())
91                 return;
92             packetCodes["ping"] = PING;
93             packetCodes["pong"] = PONG;
94             packetCodes["data"] = DATA;
95         }
96 };
97
98 class StageParser : public EnumParser<Stage>
99 {
100     public:
101         StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
102             : EnumParser(next, stage, stageStr, stageCodes)
103         {
104             if (!stageCodes.empty())
105                 return;
106             stageCodes["authorize"] = AUTHORIZE;
107             stageCodes["authenticate"] = AUTHENTICATE;
108             stageCodes["preacct"] = PREACCT;
109             stageCodes["accounting"] = ACCOUNTING;
110             stageCodes["postauth"] = POSTAUTH;
111         }
112 };
113
114 class TopParser : public NodeParser
115 {
116     public:
117         typedef void (*Callback) (void* /*data*/);
118         TopParser(Callback callback, void* data)
119             : m_packetParser(this, m_packet, m_packetStr),
120               m_stageParser(this, m_stage, m_stageStr),
121               m_pairsParser(this, m_data),
122               m_callback(callback), m_callbackData(data)
123         {}
124
125         virtual NodeParser* parseStartMap() { return this; }
126         virtual NodeParser* parseMapKey(const std::string& value)
127         {
128             std::string key = ToLower(value);
129
130             if (key == "packet")
131                 return &m_packetParser;
132             else if (key == "stage")
133                 return &m_stageParser;
134             else if (key == "pairs")
135                 return &m_pairsParser;
136
137             return this;
138         }
139         virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
140
141         const std::string& packetStr() const { return m_packetStr; }
142         Packet packet() const { return m_packet; }
143         const std::string& stageStr() const { return m_stageStr; }
144         Stage stage() const { return m_stage; }
145         const Config::Pairs& data() const { return m_data; }
146
147     private:
148         std::string m_packetStr;
149         Packet m_packet;
150         std::string m_stageStr;
151         Stage m_stage;
152         Config::Pairs m_data;
153
154         PacketParser m_packetParser;
155         StageParser m_stageParser;
156         PairsParser m_pairsParser;
157
158         Callback m_callback;
159         void* m_callbackData;
160 };
161
162 class ProtoParser : public Parser
163 {
164     public:
165         ProtoParser(TopParser::Callback callback, void* data)
166             : Parser( &m_topParser ),
167               m_topParser(callback, data)
168         {}
169
170         const std::string& packetStr() const { return m_topParser.packetStr(); }
171         Packet packet() const { return m_topParser.packet(); }
172         const std::string& stageStr() const { return m_topParser.stageStr(); }
173         Stage stage() const { return m_topParser.stage(); }
174         const Config::Pairs& data() const { return m_topParser.data(); }
175
176     private:
177         TopParser m_topParser;
178 };
179
180 class PacketGen : public Gen
181 {
182     public:
183         PacketGen(const std::string& type)
184             : m_type(type)
185         {
186             m_gen.add("packet", m_type);
187         }
188         void run(yajl_gen_t* handle) const
189         {
190             m_gen.run(handle);
191         }
192         PacketGen& add(const std::string& key, const std::string& value)
193         {
194             m_gen.add(key, new StringGen(value));
195             return *this;
196         }
197         PacketGen& add(const std::string& key, MapGen* map)
198         {
199             m_gen.add(key, map);
200             return *this;
201         }
202         PacketGen& add(const std::string& key, MapGen& map)
203         {
204             m_gen.add(key, map);
205             return *this;
206         }
207     private:
208         MapGen m_gen;
209         StringGen m_type;
210 };
211
212 std::string toString(Config::ReturnCode code)
213 {
214     switch (code)
215     {
216         case Config::REJECT:   return "reject";
217         case Config::FAIL:     return "fail";
218         case Config::OK:       return "ok";
219         case Config::HANDLED:  return "handled";
220         case Config::INVALID:  return "invalid";
221         case Config::USERLOCK: return "userlock";
222         case Config::NOTFOUND: return "notfound";
223         case Config::NOOP:     return "noop";
224         case Config::UPDATED:  return "noop";
225     }
226     return "reject";
227 }
228
229 }
230
231 class Conn::Impl
232 {
233     public:
234         Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
235         ~Impl();
236
237         int sock() const { return m_sock; }
238
239         bool read();
240         bool tick();
241
242         bool isOk() const { return m_ok; }
243
244     private:
245         USERS& m_users;
246         PLUGIN_LOGGER& m_logger;
247         RADIUS& m_plugin;
248         const Config& m_config;
249         int m_sock;
250         std::string m_remote;
251         bool m_ok;
252         time_t m_lastPing;
253         time_t m_lastActivity;
254         ProtoParser m_parser;
255         std::set<std::string> m_authorized;
256
257         template <typename T>
258         const T& stageMember(T Config::Section::* member) const
259         {
260             switch (m_parser.stage())
261             {
262                 case AUTHORIZE: return m_config.autz.*member;
263                 case AUTHENTICATE: return m_config.auth.*member;
264                 case POSTAUTH: return m_config.postauth.*member;
265                 case PREACCT: return m_config.preacct.*member;
266                 case ACCOUNTING: return m_config.acct.*member;
267             }
268             throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
269         }
270
271         const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
272         const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
273         const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
274         Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
275         const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
276
277         static void process(void* data);
278         void processPing();
279         void processPong();
280         void processData();
281         bool answer(const USER& user);
282         bool answerNo();
283         bool sendPing();
284         bool sendPong();
285
286         static bool write(void* data, const char* buf, size_t size);
287 };
288
289 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
290     : m_impl(new Impl(users, logger, plugin, config, fd, remote))
291 {
292 }
293
294 Conn::~Conn()
295 {
296 }
297
298 int Conn::sock() const
299 {
300     return m_impl->sock();
301 }
302
303 bool Conn::read()
304 {
305     return m_impl->read();
306 }
307
308 bool Conn::tick()
309 {
310     return m_impl->tick();
311 }
312
313 bool Conn::isOk() const
314 {
315     return m_impl->isOk();
316 }
317
318 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
319     : m_users(users),
320       m_logger(logger),
321       m_plugin(plugin),
322       m_config(config),
323       m_sock(fd),
324       m_remote(remote),
325       m_ok(true),
326       m_lastPing(time(NULL)),
327       m_lastActivity(m_lastPing),
328       m_parser(&Conn::Impl::process, this)
329 {
330 }
331
332 Conn::Impl::~Impl()
333 {
334     close(m_sock);
335
336     std::set<std::string>::const_iterator it = m_authorized.begin();
337     for (; it != m_authorized.end(); ++it)
338         m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
339 }
340
341 bool Conn::Impl::read()
342 {
343     static std::vector<char> buffer(1024);
344     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
345     if (res < 0)
346     {
347         m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
348         m_ok = false;
349         return false;
350     }
351     printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
352     m_lastActivity = time(NULL);
353     if (res == 0)
354     {
355         m_ok = false;
356         return true;
357     }
358     return m_parser.append(buffer.data(), res);
359 }
360
361 bool Conn::Impl::tick()
362 {
363     time_t now = time(NULL);
364     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
365     {
366         int delta = difftime(now, m_lastActivity);
367         printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
368         m_logger("Connection to " + m_remote + " timed out.");
369         m_ok = false;
370         return false;
371     }
372     if (difftime(now, m_lastPing) > PING_TIMEOUT)
373     {
374         int delta = difftime(now, m_lastPing);
375         printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
376         sendPing();
377     }
378     return true;
379 }
380
381 void Conn::Impl::process(void* data)
382 {
383     Impl& impl = *static_cast<Impl*>(data);
384     try
385     {
386         switch (impl.m_parser.packet())
387         {
388             case PING:
389                 impl.processPing();
390                 return;
391             case PONG:
392                 impl.processPong();
393                 return;
394             case DATA:
395                 impl.processData();
396                 return;
397         }
398     }
399     catch (const std::exception& ex)
400     {
401         printfd(__FILE__, "Processing error. %s", ex.what());
402         impl.m_logger("Processing error. %s", ex.what());
403     }
404     printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
405     impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
406 }
407
408 void Conn::Impl::processPing()
409 {
410     printfd(__FILE__, "Got ping. Sending pong...\n");
411     sendPong();
412 }
413
414 void Conn::Impl::processPong()
415 {
416     printfd(__FILE__, "Got pong.\n");
417     m_lastActivity = time(NULL);
418 }
419
420 void Conn::Impl::processData()
421 {
422     printfd(__FILE__, "Got data.\n");
423     int handle = m_users.OpenSearch();
424
425     USER_PTR user = NULL;
426     bool matched = false;
427     while (m_users.SearchNext(handle, &user) == 0)
428     {
429         if (user == NULL)
430             continue;
431
432         matched = true;
433         for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
434         {
435             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
436             if (pos == m_parser.data().end())
437             {
438                 matched = false;
439                 break;
440             }
441             if (user->GetParamValue(it->second) != pos->second)
442             {
443                 matched = false;
444                 break;
445             }
446         }
447         if (!matched)
448             continue;
449         answer(*user);
450         if (authorize().check(*user, m_parser.data()))
451         {
452             m_plugin.authorize(*user);
453             m_authorized.insert(user->GetLogin());
454         }
455         break;
456     }
457
458     if (!matched)
459         answerNo();
460
461     m_users.CloseSearch(handle);
462 }
463
464 bool Conn::Impl::answer(const USER& user)
465 {
466     printfd(__FILE__, "Got match. Sending answer...\n");
467     MapGen replyData;
468     for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
469         replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
470
471     MapGen modifyData;
472     for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
473         modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
474
475     PacketGen gen("data");
476     gen.add("result", "ok")
477        .add("reply", replyData)
478        .add("modify", modifyData);
479
480     m_lastPing = time(NULL);
481
482     return generate(gen, &Conn::Impl::write, this);
483 }
484
485 bool Conn::Impl::answerNo()
486 {
487     printfd(__FILE__, "No match. Sending answer...\n");
488     PacketGen gen("data");
489     gen.add("result", "no");
490     gen.add("return_code", toString(returnCode()));
491
492     m_lastPing = time(NULL);
493
494     return generate(gen, &Conn::Impl::write, this);
495 }
496
497 bool Conn::Impl::sendPing()
498 {
499     PacketGen gen("ping");
500
501     m_lastPing = time(NULL);
502
503     return generate(gen, &Conn::Impl::write, this);
504 }
505
506 bool Conn::Impl::sendPong()
507 {
508     PacketGen gen("pong");
509
510     m_lastPing = time(NULL);
511
512     return generate(gen, &Conn::Impl::write, this);
513 }
514
515 bool Conn::Impl::write(void* data, const char* buf, size_t size)
516 {
517     std::string json(buf, size);
518     printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
519     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
520     while (size > 0)
521     {
522         ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
523         if (res < 0)
524         {
525             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
526             conn.m_ok = false;
527             return false;
528         }
529         size -= res;
530     }
531     return true;
532 }