]> git.stg.codes - stg.git/blob - projects/stargazer/plugins/other/radius/conn.cpp
Check NULL value pairs.
[stg.git] / projects / stargazer / plugins / other / radius / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "config.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/users.h"
28 #include "stg/user.h"
29 #include "stg/logger.h"
30 #include "stg/common.h"
31
32 #include <yajl/yajl_gen.h>
33
34 #include <map>
35 #include <stdexcept>
36 #include <cstring>
37 #include <cerrno>
38
39 #include <unistd.h>
40 #include <sys/types.h>
41 #include <sys/socket.h>
42
43 using STG::Conn;
44 using STG::Config;
45 using STG::JSON::Parser;
46 using STG::JSON::PairsParser;
47 using STG::JSON::EnumParser;
48 using STG::JSON::NodeParser;
49 using STG::JSON::Gen;
50 using STG::JSON::MapGen;
51 using STG::JSON::StringGen;
52
53 namespace
54 {
55
56 double CONN_TIMEOUT = 60;
57 double PING_TIMEOUT = 10;
58
59 enum Packet
60 {
61     PING,
62     PONG,
63     DATA
64 };
65
66 enum Stage
67 {
68     AUTHORIZE,
69     AUTHENTICATE,
70     PREACCT,
71     ACCOUNTING,
72     POSTAUTH
73 };
74
75 std::map<std::string, Packet> packetCodes;
76 std::map<std::string, Stage> stageCodes;
77
78 class PacketParser : public EnumParser<Packet>
79 {
80     public:
81         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
82             : EnumParser(next, packet, packetStr, packetCodes)
83         {
84             if (!packetCodes.empty())
85                 return;
86             packetCodes["ping"] = PING;
87             packetCodes["pong"] = PONG;
88             packetCodes["data"] = DATA;
89         }
90 };
91
92 class StageParser : public EnumParser<Stage>
93 {
94     public:
95         StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
96             : EnumParser(next, stage, stageStr, stageCodes)
97         {
98             if (!stageCodes.empty())
99                 return;
100             stageCodes["authorize"] = AUTHORIZE;
101             stageCodes["authenticate"] = AUTHENTICATE;
102             stageCodes["preacct"] = PREACCT;
103             stageCodes["accounting"] = ACCOUNTING;
104             stageCodes["postauth"] = POSTAUTH;
105         }
106 };
107
108 class TopParser : public NodeParser
109 {
110     public:
111         typedef void (*Callback) (void* /*data*/);
112         TopParser(Callback callback, void* data)
113             : m_packetParser(this, m_packet, m_packetStr),
114               m_stageParser(this, m_stage, m_stageStr),
115               m_pairsParser(this, m_data),
116               m_callback(callback), m_callbackData(data)
117         {}
118
119         virtual NodeParser* parseStartMap() { return this; }
120         virtual NodeParser* parseMapKey(const std::string& value)
121         {
122             std::string key = ToLower(value);
123
124             if (key == "packet")
125                 return &m_packetParser;
126             else if (key == "stage")
127                 return &m_stageParser;
128             else if (key == "pairs")
129                 return &m_pairsParser;
130
131             return this;
132         }
133         virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
134
135         const std::string& packetStr() const { return m_packetStr; }
136         Packet packet() const { return m_packet; }
137         const std::string& stageStr() const { return m_stageStr; }
138         Stage stage() const { return m_stage; }
139         const Config::Pairs& data() const { return m_data; }
140
141     private:
142         std::string m_packetStr;
143         Packet m_packet;
144         std::string m_stageStr;
145         Stage m_stage;
146         Config::Pairs m_data;
147
148         PacketParser m_packetParser;
149         StageParser m_stageParser;
150         PairsParser m_pairsParser;
151
152         Callback m_callback;
153         void* m_callbackData;
154 };
155
156 class ProtoParser : public Parser
157 {
158     public:
159         ProtoParser(TopParser::Callback callback, void* data)
160             : Parser( &m_topParser ),
161               m_topParser(callback, data)
162         {}
163
164         const std::string& packetStr() const { return m_topParser.packetStr(); }
165         Packet packet() const { return m_topParser.packet(); }
166         const std::string& stageStr() const { return m_topParser.stageStr(); }
167         Stage stage() const { return m_topParser.stage(); }
168         const Config::Pairs& data() const { return m_topParser.data(); }
169
170     private:
171         TopParser m_topParser;
172 };
173
174 class PacketGen : public Gen
175 {
176     public:
177         PacketGen(const std::string& type)
178             : m_type(type)
179         {
180             m_gen.add("packet", m_type);
181         }
182         void run(yajl_gen_t* handle) const
183         {
184             m_gen.run(handle);
185         }
186         PacketGen& add(const std::string& key, const std::string& value)
187         {
188             m_gen.add(key, new StringGen(value));
189             return *this;
190         }
191         PacketGen& add(const std::string& key, MapGen* map)
192         {
193             m_gen.add(key, map);
194             return *this;
195         }
196         PacketGen& add(const std::string& key, MapGen& map)
197         {
198             m_gen.add(key, map);
199             return *this;
200         }
201     private:
202         MapGen m_gen;
203         StringGen m_type;
204 };
205
206 std::string toString(Config::ReturnCode code)
207 {
208     switch (code)
209     {
210         case Config::REJECT:   return "reject";
211         case Config::FAIL:     return "fail";
212         case Config::OK:       return "ok";
213         case Config::HANDLED:  return "handled";
214         case Config::INVALID:  return "invalid";
215         case Config::USERLOCK: return "userlock";
216         case Config::NOTFOUND: return "notfound";
217         case Config::NOOP:     return "noop";
218         case Config::UPDATED:  return "noop";
219     }
220     return "reject";
221 }
222
223 }
224
225 class Conn::Impl
226 {
227     public:
228         Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
229         ~Impl();
230
231         int sock() const { return m_sock; }
232
233         bool read();
234         bool tick();
235
236         bool isOk() const { return m_ok; }
237
238     private:
239         USERS& m_users;
240         PLUGIN_LOGGER& m_logger;
241         const Config& m_config;
242         int m_sock;
243         std::string m_remote;
244         bool m_ok;
245         time_t m_lastPing;
246         time_t m_lastActivity;
247         ProtoParser m_parser;
248
249         template <typename T>
250         const T& stageMember(T Config::Section::* member) const
251         {
252             switch (m_parser.stage())
253             {
254                 case AUTHORIZE: return m_config.autz.*member;
255                 case AUTHENTICATE: return m_config.auth.*member;
256                 case POSTAUTH: return m_config.postauth.*member;
257                 case PREACCT: return m_config.preacct.*member;
258                 case ACCOUNTING: return m_config.acct.*member;
259             }
260             throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
261         }
262
263         const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
264         const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
265         const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
266         Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
267
268         static void process(void* data);
269         void processPing();
270         void processPong();
271         void processData();
272         bool answer(const USER& user);
273         bool answerNo();
274         bool sendPing();
275         bool sendPong();
276
277         static bool write(void* data, const char* buf, size_t size);
278 };
279
280 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
281     : m_impl(new Impl(users, logger, config, fd, remote))
282 {
283 }
284
285 Conn::~Conn()
286 {
287 }
288
289 int Conn::sock() const
290 {
291     return m_impl->sock();
292 }
293
294 bool Conn::read()
295 {
296     return m_impl->read();
297 }
298
299 bool Conn::tick()
300 {
301     return m_impl->tick();
302 }
303
304 bool Conn::isOk() const
305 {
306     return m_impl->isOk();
307 }
308
309 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
310     : m_users(users),
311       m_logger(logger),
312       m_config(config),
313       m_sock(fd),
314       m_remote(remote),
315       m_ok(true),
316       m_lastPing(time(NULL)),
317       m_lastActivity(m_lastPing),
318       m_parser(&Conn::Impl::process, this)
319 {
320 }
321
322 Conn::Impl::~Impl()
323 {
324     close(m_sock);
325 }
326
327 bool Conn::Impl::read()
328 {
329     static std::vector<char> buffer(1024);
330     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
331     if (res < 0)
332     {
333         m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
334         m_ok = false;
335         return false;
336     }
337     printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
338     m_lastActivity = time(NULL);
339     if (res == 0)
340     {
341         m_ok = false;
342         return true;
343     }
344     return m_parser.append(buffer.data(), res);
345 }
346
347 bool Conn::Impl::tick()
348 {
349     time_t now = time(NULL);
350     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
351     {
352         int delta = difftime(now, m_lastActivity);
353         printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
354         m_logger("Connection to " + m_remote + " timed out.");
355         m_ok = false;
356         return false;
357     }
358     if (difftime(now, m_lastPing) > PING_TIMEOUT)
359     {
360         int delta = difftime(now, m_lastPing);
361         printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
362         sendPing();
363     }
364     return true;
365 }
366
367 void Conn::Impl::process(void* data)
368 {
369     Impl& impl = *static_cast<Impl*>(data);
370     try
371     {
372         switch (impl.m_parser.packet())
373         {
374             case PING:
375                 impl.processPing();
376                 return;
377             case PONG:
378                 impl.processPong();
379                 return;
380             case DATA:
381                 impl.processData();
382                 return;
383         }
384     }
385     catch (const std::exception& ex)
386     {
387         printfd(__FILE__, "Processing error. %s", ex.what());
388         impl.m_logger("Processing error. %s", ex.what());
389     }
390     printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
391     impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
392 }
393
394 void Conn::Impl::processPing()
395 {
396     printfd(__FILE__, "Got ping. Sending pong...\n");
397     sendPong();
398 }
399
400 void Conn::Impl::processPong()
401 {
402     printfd(__FILE__, "Got pong.\n");
403     m_lastActivity = time(NULL);
404 }
405
406 void Conn::Impl::processData()
407 {
408     printfd(__FILE__, "Got data.\n");
409     int handle = m_users.OpenSearch();
410
411     USER_PTR user = NULL;
412     bool matched = false;
413     while (m_users.SearchNext(handle, &user) == 0)
414     {
415         if (user == NULL)
416             continue;
417
418         matched = true;
419         for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
420         {
421             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
422             if (pos == m_parser.data().end())
423             {
424                 matched = false;
425                 break;
426             }
427             if (user->GetParamValue(it->second) != pos->second)
428             {
429                 matched = false;
430                 break;
431             }
432         }
433         if (!matched)
434             continue;
435         answer(*user);
436         break;
437     }
438
439     if (!matched)
440         answerNo();
441
442     m_users.CloseSearch(handle);
443 }
444
445 bool Conn::Impl::answer(const USER& user)
446 {
447     printfd(__FILE__, "Got match. Sending answer...\n");
448     MapGen replyData;
449     for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
450         replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
451
452     MapGen modifyData;
453     for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
454         modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
455
456     PacketGen gen("data");
457     gen.add("result", "ok")
458        .add("reply", replyData)
459        .add("modify", modifyData);
460
461     m_lastPing = time(NULL);
462
463     return generate(gen, &Conn::Impl::write, this);
464 }
465
466 bool Conn::Impl::answerNo()
467 {
468     printfd(__FILE__, "No match. Sending answer...\n");
469     PacketGen gen("data");
470     gen.add("result", "no");
471     gen.add("return_code", toString(returnCode()));
472
473     m_lastPing = time(NULL);
474
475     return generate(gen, &Conn::Impl::write, this);
476 }
477
478 bool Conn::Impl::sendPing()
479 {
480     PacketGen gen("ping");
481
482     m_lastPing = time(NULL);
483
484     return generate(gen, &Conn::Impl::write, this);
485 }
486
487 bool Conn::Impl::sendPong()
488 {
489     PacketGen gen("pong");
490
491     m_lastPing = time(NULL);
492
493     return generate(gen, &Conn::Impl::write, this);
494 }
495
496 bool Conn::Impl::write(void* data, const char* buf, size_t size)
497 {
498     std::string json(buf, size);
499     printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
500     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
501     while (size > 0)
502     {
503         ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
504         if (res < 0)
505         {
506             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
507             conn.m_ok = false;
508             return false;
509         }
510         size -= res;
511     }
512     return true;
513 }