]> git.stg.codes - stg.git/blob - projects/rlm_stg/conn.cpp
13a9cea3294285022aaf1a93f25e797997215b41
[stg.git] / projects / rlm_stg / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/locker.h"
28
29 #include <cerrno>
30 #include <cstring>
31
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 #include <sys/un.h> // UNIX
35 #include <netinet/in.h> // IP
36 #include <netinet/tcp.h> // TCP
37 #include <netdb.h>
38
39 namespace RLM = STG::RLM;
40
41 using RLM::Conn;
42 using STG::JSON::Parser;
43 using STG::JSON::PairsParser;
44 using STG::JSON::EnumParser;
45 using STG::JSON::NodeParser;
46 using STG::JSON::Gen;
47 using STG::JSON::MapGen;
48 using STG::JSON::StringGen;
49
50 namespace
51 {
52
53 double CONN_TIMEOUT = 60;
54 double PING_TIMEOUT = 10;
55
56 struct ChannelConfig {
57     struct Error : std::runtime_error {
58         explicit Error(const std::string& message) : runtime_error(message) {}
59     };
60
61     explicit ChannelConfig(std::string address);
62
63     std::string transport;
64     std::string key;
65     std::string address;
66     std::string portStr;
67     uint16_t port;
68 };
69
70 std::string toStage(RLM::REQUEST_TYPE type)
71 {
72     switch (type)
73     {
74         case RLM::AUTHORIZE: return "authorize";
75         case RLM::AUTHENTICATE: return "authenticate";
76         case RLM::POST_AUTH: return "postauth";
77         case RLM::PRE_ACCT: return "preacct";
78         case RLM::ACCOUNT: return "accounting";
79     }
80     return "";
81 }
82
83 enum Packet
84 {
85     PING,
86     PONG,
87     DATA
88 };
89
90 std::map<std::string, Packet> packetCodes;
91 std::map<std::string, bool> resultCodes;
92
93 class PacketParser : public EnumParser<Packet>
94 {
95     public:
96         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
97             : EnumParser(next, packet, packetStr, packetCodes)
98         {
99             if (!packetCodes.empty())
100                 return;
101             packetCodes["ping"] = PING;
102             packetCodes["pong"] = PONG;
103             packetCodes["data"] = DATA;
104         }
105 };
106
107 class ResultParser : public EnumParser<bool>
108 {
109     public:
110         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
111             : EnumParser(next, result, resultStr, resultCodes)
112         {
113             if (!resultCodes.empty())
114                 return;
115             resultCodes["no"] = false;
116             resultCodes["ok"] = true;
117         }
118 };
119
120 class TopParser : public NodeParser
121 {
122     public:
123         typedef void (*Callback) (void* /*data*/);
124         TopParser(Callback callback, void* data)
125             : m_packet(PING),
126               m_result(false),
127               m_packetParser(this, m_packet, m_packetStr),
128               m_resultParser(this, m_result, m_resultStr),
129               m_replyParser(this, m_reply),
130               m_modifyParser(this, m_modify),
131               m_callback(callback), m_data(data)
132         {}
133
134         virtual NodeParser* parseStartMap() { return this; }
135         virtual NodeParser* parseMapKey(const std::string& value)
136         {
137             std::string key = ToLower(value);
138
139             if (key == "packet")
140                 return &m_packetParser;
141             else if (key == "result")
142                 return &m_resultParser;
143             else if (key == "reply")
144                 return &m_replyParser;
145             else if (key == "modify")
146                 return &m_modifyParser;
147
148             return this;
149         }
150         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
151
152         const std::string& packetStr() const { return m_packetStr; }
153         Packet packet() const { return m_packet; }
154         const std::string& resultStr() const { return m_resultStr; }
155         bool result() const { return m_result; }
156         const PairsParser::Pairs& reply() const { return m_reply; }
157         const PairsParser::Pairs& modify() const { return m_modify; }
158
159     private:
160         std::string m_packetStr;
161         Packet m_packet;
162         std::string m_resultStr;
163         bool m_result;
164         PairsParser::Pairs m_reply;
165         PairsParser::Pairs m_modify;
166
167         PacketParser m_packetParser;
168         ResultParser m_resultParser;
169         PairsParser m_replyParser;
170         PairsParser m_modifyParser;
171
172         Callback m_callback;
173         void* m_data;
174 };
175
176 class ProtoParser : public Parser
177 {
178     public:
179         ProtoParser(TopParser::Callback callback, void* data)
180             : Parser( &m_topParser ),
181               m_topParser(callback, data)
182         {}
183
184         const std::string& packetStr() const { return m_topParser.packetStr(); }
185         Packet packet() const { return m_topParser.packet(); }
186         const std::string& resultStr() const { return m_topParser.resultStr(); }
187         bool result() const { return m_topParser.result(); }
188         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
189         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
190
191     private:
192         TopParser m_topParser;
193 };
194
195 class PacketGen : public Gen
196 {
197     public:
198         explicit PacketGen(const std::string& type)
199             : m_type(type)
200         {
201             m_gen.add("packet", m_type);
202         }
203         void run(yajl_gen_t* handle) const
204         {
205             m_gen.run(handle);
206         }
207         PacketGen& add(const std::string& key, const std::string& value)
208         {
209             m_gen.add(key, new StringGen(value));
210             return *this;
211         }
212         PacketGen& add(const std::string& key, MapGen& map)
213         {
214             m_gen.add(key, map);
215             return *this;
216         }
217     private:
218         MapGen m_gen;
219         StringGen m_type;
220 };
221
222 }
223
224 class Conn::Impl
225 {
226 public:
227     Impl(const std::string& address, Callback callback, void* data);
228     ~Impl();
229
230     bool stop();
231     bool connected() const { return m_connected; }
232
233     bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
234
235 private:
236     ChannelConfig m_config;
237
238     int m_sock;
239
240     bool m_running;
241     bool m_stopped;
242
243     time_t m_lastPing;
244     time_t m_lastActivity;
245
246     pthread_t m_thread;
247     pthread_mutex_t m_mutex;
248
249     Callback m_callback;
250     void* m_data;
251
252     ProtoParser m_parser;
253
254     bool m_connected;
255
256     void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password);
257     void m_writePairBlock(const PAIRS& source);
258     PAIRS m_readPairBlock();
259
260     static void* run(void* );
261
262     void runImpl();
263
264     bool start();
265
266     int connect();
267     int connectTCP();
268     int connectUNIX();
269
270     bool read();
271     bool tick();
272
273     static void process(void* data);
274     void processPing();
275     void processPong();
276     void processData();
277     bool sendPing();
278     bool sendPong();
279
280     static bool write(void* data, const char* buf, size_t size);
281 };
282
283 ChannelConfig::ChannelConfig(std::string addr)
284 {
285     // unix:pass@/var/run/stg.sock
286     // tcp:secret@192.168.0.1:12345
287     // udp:key@isp.com.ua:54321
288
289     size_t pos = addr.find_first_of(':');
290     if (pos == std::string::npos)
291         throw Error("Missing transport name.");
292     transport = ToLower(addr.substr(0, pos));
293     addr = addr.substr(pos + 1);
294     if (addr.empty())
295         throw Error("Missing address to connect to.");
296     pos = addr.find_first_of('@');
297     if (pos != std::string::npos) {
298         key = addr.substr(0, pos);
299         addr = addr.substr(pos + 1);
300         if (addr.empty())
301             throw Error("Missing address to connect to.");
302     }
303     if (transport == "unix")
304     {
305         address = addr;
306         return;
307     }
308     pos = addr.find_first_of(':');
309     if (pos == std::string::npos)
310         throw Error("Missing port.");
311     address = addr.substr(0, pos);
312     portStr = addr.substr(pos + 1);
313     if (str2x(portStr, port))
314         throw Error("Invalid port value.");
315 }
316
317 Conn::Conn(const std::string& address, Callback callback, void* data)
318     : m_impl(new Impl(address, callback, data))
319 {
320 }
321
322 Conn::~Conn()
323 {
324 }
325
326 bool Conn::stop()
327 {
328     return m_impl->stop();
329 }
330
331 bool Conn::connected() const
332 {
333     return m_impl->connected();
334 }
335
336 bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
337 {
338     return m_impl->request(type, userName, password, pairs);
339 }
340
341 Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
342     : m_config(address),
343       m_sock(connect()),
344       m_running(false),
345       m_stopped(true),
346       m_lastPing(time(NULL)),
347       m_lastActivity(m_lastPing),
348       m_callback(callback),
349       m_data(data),
350       m_parser(&Conn::Impl::process, this),
351       m_connected(true)
352 {
353     pthread_mutex_init(&m_mutex, NULL);
354 }
355
356 Conn::Impl::~Impl()
357 {
358     stop();
359     shutdown(m_sock, SHUT_RDWR);
360     close(m_sock);
361     pthread_mutex_destroy(&m_mutex);
362 }
363
364 bool Conn::Impl::stop()
365 {
366     m_connected = false;
367
368     if (m_stopped)
369         return true;
370
371     m_running = false;
372
373     for (size_t i = 0; i < 25 && !m_stopped; i++) {
374         struct timespec ts = {0, 200000000};
375         nanosleep(&ts, NULL);
376     }
377
378     if (m_stopped) {
379         pthread_join(m_thread, NULL);
380         return true;
381     }
382
383     return false;
384 }
385
386 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
387 {
388     if (!m_running)
389         if (!start())
390             return false;
391     MapGen map;
392     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
393         map.add(it->first, new StringGen(it->second));
394     map.add("Radius-Username", new StringGen(userName));
395     map.add("Radius-Userpass", new StringGen(password));
396
397     PacketGen gen("data");
398     gen.add("stage", toStage(type))
399        .add("pairs", map);
400
401     STG_LOCKER lock(m_mutex);
402
403     m_lastPing = time(NULL);
404
405     return generate(gen, &Conn::Impl::write, this);
406 }
407
408 void Conn::Impl::runImpl()
409 {
410     m_running = true;
411
412     while (m_running) {
413         fd_set fds;
414
415         FD_ZERO(&fds);
416         FD_SET(m_sock, &fds);
417
418         struct timeval tv;
419         tv.tv_sec = 0;
420         tv.tv_usec = 500000;
421
422         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
423         if (res < 0)
424         {
425             if (errno == EINTR)
426                 continue;
427             RadLog("'select' is failed: %s", strerror(errno));
428             break;
429         }
430
431
432         if (!m_running)
433             break;
434
435         STG_LOCKER lock(m_mutex);
436
437         if (res > 0)
438         {
439             if (FD_ISSET(m_sock, &fds))
440                 m_running = read();
441         }
442         else
443             m_running = tick();
444     }
445
446     m_connected = false;
447     m_stopped = true;
448 }
449
450 bool Conn::Impl::start()
451 {
452     int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
453     if (res != 0)
454         return false;
455     return true;
456 }
457
458 int Conn::Impl::connect()
459 {
460     if (m_config.transport == "tcp")
461         return connectTCP();
462     else if (m_config.transport == "unix")
463         return connectUNIX();
464     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
465 }
466
467 int Conn::Impl::connectTCP()
468 {
469     addrinfo hints;
470     memset(&hints, 0, sizeof(addrinfo));
471
472     hints.ai_family = AF_INET;       /* Allow IPv4 */
473     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
474     hints.ai_flags = 0;     /* For wildcard IP address */
475     hints.ai_protocol = 0;           /* Any protocol */
476     hints.ai_canonname = NULL;
477     hints.ai_addr = NULL;
478     hints.ai_next = NULL;
479
480     addrinfo* ais = NULL;
481     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
482     if (res != 0)
483         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
484
485     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
486     {
487         int fd = socket(AF_INET, SOCK_STREAM, 0);
488         if (fd == -1)
489         {
490             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
491             freeaddrinfo(ais);
492             throw error;
493         }
494         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
495         {
496             shutdown(fd, SHUT_RDWR);
497             close(fd);
498             RadLog("'connect' is failed: %s", strerror(errno));
499             continue;
500         }
501         freeaddrinfo(ais);
502         return fd;
503     }
504
505     freeaddrinfo(ais);
506
507     throw Error("Failed to resolve '" + m_config.address);
508 };
509
510 int Conn::Impl::connectUNIX()
511 {
512     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
513     if (fd == -1)
514         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
515     struct sockaddr_un addr;
516     memset(&addr, 0, sizeof(addr));
517     addr.sun_family = AF_UNIX;
518     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
519     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
520     {
521         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
522         shutdown(fd, SHUT_RDWR);
523         close(fd);
524         throw error;
525     }
526     return fd;
527 }
528
529 bool Conn::Impl::read()
530 {
531     static std::vector<char> buffer(1024);
532     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
533     if (res < 0)
534     {
535         RadLog("Failed to read data: %s", strerror(errno));
536         return false;
537     }
538     m_lastActivity = time(NULL);
539     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
540     if (res == 0)
541     {
542         m_parser.last();
543         return false;
544     }
545     return m_parser.append(buffer.data(), res);
546 }
547
548 bool Conn::Impl::tick()
549 {
550     time_t now = time(NULL);
551     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
552     {
553         int delta = difftime(now, m_lastActivity);
554         RadLog("Connection timeout: %d sec.", delta);
555         //m_logger("Connection to " + m_remote + " timed out.");
556         return false;
557     }
558     if (difftime(now, m_lastPing) > PING_TIMEOUT)
559     {
560         int delta = difftime(now, m_lastPing);
561         RadLog("Ping timeout: %d sec. Sending ping...", delta);
562         sendPing();
563     }
564     return true;
565 }
566
567 void Conn::Impl::process(void* data)
568 {
569     Impl& impl = *static_cast<Impl*>(data);
570     switch (impl.m_parser.packet())
571     {
572         case PING:
573             impl.processPing();
574             return;
575         case PONG:
576             impl.processPong();
577             return;
578         case DATA:
579             impl.processData();
580             return;
581     }
582     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
583 }
584
585 void Conn::Impl::processPing()
586 {
587     sendPong();
588 }
589
590 void Conn::Impl::processPong()
591 {
592     m_lastActivity = time(NULL);
593 }
594
595 void Conn::Impl::processData()
596 {
597     RESULT data;
598     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
599         data.reply.push_back(std::make_pair(it->first, it->second));
600     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
601         data.modify.push_back(std::make_pair(it->first, it->second));
602     m_callback(m_data, data, m_parser.result());
603 }
604
605 bool Conn::Impl::sendPing()
606 {
607     PacketGen gen("ping");
608
609     m_lastPing = time(NULL);
610
611     return generate(gen, &Conn::Impl::write, this);
612 }
613
614 bool Conn::Impl::sendPong()
615 {
616     PacketGen gen("pong");
617
618     m_lastPing = time(NULL);
619
620     return generate(gen, &Conn::Impl::write, this);
621 }
622
623 bool Conn::Impl::write(void* data, const char* buf, size_t size)
624 {
625     std::string json(buf, size);
626     RadLog("Sending JSON: %s", json.c_str());
627     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
628     while (size > 0)
629     {
630         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
631         if (res < 0)
632         {
633             impl.m_connected = false;
634             RadLog("Failed to write data: %s.", strerror(errno));
635             return false;
636         }
637         size -= res;
638     }
639     return true;
640 }
641
642 void* Conn::Impl::run(void* data)
643 {
644     Impl& impl = *static_cast<Impl*>(data);
645     impl.runImpl();
646     return NULL;
647 }