]> git.stg.codes - stg.git/blob - projects/rlm_stg/conn.cpp
3ee720d19f2fae0b904b75266db225c2cf9fb3d1
[stg.git] / projects / rlm_stg / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/locker.h"
28
29 #include <cerrno>
30 #include <cstring>
31
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 #include <sys/un.h> // UNIX
35 #include <netinet/in.h> // IP
36 #include <netinet/tcp.h> // TCP
37 #include <netdb.h>
38
39 namespace RLM = STG::RLM;
40
41 using RLM::Conn;
42 using STG::JSON::Parser;
43 using STG::JSON::PairsParser;
44 using STG::JSON::EnumParser;
45 using STG::JSON::NodeParser;
46 using STG::JSON::Gen;
47 using STG::JSON::MapGen;
48 using STG::JSON::StringGen;
49
50 namespace
51 {
52
53 double CONN_TIMEOUT = 60;
54 double PING_TIMEOUT = 10;
55
56 struct ChannelConfig {
57     struct Error : std::runtime_error {
58         explicit Error(const std::string& message) : runtime_error(message) {}
59     };
60
61     explicit ChannelConfig(std::string address);
62
63     std::string transport;
64     std::string key;
65     std::string address;
66     std::string portStr;
67     uint16_t port;
68 };
69
70 std::string toStage(RLM::REQUEST_TYPE type)
71 {
72     switch (type)
73     {
74         case RLM::AUTHORIZE: return "authorize";
75         case RLM::AUTHENTICATE: return "authenticate";
76         case RLM::POST_AUTH: return "postauth";
77         case RLM::PRE_ACCT: return "preacct";
78         case RLM::ACCOUNT: return "accounting";
79     }
80     return "";
81 }
82
83 enum Packet
84 {
85     PING,
86     PONG,
87     DATA
88 };
89
90 std::map<std::string, Packet> packetCodes;
91 std::map<std::string, bool> resultCodes;
92
93 class PacketParser : public EnumParser<Packet>
94 {
95     public:
96         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
97             : EnumParser(next, packet, packetStr, packetCodes)
98         {
99             if (!packetCodes.empty())
100                 return;
101             packetCodes["ping"] = PING;
102             packetCodes["pong"] = PONG;
103             packetCodes["data"] = DATA;
104         }
105 };
106
107 class ResultParser : public EnumParser<bool>
108 {
109     public:
110         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
111             : EnumParser(next, result, resultStr, resultCodes)
112         {
113             if (!resultCodes.empty())
114                 return;
115             resultCodes["no"] = false;
116             resultCodes["ok"] = true;
117         }
118 };
119
120 class TopParser : public NodeParser
121 {
122     public:
123         typedef void (*Callback) (void* /*data*/);
124         TopParser(Callback callback, void* data)
125             : m_packet(PING),
126               m_result(false),
127               m_packetParser(this, m_packet, m_packetStr),
128               m_resultParser(this, m_result, m_resultStr),
129               m_replyParser(this, m_reply),
130               m_modifyParser(this, m_modify),
131               m_callback(callback), m_data(data)
132         {}
133
134         virtual NodeParser* parseStartMap() { return this; }
135         virtual NodeParser* parseMapKey(const std::string& value)
136         {
137             std::string key = ToLower(value);
138
139             if (key == "packet")
140                 return &m_packetParser;
141             else if (key == "result")
142                 return &m_resultParser;
143             else if (key == "reply")
144                 return &m_replyParser;
145             else if (key == "modify")
146                 return &m_modifyParser;
147
148             return this;
149         }
150         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
151
152         const std::string& packetStr() const { return m_packetStr; }
153         Packet packet() const { return m_packet; }
154         const std::string& resultStr() const { return m_resultStr; }
155         bool result() const { return m_result; }
156         const PairsParser::Pairs& reply() const { return m_reply; }
157         const PairsParser::Pairs& modify() const { return m_modify; }
158
159     private:
160         std::string m_packetStr;
161         Packet m_packet;
162         std::string m_resultStr;
163         bool m_result;
164         PairsParser::Pairs m_reply;
165         PairsParser::Pairs m_modify;
166
167         PacketParser m_packetParser;
168         ResultParser m_resultParser;
169         PairsParser m_replyParser;
170         PairsParser m_modifyParser;
171
172         Callback m_callback;
173         void* m_data;
174 };
175
176 class ProtoParser : public Parser
177 {
178     public:
179         ProtoParser(TopParser::Callback callback, void* data)
180             : Parser( &m_topParser ),
181               m_topParser(callback, data)
182         {}
183
184         const std::string& packetStr() const { return m_topParser.packetStr(); }
185         Packet packet() const { return m_topParser.packet(); }
186         const std::string& resultStr() const { return m_topParser.resultStr(); }
187         bool result() const { return m_topParser.result(); }
188         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
189         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
190
191     private:
192         TopParser m_topParser;
193 };
194
195 class PacketGen : public Gen
196 {
197     public:
198         explicit PacketGen(const std::string& type)
199             : m_type(type)
200         {
201             m_gen.add("packet", m_type);
202         }
203         void run(yajl_gen_t* handle) const
204         {
205             m_gen.run(handle);
206         }
207         PacketGen& add(const std::string& key, const std::string& value)
208         {
209             m_gen.add(key, new StringGen(value));
210             return *this;
211         }
212         PacketGen& add(const std::string& key, MapGen& map)
213         {
214             m_gen.add(key, map);
215             return *this;
216         }
217     private:
218         MapGen m_gen;
219         StringGen m_type;
220 };
221
222 }
223
224 class Conn::Impl
225 {
226 public:
227     Impl(const std::string& address, Callback callback, void* data);
228     ~Impl();
229
230     bool stop();
231     bool connected() const { return m_connected; }
232
233     bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
234
235 private:
236     ChannelConfig m_config;
237
238     int m_sock;
239
240     bool m_running;
241     bool m_stopped;
242
243     time_t m_lastPing;
244     time_t m_lastActivity;
245
246     pthread_t m_thread;
247     pthread_mutex_t m_mutex;
248
249     Callback m_callback;
250     void* m_data;
251
252     ProtoParser m_parser;
253
254     bool m_connected;
255
256     void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password);
257     void m_writePairBlock(const PAIRS& source);
258     PAIRS m_readPairBlock();
259
260     static void* run(void* );
261
262     void runImpl();
263
264     int connect();
265     int connectTCP();
266     int connectUNIX();
267
268     bool read();
269     bool tick();
270
271     static void process(void* data);
272     void processPing();
273     void processPong();
274     void processData();
275     bool sendPing();
276     bool sendPong();
277
278     static bool write(void* data, const char* buf, size_t size);
279 };
280
281 ChannelConfig::ChannelConfig(std::string addr)
282 {
283     // unix:pass@/var/run/stg.sock
284     // tcp:secret@192.168.0.1:12345
285     // udp:key@isp.com.ua:54321
286
287     size_t pos = addr.find_first_of(':');
288     if (pos == std::string::npos)
289         throw Error("Missing transport name.");
290     transport = ToLower(addr.substr(0, pos));
291     addr = addr.substr(pos + 1);
292     if (addr.empty())
293         throw Error("Missing address to connect to.");
294     pos = addr.find_first_of('@');
295     if (pos != std::string::npos) {
296         key = addr.substr(0, pos);
297         addr = addr.substr(pos + 1);
298         if (addr.empty())
299             throw Error("Missing address to connect to.");
300     }
301     if (transport == "unix")
302     {
303         address = addr;
304         return;
305     }
306     pos = addr.find_first_of(':');
307     if (pos == std::string::npos)
308         throw Error("Missing port.");
309     address = addr.substr(0, pos);
310     portStr = addr.substr(pos + 1);
311     if (str2x(portStr, port))
312         throw Error("Invalid port value.");
313 }
314
315 Conn::Conn(const std::string& address, Callback callback, void* data)
316     : m_impl(new Impl(address, callback, data))
317 {
318 }
319
320 Conn::~Conn()
321 {
322 }
323
324 bool Conn::stop()
325 {
326     return m_impl->stop();
327 }
328
329 bool Conn::connected() const
330 {
331     return m_impl->connected();
332 }
333
334 bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
335 {
336     return m_impl->request(type, userName, password, pairs);
337 }
338
339 Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
340     : m_config(address),
341       m_sock(connect()),
342       m_running(false),
343       m_stopped(true),
344       m_lastPing(time(NULL)),
345       m_lastActivity(m_lastPing),
346       m_callback(callback),
347       m_data(data),
348       m_parser(&Conn::Impl::process, this),
349       m_connected(true)
350 {
351     RadLog("Created connection.");
352     pthread_mutex_init(&m_mutex, NULL);
353     int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
354     if (res != 0)
355         throw Error("Failed to create thread: " + std::string(strerror(errno)));
356 }
357
358 Conn::Impl::~Impl()
359 {
360     stop();
361     shutdown(m_sock, SHUT_RDWR);
362     close(m_sock);
363     pthread_mutex_destroy(&m_mutex);
364     RadLog("Deleted connection.");
365 }
366
367 bool Conn::Impl::stop()
368 {
369     m_connected = false;
370
371     if (m_stopped)
372         return true;
373
374     m_running = false;
375
376     for (size_t i = 0; i < 25 && !m_stopped; i++) {
377         struct timespec ts = {0, 200000000};
378         nanosleep(&ts, NULL);
379     }
380
381     if (m_stopped) {
382         pthread_join(m_thread, NULL);
383         return true;
384     }
385
386     return false;
387 }
388
389 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
390 {
391     MapGen map;
392     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
393         map.add(it->first, new StringGen(it->second));
394     map.add("Radius-Username", new StringGen(userName));
395     map.add("Radius-Userpass", new StringGen(password));
396
397     PacketGen gen("data");
398     gen.add("stage", toStage(type))
399        .add("pairs", map);
400
401     STG_LOCKER lock(m_mutex);
402
403     m_lastPing = time(NULL);
404
405     return generate(gen, &Conn::Impl::write, this);
406 }
407
408 void Conn::Impl::runImpl()
409 {
410     m_running = true;
411
412     RadLog("Run connection.");
413
414     while (m_running) {
415         fd_set fds;
416
417         FD_ZERO(&fds);
418         FD_SET(m_sock, &fds);
419
420         struct timeval tv;
421         tv.tv_sec = 0;
422         tv.tv_usec = 500000;
423
424         RadLog("Starting 'select'.");
425         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
426         RadLog("'select' result: %d.", res);
427         if (res < 0)
428         {
429             if (errno == EINTR)
430                 continue;
431             RadLog("'select' is failed: %s", strerror(errno));
432             break;
433         }
434
435
436         if (!m_running)
437             break;
438
439         STG_LOCKER lock(m_mutex);
440
441         if (res > 0)
442         {
443             RadLog("Got %d fds.", res);
444             if (FD_ISSET(m_sock, &fds))
445                 m_running = read();
446             RadLog("Read complete.");
447         }
448         else
449             m_running = tick();
450     }
451
452     RadLog("End running connection.");
453
454     m_connected = false;
455     m_stopped = true;
456 }
457
458 int Conn::Impl::connect()
459 {
460     if (m_config.transport == "tcp")
461         return connectTCP();
462     else if (m_config.transport == "unix")
463         return connectUNIX();
464     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
465 }
466
467 int Conn::Impl::connectTCP()
468 {
469     addrinfo hints;
470     memset(&hints, 0, sizeof(addrinfo));
471
472     hints.ai_family = AF_INET;       /* Allow IPv4 */
473     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
474     hints.ai_flags = 0;     /* For wildcard IP address */
475     hints.ai_protocol = 0;           /* Any protocol */
476     hints.ai_canonname = NULL;
477     hints.ai_addr = NULL;
478     hints.ai_next = NULL;
479
480     addrinfo* ais = NULL;
481     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
482     if (res != 0)
483         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
484
485     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
486     {
487         int fd = socket(AF_INET, SOCK_STREAM, 0);
488         if (fd == -1)
489         {
490             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
491             freeaddrinfo(ais);
492             throw error;
493         }
494         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
495         {
496             shutdown(fd, SHUT_RDWR);
497             close(fd);
498             RadLog("'connect' is failed: %s", strerror(errno));
499             continue;
500         }
501         freeaddrinfo(ais);
502         return fd;
503     }
504
505     freeaddrinfo(ais);
506
507     throw Error("Failed to resolve '" + m_config.address);
508 };
509
510 int Conn::Impl::connectUNIX()
511 {
512     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
513     if (fd == -1)
514         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
515     struct sockaddr_un addr;
516     memset(&addr, 0, sizeof(addr));
517     addr.sun_family = AF_UNIX;
518     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
519     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
520     {
521         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
522         shutdown(fd, SHUT_RDWR);
523         close(fd);
524         throw error;
525     }
526     return fd;
527 }
528
529 bool Conn::Impl::read()
530 {
531     static std::vector<char> buffer(1024);
532     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
533     if (res < 0)
534     {
535         RadLog("Failed to read data: %s", strerror(errno));
536         return false;
537     }
538     m_lastActivity = time(NULL);
539     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
540     if (res == 0)
541     {
542         m_parser.last();
543         return false;
544     }
545     return m_parser.append(buffer.data(), res);
546 }
547
548 bool Conn::Impl::tick()
549 {
550     time_t now = time(NULL);
551     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
552     {
553         int delta = difftime(now, m_lastActivity);
554         RadLog("Connection timeout: %d sec.", delta);
555         //m_logger("Connection to " + m_remote + " timed out.");
556         return false;
557     }
558     if (difftime(now, m_lastPing) > PING_TIMEOUT)
559     {
560         int delta = difftime(now, m_lastPing);
561         RadLog("Ping timeout: %d sec. Sending ping...", delta);
562         sendPing();
563     }
564     return true;
565 }
566
567 void Conn::Impl::process(void* data)
568 {
569     Impl& impl = *static_cast<Impl*>(data);
570     switch (impl.m_parser.packet())
571     {
572         case PING:
573             impl.processPing();
574             return;
575         case PONG:
576             impl.processPong();
577             return;
578         case DATA:
579             impl.processData();
580             return;
581     }
582     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
583 }
584
585 void Conn::Impl::processPing()
586 {
587     RadLog("Got ping, sending pong.");
588     sendPong();
589 }
590
591 void Conn::Impl::processPong()
592 {
593     RadLog("Got pong.");
594     m_lastActivity = time(NULL);
595 }
596
597 void Conn::Impl::processData()
598 {
599     RESULT data;
600     RadLog("Got data.");
601     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
602         data.reply.push_back(std::make_pair(it->first, it->second));
603     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
604         data.modify.push_back(std::make_pair(it->first, it->second));
605     m_callback(m_data, data, m_parser.result());
606 }
607
608 bool Conn::Impl::sendPing()
609 {
610     PacketGen gen("ping");
611
612     m_lastPing = time(NULL);
613
614     return generate(gen, &Conn::Impl::write, this);
615 }
616
617 bool Conn::Impl::sendPong()
618 {
619     PacketGen gen("pong");
620
621     m_lastPing = time(NULL);
622
623     return generate(gen, &Conn::Impl::write, this);
624 }
625
626 bool Conn::Impl::write(void* data, const char* buf, size_t size)
627 {
628     RadLog("Sending JSON:");
629     std::string json(buf, size);
630     RadLog("%s", json.c_str());
631     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
632     while (size > 0)
633     {
634         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
635         if (res < 0)
636         {
637             impl.m_connected = false;
638             RadLog("Failed to write data: %s.", strerror(errno));
639             return false;
640         }
641         RadLog("Send %d bytes.", res);
642         size -= res;
643     }
644     return true;
645 }
646
647 void* Conn::Impl::run(void* data)
648 {
649     Impl& impl = *static_cast<Impl*>(data);
650     impl.runImpl();
651     return NULL;
652 }