]> git.stg.codes - stg.git/blob - projects/stargazer/plugins/other/radius/conn.cpp
392624a6ef4846dd8d4c424baca068638e4ee4a2
[stg.git] / projects / stargazer / plugins / other / radius / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "config.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/users.h"
28 #include "stg/user.h"
29 #include "stg/logger.h"
30 #include "stg/common.h"
31
32 #include <yajl/yajl_gen.h>
33
34 #include <map>
35 #include <cstring>
36 #include <cerrno>
37
38 #include <unistd.h>
39
40 using STG::Conn;
41 using STG::Config;
42 using STG::JSON::Parser;
43 using STG::JSON::PairsParser;
44 using STG::JSON::EnumParser;
45 using STG::JSON::NodeParser;
46 using STG::JSON::Gen;
47 using STG::JSON::MapGen;
48 using STG::JSON::StringGen;
49
50 namespace
51 {
52
53 double CONN_TIMEOUT = 5;
54 double PING_TIMEOUT = 1;
55
56 enum Packet
57 {
58     PING,
59     PONG,
60     DATA
61 };
62
63 enum Stage
64 {
65     AUTHORIZE,
66     AUTHENTICATE,
67     PREACCT,
68     ACCOUNTING,
69     POSTAUTH
70 };
71
72 std::map<std::string, Packet> packetCodes;
73 std::map<std::string, Stage> stageCodes;
74
75 class PacketParser : public EnumParser<Packet>
76 {
77     public:
78         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
79             : EnumParser(next, packet, packetStr, packetCodes)
80         {
81             if (!packetCodes.empty())
82                 return;
83             packetCodes["ping"] = PING;
84             packetCodes["pong"] = PONG;
85             packetCodes["data"] = DATA;
86         }
87 };
88
89 class StageParser : public EnumParser<Stage>
90 {
91     public:
92         StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
93             : EnumParser(next, stage, stageStr, stageCodes)
94         {
95             if (!stageCodes.empty())
96                 return;
97             stageCodes["authorize"] = AUTHORIZE;
98             stageCodes["authenticate"] = AUTHENTICATE;
99             stageCodes["preacct"] = PREACCT;
100             stageCodes["accounting"] = ACCOUNTING;
101             stageCodes["postauth"] = POSTAUTH;
102         }
103 };
104
105 class TopParser : public NodeParser
106 {
107     public:
108         TopParser()
109             : m_packetParser(this, m_packet, m_packetStr),
110               m_stageParser(this, m_stage, m_stageStr),
111               m_pairsParser(this, m_data)
112         {}
113
114         virtual NodeParser* parseStartMap() { return this; }
115         virtual NodeParser* parseMapKey(const std::string& value)
116         {
117             std::string key = ToLower(value);
118
119             if (key == "packet")
120                 return &m_packetParser;
121             else if (key == "stage")
122                 return &m_stageParser;
123             else if (key == "pairs")
124                 return &m_pairsParser;
125
126             return this;
127         }
128         virtual NodeParser* parseEndMap() { return this; }
129
130         const std::string& packetStr() const { return m_packetStr; }
131         Packet packet() const { return m_packet; }
132         const std::string& stageStr() const { return m_stageStr; }
133         Stage stage() const { return m_stage; }
134         const Config::Pairs& data() const { return m_data; }
135
136     private:
137         std::string m_packetStr;
138         Packet m_packet;
139         std::string m_stageStr;
140         Stage m_stage;
141         Config::Pairs m_data;
142
143         PacketParser m_packetParser;
144         StageParser m_stageParser;
145         PairsParser m_pairsParser;
146 };
147
148 class ProtoParser : public Parser
149 {
150     public:
151         ProtoParser() : Parser( &m_topParser ) {}
152
153         const std::string& packetStr() const { return m_topParser.packetStr(); }
154         Packet packet() const { return m_topParser.packet(); }
155         const std::string& stageStr() const { return m_topParser.stageStr(); }
156         Stage stage() const { return m_topParser.stage(); }
157         const Config::Pairs& data() const { return m_topParser.data(); }
158
159     private:
160         TopParser m_topParser;
161 };
162
163 class PacketGen : public Gen
164 {
165     public:
166         PacketGen(const std::string& type)
167             : m_type(type)
168         {
169             m_gen.add("packet", m_type);
170         }
171         void run(yajl_gen_t* handle) const
172         {
173             m_gen.run(handle);
174         }
175         PacketGen& add(const std::string& key, const std::string& value)
176         {
177             m_gen.add(key, new StringGen(value));
178             return *this;
179         }
180         PacketGen& add(const std::string& key, MapGen* map)
181         {
182             m_gen.add(key, map);
183             return *this;
184         }
185     private:
186         MapGen m_gen;
187         StringGen m_type;
188 };
189
190 }
191
192 class Conn::Impl
193 {
194     public:
195         Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote);
196         ~Impl();
197
198         int sock() const { return m_sock; }
199
200         bool read();
201         bool tick();
202
203         bool isOk() const { return m_ok; }
204
205     private:
206         USERS& m_users;
207         PLUGIN_LOGGER& m_logger;
208         const Config& m_config;
209         int m_sock;
210         std::string m_remote;
211         bool m_ok;
212         time_t m_lastPing;
213         time_t m_lastActivity;
214         ProtoParser m_parser;
215
216         bool process();
217         bool processPing();
218         bool processPong();
219         bool processData();
220         bool answer(const USER& user);
221         bool answerNo();
222         bool sendPing();
223         bool sendPong();
224
225         static bool write(void* data, const char* buf, size_t size);
226 };
227
228 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
229     : m_impl(new Impl(users, logger, config, fd, remote))
230 {
231 }
232
233 Conn::~Conn()
234 {
235 }
236
237 int Conn::sock() const
238 {
239     return m_impl->sock();
240 }
241
242 bool Conn::read()
243 {
244     return m_impl->read();
245 }
246
247 bool Conn::tick()
248 {
249     return m_impl->tick();
250 }
251
252 bool Conn::isOk() const
253 {
254     return m_impl->isOk();
255 }
256
257 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, const Config& config, int fd, const std::string& remote)
258     : m_users(users),
259       m_logger(logger),
260       m_config(config),
261       m_sock(fd),
262       m_remote(remote),
263       m_ok(true),
264       m_lastPing(time(NULL)),
265       m_lastActivity(m_lastPing)
266 {
267 }
268
269 Conn::Impl::~Impl()
270 {
271     close(m_sock);
272 }
273
274 bool Conn::Impl::read()
275 {
276     static std::vector<char> buffer(1024);
277     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
278     if (res < 0)
279     {
280         m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
281         m_ok = false;
282         return false;
283     }
284     m_lastActivity = time(NULL);
285     if (res == 0)
286     {
287         if (!m_parser.done())
288         {
289             m_ok = false;
290             m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
291             return false;
292         }
293         return process();
294     }
295     return m_parser.append(buffer.data(), res);
296 }
297
298 bool Conn::Impl::tick()
299 {
300     time_t now = time(NULL);
301     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
302     {
303         m_logger("Connection to " + m_remote + " timed out.");
304         m_ok = false;
305         return false;
306     }
307     if (difftime(now, m_lastPing) > PING_TIMEOUT)
308         sendPing();
309     return true;
310 }
311
312 bool Conn::Impl::process()
313 {
314     switch (m_parser.packet())
315     {
316         case PING:
317             return processPing();
318         case PONG:
319             return processPong();
320         case DATA:
321             return processData();
322     }
323     m_logger("Received invalid packet type: " + m_parser.packetStr());
324     return false;
325 }
326
327 bool Conn::Impl::processPing()
328 {
329     return sendPong();
330 }
331
332 bool Conn::Impl::processPong()
333 {
334     m_lastActivity = time(NULL);
335     return true;
336 }
337
338 bool Conn::Impl::processData()
339 {
340     int handle = m_users.OpenSearch();
341
342     USER_PTR user = NULL;
343     bool match = true;
344     while (m_users.SearchNext(handle, &user))
345     {
346         if (user == NULL)
347             continue;
348
349         match = true;
350         for (Config::Pairs::const_iterator it = m_config.match.begin(); it != m_config.match.end(); ++it)
351         {
352             Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
353             if (pos == m_parser.data().end())
354             {
355                 match = false;
356                 break;
357             }
358             if (user->GetParamValue(it->second) != pos->second)
359             {
360                 match = false;
361                 break;
362             }
363         }
364         if (!match)
365             continue;
366         answer(*user);
367         break;
368     }
369
370     if (!match)
371         answerNo();
372
373     m_users.CloseSearch(handle);
374
375     return true;
376 }
377
378 bool Conn::Impl::answer(const USER& user)
379 {
380     boost::scoped_ptr<MapGen> reply(new MapGen);
381     for (Config::Pairs::const_iterator it = m_config.reply.begin(); it != m_config.reply.end(); ++it)
382         reply->add(it->first, new StringGen(user.GetParamValue(it->second)));
383
384     boost::scoped_ptr<MapGen> modify(new MapGen);
385     for (Config::Pairs::const_iterator it = m_config.modify.begin(); it != m_config.modify.end(); ++it)
386         modify->add(it->first, new StringGen(user.GetParamValue(it->second)));
387
388     PacketGen gen("data");
389     gen.add("result", "ok")
390        .add("reply", reply.get())
391        .add("modify", modify.get());
392
393     m_lastPing = time(NULL);
394
395     return generate(gen, &Conn::Impl::write, this);
396 }
397
398 bool Conn::Impl::answerNo()
399 {
400     PacketGen gen("data");
401     gen.add("result", "ok");
402
403     m_lastPing = time(NULL);
404
405     return generate(gen, &Conn::Impl::write, this);
406 }
407
408 bool Conn::Impl::sendPing()
409 {
410     PacketGen gen("ping");
411
412     m_lastPing = time(NULL);
413
414     return generate(gen, &Conn::Impl::write, this);
415 }
416
417 bool Conn::Impl::sendPong()
418 {
419     PacketGen gen("pong");
420
421     m_lastPing = time(NULL);
422
423     return generate(gen, &Conn::Impl::write, this);
424 }
425
426 bool Conn::Impl::write(void* data, const char* buf, size_t size)
427 {
428     Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
429     while (size > 0)
430     {
431         ssize_t res = ::write(conn.m_sock, buf, size);
432         if (res < 0)
433         {
434             conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
435             conn.m_ok = false;
436             return false;
437         }
438         size -= res;
439     }
440     return true;
441 }