]> git.stg.codes - stg.git/blob - projects/rlm_stg/conn.cpp
ed0b7a61a16d4c7ae590d351213d47cdc6594b4c
[stg.git] / projects / rlm_stg / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/locker.h"
28
29 #include <cerrno>
30 #include <cstring>
31
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 #include <sys/un.h> // UNIX
35 #include <netinet/in.h> // IP
36 #include <netinet/tcp.h> // TCP
37 #include <netdb.h>
38
39 namespace RLM = STG::RLM;
40
41 using RLM::Conn;
42 using STG::JSON::Parser;
43 using STG::JSON::PairsParser;
44 using STG::JSON::EnumParser;
45 using STG::JSON::NodeParser;
46 using STG::JSON::Gen;
47 using STG::JSON::MapGen;
48 using STG::JSON::StringGen;
49
50 namespace
51 {
52
53 double CONN_TIMEOUT = 60;
54 double PING_TIMEOUT = 10;
55
56 struct ChannelConfig {
57     struct Error : std::runtime_error {
58         Error(const std::string& message) : runtime_error(message) {}
59     };
60
61     ChannelConfig(std::string address);
62
63     std::string transport;
64     std::string key;
65     std::string address;
66     std::string portStr;
67     uint16_t port;
68 };
69
70 std::string toStage(RLM::REQUEST_TYPE type)
71 {
72     switch (type)
73     {
74         case RLM::AUTHORIZE: return "authorize";
75         case RLM::AUTHENTICATE: return "authenticate";
76         case RLM::POST_AUTH: return "postauth";
77         case RLM::PRE_ACCT: return "preacct";
78         case RLM::ACCOUNT: return "accounting";
79     }
80     return "";
81 }
82
83 enum Packet
84 {
85     PING,
86     PONG,
87     DATA
88 };
89
90 std::map<std::string, Packet> packetCodes;
91 std::map<std::string, bool> resultCodes;
92
93 class PacketParser : public EnumParser<Packet>
94 {
95     public:
96         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
97             : EnumParser(next, packet, packetStr, packetCodes)
98         {
99             if (!packetCodes.empty())
100                 return;
101             packetCodes["ping"] = PING;
102             packetCodes["pong"] = PONG;
103             packetCodes["data"] = DATA;
104         }
105 };
106
107 class ResultParser : public EnumParser<bool>
108 {
109     public:
110         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
111             : EnumParser(next, result, resultStr, resultCodes)
112         {
113             if (!resultCodes.empty())
114                 return;
115             resultCodes["no"] = false;
116             resultCodes["ok"] = true;
117         }
118 };
119
120 class TopParser : public NodeParser
121 {
122     public:
123         typedef void (*Callback) (void* /*data*/);
124         TopParser(Callback callback, void* data)
125             : m_packet(PING),
126               m_result(false),
127               m_packetParser(this, m_packet, m_packetStr),
128               m_resultParser(this, m_result, m_resultStr),
129               m_replyParser(this, m_reply),
130               m_modifyParser(this, m_modify),
131               m_callback(callback), m_data(data)
132         {}
133
134         virtual NodeParser* parseStartMap() { return this; }
135         virtual NodeParser* parseMapKey(const std::string& value)
136         {
137             std::string key = ToLower(value);
138
139             if (key == "packet")
140                 return &m_packetParser;
141             else if (key == "result")
142                 return &m_resultParser;
143             else if (key == "reply")
144                 return &m_replyParser;
145             else if (key == "modify")
146                 return &m_modifyParser;
147
148             return this;
149         }
150         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
151
152         const std::string& packetStr() const { return m_packetStr; }
153         Packet packet() const { return m_packet; }
154         const std::string& resultStr() const { return m_resultStr; }
155         bool result() const { return m_result; }
156         const PairsParser::Pairs& reply() const { return m_reply; }
157         const PairsParser::Pairs& modify() const { return m_modify; }
158
159     private:
160         std::string m_packetStr;
161         Packet m_packet;
162         std::string m_resultStr;
163         bool m_result;
164         PairsParser::Pairs m_reply;
165         PairsParser::Pairs m_modify;
166
167         PacketParser m_packetParser;
168         ResultParser m_resultParser;
169         PairsParser m_replyParser;
170         PairsParser m_modifyParser;
171
172         Callback m_callback;
173         void* m_data;
174 };
175
176 class ProtoParser : public Parser
177 {
178     public:
179         ProtoParser(TopParser::Callback callback, void* data)
180             : Parser( &m_topParser ),
181               m_topParser(callback, data)
182         {}
183
184         const std::string& packetStr() const { return m_topParser.packetStr(); }
185         Packet packet() const { return m_topParser.packet(); }
186         const std::string& resultStr() const { return m_topParser.resultStr(); }
187         bool result() const { return m_topParser.result(); }
188         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
189         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
190
191     private:
192         TopParser m_topParser;
193 };
194
195 class PacketGen : public Gen
196 {
197     public:
198         PacketGen(const std::string& type)
199             : m_type(type)
200         {
201             m_gen.add("packet", m_type);
202         }
203         void run(yajl_gen_t* handle) const
204         {
205             m_gen.run(handle);
206         }
207         PacketGen& add(const std::string& key, const std::string& value)
208         {
209             m_gen.add(key, new StringGen(value));
210             return *this;
211         }
212         PacketGen& add(const std::string& key, MapGen& map)
213         {
214             m_gen.add(key, map);
215             return *this;
216         }
217     private:
218         MapGen m_gen;
219         StringGen m_type;
220 };
221
222 }
223
224 class Conn::Impl
225 {
226 public:
227     Impl(const std::string& address, Callback callback, void* data);
228     ~Impl();
229
230     bool stop();
231     bool connected() const { return m_connected; }
232
233     bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
234
235 private:
236     ChannelConfig m_config;
237
238     int m_sock;
239
240     bool m_running;
241     bool m_stopped;
242
243     time_t m_lastPing;
244     time_t m_lastActivity;
245
246     pthread_t m_thread;
247     pthread_mutex_t m_mutex;
248
249     Callback m_callback;
250     void* m_data;
251
252     ProtoParser m_parser;
253
254     bool m_connected;
255
256     void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password);
257     void m_writePairBlock(const PAIRS& source);
258     PAIRS m_readPairBlock();
259
260     static void* run(void* );
261
262     void runImpl();
263
264     int connect();
265     int connectTCP();
266     int connectUNIX();
267
268     bool read();
269     bool tick();
270
271     static void process(void* data);
272     void processPing();
273     void processPong();
274     void processData();
275     bool sendPing();
276     bool sendPong();
277
278     static bool write(void* data, const char* buf, size_t size);
279 };
280
281 ChannelConfig::ChannelConfig(std::string addr)
282 {
283     // unix:pass@/var/run/stg.sock
284     // tcp:secret@192.168.0.1:12345
285     // udp:key@isp.com.ua:54321
286
287     size_t pos = addr.find_first_of(':');
288     if (pos == std::string::npos)
289         throw Error("Missing transport name.");
290     transport = ToLower(addr.substr(0, pos));
291     addr = addr.substr(pos + 1);
292     if (addr.empty())
293         throw Error("Missing address to connect to.");
294     pos = addr.find_first_of('@');
295     if (pos != std::string::npos) {
296         key = addr.substr(0, pos);
297         addr = addr.substr(pos + 1);
298         if (addr.empty())
299             throw Error("Missing address to connect to.");
300     }
301     if (transport == "unix")
302     {
303         address = addr;
304         return;
305     }
306     pos = addr.find_first_of(':');
307     if (pos == std::string::npos)
308         throw Error("Missing port.");
309     address = addr.substr(0, pos);
310     portStr = addr.substr(pos + 1);
311     if (str2x(portStr, port))
312         throw Error("Invalid port value.");
313 }
314
315 Conn::Conn(const std::string& address, Callback callback, void* data)
316     : m_impl(new Impl(address, callback, data))
317 {
318 }
319
320 Conn::~Conn()
321 {
322 }
323
324 bool Conn::stop()
325 {
326     return m_impl->stop();
327 }
328
329 bool Conn::connected() const
330 {
331     return m_impl->connected();
332 }
333
334 bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
335 {
336     return m_impl->request(type, userName, password, pairs);
337 }
338
339 Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
340     : m_config(address),
341       m_sock(connect()),
342       m_running(false),
343       m_stopped(true),
344       m_lastPing(time(NULL)),
345       m_lastActivity(m_lastPing),
346       m_callback(callback),
347       m_data(data),
348       m_parser(&Conn::Impl::process, this),
349       m_connected(true)
350 {
351     pthread_mutex_init(&m_mutex, NULL);
352     int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
353     if (res != 0)
354         throw Error("Failed to create thread: " + std::string(strerror(errno)));
355 }
356
357 Conn::Impl::~Impl()
358 {
359     stop();
360     shutdown(m_sock, SHUT_RDWR);
361     close(m_sock);
362     pthread_mutex_destroy(&m_mutex);
363 }
364
365 bool Conn::Impl::stop()
366 {
367     m_connected = false;
368
369     if (m_stopped)
370         return true;
371
372     m_running = false;
373
374     for (size_t i = 0; i < 25 && !m_stopped; i++) {
375         struct timespec ts = {0, 200000000};
376         nanosleep(&ts, NULL);
377     }
378
379     if (m_stopped) {
380         pthread_join(m_thread, NULL);
381         return true;
382     }
383
384     return false;
385 }
386
387 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
388 {
389     MapGen map;
390     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
391         map.add(it->first, new StringGen(it->second));
392     map.add("Radius-Username", new StringGen(userName));
393     map.add("Radius-Userpass", new StringGen(password));
394
395     PacketGen gen("data");
396     gen.add("stage", toStage(type))
397        .add("pairs", map);
398
399     STG_LOCKER lock(m_mutex);
400
401     m_lastPing = time(NULL);
402
403     return generate(gen, &Conn::Impl::write, this);
404 }
405
406 void Conn::Impl::runImpl()
407 {
408     m_running = true;
409
410     while (m_running) {
411         fd_set fds;
412
413         FD_ZERO(&fds);
414         FD_SET(m_sock, &fds);
415
416         struct timeval tv;
417         tv.tv_sec = 0;
418         tv.tv_usec = 500000;
419
420         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
421         if (res < 0)
422         {
423             if (errno == EINTR)
424                 continue;
425             RadLog("'select' is failed: %s", strerror(errno));
426             break;
427         }
428
429         if (!m_running)
430             break;
431
432         STG_LOCKER lock(m_mutex);
433
434         if (res > 0)
435         {
436             if (FD_ISSET(m_sock, &fds))
437                 m_running = read();
438         }
439         else
440             m_running = tick();
441     }
442
443     m_connected = false;
444     m_stopped = true;
445 }
446
447 int Conn::Impl::connect()
448 {
449     if (m_config.transport == "tcp")
450         return connectTCP();
451     else if (m_config.transport == "unix")
452         return connectUNIX();
453     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
454 }
455
456 int Conn::Impl::connectTCP()
457 {
458     addrinfo hints;
459     memset(&hints, 0, sizeof(addrinfo));
460
461     hints.ai_family = AF_INET;       /* Allow IPv4 */
462     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
463     hints.ai_flags = 0;     /* For wildcard IP address */
464     hints.ai_protocol = 0;           /* Any protocol */
465     hints.ai_canonname = NULL;
466     hints.ai_addr = NULL;
467     hints.ai_next = NULL;
468
469     addrinfo* ais = NULL;
470     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
471     if (res != 0)
472         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
473
474     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
475     {
476         int fd = socket(AF_INET, SOCK_STREAM, 0);
477         if (fd == -1)
478         {
479             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
480             freeaddrinfo(ais);
481             throw error;
482         }
483         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
484         {
485             shutdown(fd, SHUT_RDWR);
486             close(fd);
487             RadLog("'connect' is failed: %s", strerror(errno));
488             continue;
489         }
490         freeaddrinfo(ais);
491         return fd;
492     }
493
494     freeaddrinfo(ais);
495
496     throw Error("Failed to resolve '" + m_config.address);
497 };
498
499 int Conn::Impl::connectUNIX()
500 {
501     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
502     if (fd == -1)
503         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
504     struct sockaddr_un addr;
505     memset(&addr, 0, sizeof(addr));
506     addr.sun_family = AF_UNIX;
507     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
508     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
509     {
510         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
511         shutdown(fd, SHUT_RDWR);
512         close(fd);
513         throw error;
514     }
515     return fd;
516 }
517
518 bool Conn::Impl::read()
519 {
520     static std::vector<char> buffer(1024);
521     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
522     if (res < 0)
523     {
524         RadLog("Failed to read data: %s", strerror(errno));
525         return false;
526     }
527     m_lastActivity = time(NULL);
528     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
529     if (res == 0)
530     {
531         m_parser.last();
532         return false;
533     }
534     return m_parser.append(buffer.data(), res);
535 }
536
537 bool Conn::Impl::tick()
538 {
539     time_t now = time(NULL);
540     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
541     {
542         int delta = difftime(now, m_lastActivity);
543         RadLog("Connection timeout: %d sec.", delta);
544         //m_logger("Connection to " + m_remote + " timed out.");
545         return false;
546     }
547     if (difftime(now, m_lastPing) > PING_TIMEOUT)
548     {
549         int delta = difftime(now, m_lastPing);
550         RadLog("Ping timeout: %d sec. Sending ping...", delta);
551         sendPing();
552     }
553     return true;
554 }
555
556 void Conn::Impl::process(void* data)
557 {
558     Impl& impl = *static_cast<Impl*>(data);
559     switch (impl.m_parser.packet())
560     {
561         case PING:
562             impl.processPing();
563             return;
564         case PONG:
565             impl.processPong();
566             return;
567         case DATA:
568             impl.processData();
569             return;
570     }
571     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
572 }
573
574 void Conn::Impl::processPing()
575 {
576     RadLog("Got ping, sending pong.");
577     sendPong();
578 }
579
580 void Conn::Impl::processPong()
581 {
582     RadLog("Got pong.");
583     m_lastActivity = time(NULL);
584 }
585
586 void Conn::Impl::processData()
587 {
588     RESULT data;
589     RadLog("Got data.");
590     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
591         data.reply.push_back(std::make_pair(it->first, it->second));
592     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
593         data.modify.push_back(std::make_pair(it->first, it->second));
594     m_callback(m_data, data, m_parser.result());
595 }
596
597 bool Conn::Impl::sendPing()
598 {
599     PacketGen gen("ping");
600
601     m_lastPing = time(NULL);
602
603     return generate(gen, &Conn::Impl::write, this);
604 }
605
606 bool Conn::Impl::sendPong()
607 {
608     PacketGen gen("pong");
609
610     m_lastPing = time(NULL);
611
612     return generate(gen, &Conn::Impl::write, this);
613 }
614
615 bool Conn::Impl::write(void* data, const char* buf, size_t size)
616 {
617     RadLog("Sending JSON:");
618     std::string json(buf, size);
619     RadLog("%s", json.c_str());
620     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
621     while (size > 0)
622     {
623         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
624         if (res < 0)
625         {
626             impl.m_connected = false;
627             RadLog("Failed to write data: %s.", strerror(errno));
628             return false;
629         }
630         size -= res;
631     }
632     return true;
633 }
634
635 void* Conn::Impl::run(void* data)
636 {
637     Impl& impl = *static_cast<Impl*>(data);
638     impl.runImpl();
639     return NULL;
640 }