]> git.stg.codes - stg.git/blob - projects/stargazer/plugins/other/radius/conn.cpp
Handle EINTR in mod_radius correctly.
[stg.git] / projects / stargazer / plugins / other / radius / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "config.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/users.h"
28 #include "stg/user.h"
29 #include "stg/logger.h"
30 #include "stg/common.h"
31
32 #include <yajl/yajl_gen.h>
33
34 #include <map>
35 #include <cstring>
36 #include <cerrno>
37
38 #include <unistd.h>
39 #include <sys/types.h>
40 #include <sys/socket.h>
41
42 using STG::Conn;
43 using STG::Config;
44 using STG::JSON::Parser;
45 using STG::JSON::PairsParser;
46 using STG::JSON::EnumParser;
47 using STG::JSON::NodeParser;
48 using STG::JSON::Gen;
49 using STG::JSON::MapGen;
50 using STG::JSON::StringGen;
51
52 namespace
53 {
54
55 double CONN_TIMEOUT = 60;
56 double PING_TIMEOUT = 10;
57
58 enum Packet
59 {
60     PING,
61     PONG,
62     DATA
63 };
64
65 enum Stage
66 {
67     AUTHORIZE,
68     AUTHENTICATE,
69     PREACCT,
70     ACCOUNTING,
71     POSTAUTH
72 };
73
74 std::map<std::string, Packet> packetCodes;
75 std::map<std::string, Stage> stageCodes;
76
77 class PacketParser : public EnumParser<Packet>
78 {
79     public:
80         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
81             : EnumParser(next, packet, packetStr, packetCodes)
82         {
83             if (!packetCodes.empty())
84                 return;
85             packetCodes["ping"] = PING;
86             packetCodes["pong"] = PONG;
87             packetCodes["data"] = DATA;
88         }
89 };
90
91 class StageParser : public EnumParser<Stage>
92 {
93     public:
94         StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
95             : EnumParser(next, stage, stageStr, stageCodes)
96         {
97             if (!stageCodes.empty())
98                 return;
99             stageCodes["authorize"] = AUTHORIZE;
100             stageCodes["authenticate"] = AUTHENTICATE;
101             stageCodes["preacct"] = PREACCT;
102             stageCodes["accounting"] = ACCOUNTING;
103             stageCodes["postauth"] = POSTAUTH;
104         }
105 };
106
107 class TopParser : public NodeParser
108 {
109     public:
110         typedef void (*Callback) (void* /*data*/);
111         TopParser(Callback callback, void* data)
112             : m_packetParser(this, m_packet, m_packetStr),
113               m_stageParser(this, m_stage, m_stageStr),
114               m_pairsParser(this, m_data),
115               m_callback(callback), m_callbackData(data)
116         {}
117
118         virtual NodeParser* parseStartMap() { return this; }
119         virtual NodeParser* parseMapKey(const std::string& value)
120         {
121             std::string key = ToLower(value);
122
123             if (key == "packet")
124                 return &m_packetParser;
125             else if (key == "stage")
126                 return &m_stageParser;
127             else if (key == "pairs")
128                 return &m_pairsParser;
129
130             return this;
131         }
132         virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
133
134         const std::string& packetStr() const { return m_packetStr; }
135         Packet packet() const { return m_packet; }
136         const std::string& stageStr() const { return m_stageStr; }
137         Stage stage() const { return m_stage; }
138         const Config::Pairs& data() const { return m_data; }
139
140     private:
141         std::string m_packetStr;
142         Packet m_packet;
143         std::string m_stageStr;
144         Stage m_stage;
145         Config::Pairs m_data;
146
147         PacketParser m_packetParser;
148         StageParser m_stageParser;
149         PairsParser m_pairsParser;
150
151         Callback m_callback;
152         void* m_callbackData;
153 };
154
155 class ProtoParser : public Parser
156 {
157     public:
158         ProtoParser(TopParser::Callback callback, void* data)
159             : Parser( &m_topParser ),
160               m_topParser(callback, data)
161         {}
162
163         const std::string& packetStr() const { return m_topParser.packetStr(); }
164         Packet packet() const { return m_topParser.packet(); }
165         const std::string& stageStr() const { return m_topParser.stageStr(); }
166         Stage stage() const { return m_topParser.stage(); }
167         const Config::Pairs& data() const { return m_topParser.data(); }
168
169     private:
170         TopParser m_topParser;
171 };
172
173 class PacketGen : public Gen
174 {
175     public:
176         PacketGen(const std::string& type)
177             : m_type(type)
178         {
179             m_gen.add("packet", m_type);
180         }
181         void run(yajl_gen_t* handle) const
182         {
183             m_gen.run(handle);
184         }
185         PacketGen& add(const std::string& key, const std::string& value)
186         {
187             m_gen.add(key, new StringGen(value));
188             return *this;
189         }
190         PacketGen& add(const std::string& key, MapGen* map)
191         {
192             m_gen.add(key, map);
193             return *this;
194         }
195         PacketGen& add(const std::string& key, MapGen& map)
196         {
197             m_gen.add(key, map);
198             return *this;
199         }
200     private:
201         MapGen m_gen;
202         StringGen m_type;
203 };
204
205 }
206
207 class Conn::Impl
208 {
209     public:
210         Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
211         ~Impl();
212
213         int sock() const { return m_sock; }
214
215         bool read();
216         bool tick();
217
218         bool isOk() const { return m_ok; }
219
220     private:
221         USERS& m_users;
222         PLUGIN_LOGGER& m_logger;
223         const Config& m_config;
224         int m_sock;
225         std::string m_remote;
226         bool m_ok;
227         time_t m_lastPing;
228         time_t m_lastActivity;
229         ProtoParser m_parser;
230
231         static void process(void* data);
232         void processPing();
233         void processPong();
234         void processData();
235         bool answer(const USER& user);
236         bool answerNo();
237         bool sendPing();
238         bool sendPong();
239
240         static bool write(void* data, const char* buf, size_t size);
241 };
242
243 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
244     : m_impl(new Impl(users, logger, config, fd, remote))
245 {
246 }
247
248 Conn::~Conn()
249 {
250 }
251
252 int Conn::sock() const
253 {
254     return m_impl->sock();
255 }
256
257 bool Conn::read()
258 {
259     return m_impl->read();
260 }
261
262 bool Conn::tick()
263 {
264     return m_impl->tick();
265 }
266
267 bool Conn::isOk() const
268 {
269     return m_impl->isOk();
270 }
271
272 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
273     : m_users(users),
274       m_logger(logger),
275       m_config(config),
276       m_sock(fd),
277       m_remote(remote),
278       m_ok(true),
279       m_lastPing(time(NULL)),
280       m_lastActivity(m_lastPing),
281       m_parser(&Conn::Impl::process, this)
282 {
283 }
284
285 Conn::Impl::~Impl()
286 {
287     close(m_sock);
288 }
289
290 bool Conn::Impl::read()
291 {
292     static std::vector<char> buffer(1024);
293     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
294     if (res < 0)
295     {
296         m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
297         m_ok = false;
298         return false;
299     }
300     printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
301     m_lastActivity = time(NULL);
302     if (res == 0)
303     {
304         m_ok = false;
305         return true;
306     }
307     return m_parser.append(buffer.data(), res);
308 }
309
310 bool Conn::Impl::tick()
311 {
312     time_t now = time(NULL);
313     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
314     {
315         int delta = difftime(now, m_lastActivity);
316         printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
317         m_logger("Connection to " + m_remote + " timed out.");
318         m_ok = false;
319         return false;
320     }
321     if (difftime(now, m_lastPing) > PING_TIMEOUT)
322     {
323         int delta = difftime(now, m_lastPing);
324         printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
325         sendPing();
326     }
327     return true;
328 }
329
330 void Conn::Impl::process(void* data)
331 {
332     Impl& impl = *static_cast<Impl*>(data);
333     switch (impl.m_parser.packet())
334     {
335         case PING:
336             impl.processPing();
337             return;
338         case PONG:
339             impl.processPong();
340             return;
341         case DATA:
342             impl.processData();
343             return;
344     }
345     printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
346     impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
347 }
348
349 void Conn::Impl::processPing()
350 {
351     printfd(__FILE__, "Got ping. Sending pong...\n");
352     sendPong();
353 }
354
355 void Conn::Impl::processPong()
356 {
357     printfd(__FILE__, "Got pong.\n");
358     m_lastActivity = time(NULL);
359 }
360
361 void Conn::Impl::processData()
362 {
363     printfd(__FILE__, "Got data.\n");
364     int handle = m_users.OpenSearch();
365
366     USER_PTR user = NULL;
367     bool match = false;
368     while (m_users.SearchNext(handle, &user) == 0)
369     {
370         if (user == NULL)
371             continue;
372
373         match = true;
374         for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
375         {
376             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
377             if (pos == m_parser.data().end())
378             {
379                 match = false;
380                 break;
381             }
382             if (user->GetParamValue(it->second) != pos->second)
383             {
384                 match = false;
385                 break;
386             }
387         }
388         if (!match)
389             continue;
390         answer(*user);
391         break;
392     }
393
394     if (!match)
395         answerNo();
396
397     m_users.CloseSearch(handle);
398 }
399
400 bool Conn::Impl::answer(const USER& user)
401 {
402     printfd(__FILE__, "Got match. Sending answer...\n");
403     MapGen reply;
404     for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
405         reply.add(it->first, new StringGen(user.GetParamValue(it->second)));
406
407     MapGen modify;
408     for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
409         modify.add(it->first, new StringGen(user.GetParamValue(it->second)));
410
411     PacketGen gen("data");
412     gen.add("result", "ok")
413        .add("reply", reply)
414        .add("modify", modify);
415
416     m_lastPing = time(NULL);
417
418     return generate(gen, &Conn::Impl::write, this);
419 }
420
421 bool Conn::Impl::answerNo()
422 {
423     printfd(__FILE__, "No match. Sending answer...\n");
424     PacketGen gen("data");
425     gen.add("result", "no");
426
427     m_lastPing = time(NULL);
428
429     return generate(gen, &Conn::Impl::write, this);
430 }
431
432 bool Conn::Impl::sendPing()
433 {
434     PacketGen gen("ping");
435
436     m_lastPing = time(NULL);
437
438     return generate(gen, &Conn::Impl::write, this);
439 }
440
441 bool Conn::Impl::sendPong()
442 {
443     PacketGen gen("pong");
444
445     m_lastPing = time(NULL);
446
447     return generate(gen, &Conn::Impl::write, this);
448 }
449
450 bool Conn::Impl::write(void* data, const char* buf, size_t size)
451 {
452     std::string json(buf, size);
453     printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
454     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
455     while (size > 0)
456     {
457         ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
458         if (res < 0)
459         {
460             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
461             conn.m_ok = false;
462             return false;
463         }
464         size -= res;
465     }
466     return true;
467 }