]> git.stg.codes - stg.git/blob - rlm_stg/conn.cpp
3589a9012d3c7d0a5f93a9c074b3ff5236a33538
[stg.git] / rlm_stg / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radlog.h"
24 #include "stgpair.h"
25
26 #include "stg/json_parser.h"
27 #include "stg/json_generator.h"
28 #include "stg/locker.h"
29
30 #include <cerrno>
31 #include <cstring>
32
33 #include <sys/types.h>
34 #include <sys/socket.h>
35 #include <sys/un.h> // UNIX
36 #include <netinet/in.h> // IP
37 #include <netinet/tcp.h> // TCP
38 #include <netdb.h>
39
40 namespace RLM = STG::RLM;
41
42 using RLM::Conn;
43 using STG::JSON::Parser;
44 using STG::JSON::PairsParser;
45 using STG::JSON::EnumParser;
46 using STG::JSON::NodeParser;
47 using STG::JSON::Gen;
48 using STG::JSON::MapGen;
49 using STG::JSON::StringGen;
50
51 namespace
52 {
53
54 double CONN_TIMEOUT = 60;
55 double PING_TIMEOUT = 10;
56
57 struct ChannelConfig {
58     struct Error : std::runtime_error {
59         explicit Error(const std::string& message) : runtime_error(message) {}
60     };
61
62     explicit ChannelConfig(std::string address);
63
64     std::string transport;
65     std::string key;
66     std::string address;
67     std::string portStr;
68     uint16_t port;
69 };
70
71 std::string toStage(RLM::REQUEST_TYPE type)
72 {
73     switch (type)
74     {
75         case RLM::AUTHORIZE: return "authorize";
76         case RLM::AUTHENTICATE: return "authenticate";
77         case RLM::POST_AUTH: return "postauth";
78         case RLM::PRE_ACCT: return "preacct";
79         case RLM::ACCOUNT: return "accounting";
80     }
81     return "";
82 }
83
84 enum Packet
85 {
86     PING,
87     PONG,
88     DATA
89 };
90
91 std::map<std::string, Packet> packetCodes;
92 std::map<std::string, bool> resultCodes;
93 std::map<std::string, int> returnCodes;
94
95 class PacketParser : public EnumParser<Packet>
96 {
97     public:
98         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
99             : EnumParser(next, packet, packetStr, packetCodes)
100         {
101             if (!packetCodes.empty())
102                 return;
103             packetCodes["ping"] = PING;
104             packetCodes["pong"] = PONG;
105             packetCodes["data"] = DATA;
106         }
107 };
108
109 class ResultParser : public EnumParser<bool>
110 {
111     public:
112         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
113             : EnumParser(next, result, resultStr, resultCodes)
114         {
115             if (!resultCodes.empty())
116                 return;
117             resultCodes["no"] = false;
118             resultCodes["ok"] = true;
119         }
120 };
121
122 class ReturnCodeParser : public EnumParser<int>
123 {
124     public:
125         ReturnCodeParser(NodeParser* next, int& returnCode, std::string& returnCodeStr)
126             : EnumParser(next, returnCode, returnCodeStr, returnCodes)
127         {
128             if (!returnCodes.empty())
129                 return;
130             returnCodes["reject"]   = STG_REJECT;
131             returnCodes["fail"]     = STG_FAIL;
132             returnCodes["ok"]       = STG_OK;
133             returnCodes["handled"]  = STG_HANDLED;
134             returnCodes["invalid"]  = STG_INVALID;
135             returnCodes["userlock"] = STG_USERLOCK;
136             returnCodes["notfound"] = STG_NOTFOUND;
137             returnCodes["noop"]     = STG_NOOP;
138             returnCodes["updated"]  = STG_UPDATED;
139         }
140 };
141
142 class TopParser : public NodeParser
143 {
144     public:
145         typedef void (*Callback) (void* /*data*/);
146         TopParser(Callback callback, void* data)
147             : m_packet(PING),
148               m_result(false),
149               m_returnCode(STG_REJECT),
150               m_packetParser(this, m_packet, m_packetStr),
151               m_resultParser(this, m_result, m_resultStr),
152               m_returnCodeParser(this, m_returnCode, m_returnCodeStr),
153               m_replyParser(this, m_reply),
154               m_modifyParser(this, m_modify),
155               m_callback(callback), m_data(data)
156         {}
157
158         virtual NodeParser* parseStartMap() { return this; }
159         virtual NodeParser* parseMapKey(const std::string& value)
160         {
161             std::string key = ToLower(value);
162
163             if (key == "packet")
164                 return &m_packetParser;
165             else if (key == "result")
166                 return &m_resultParser;
167             else if (key == "reply")
168                 return &m_replyParser;
169             else if (key == "modify")
170                 return &m_modifyParser;
171             else if (key == "return_code")
172                 return &m_returnCodeParser;
173
174             return this;
175         }
176         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
177
178         const std::string& packetStr() const { return m_packetStr; }
179         Packet packet() const { return m_packet; }
180         const std::string& resultStr() const { return m_resultStr; }
181         bool result() const { return m_result; }
182         const std::string& returnCodeStr() const { return m_returnCodeStr; }
183         int returnCode() const { return m_returnCode; }
184         const PairsParser::Pairs& reply() const { return m_reply; }
185         const PairsParser::Pairs& modify() const { return m_modify; }
186
187     private:
188         std::string m_packetStr;
189         Packet m_packet;
190         std::string m_resultStr;
191         bool m_result;
192         std::string m_returnCodeStr;
193         int m_returnCode;
194         PairsParser::Pairs m_reply;
195         PairsParser::Pairs m_modify;
196
197         PacketParser m_packetParser;
198         ResultParser m_resultParser;
199         ReturnCodeParser m_returnCodeParser;
200         PairsParser m_replyParser;
201         PairsParser m_modifyParser;
202
203         Callback m_callback;
204         void* m_data;
205 };
206
207 class ProtoParser : public Parser
208 {
209     public:
210         ProtoParser(TopParser::Callback callback, void* data)
211             : Parser( &m_topParser ),
212               m_topParser(callback, data)
213         {}
214
215         const std::string& packetStr() const { return m_topParser.packetStr(); }
216         Packet packet() const { return m_topParser.packet(); }
217         const std::string& resultStr() const { return m_topParser.resultStr(); }
218         bool result() const { return m_topParser.result(); }
219         const std::string& returnCodeStr() const { return m_topParser.returnCodeStr(); }
220         int returnCode() const { return m_topParser.returnCode(); }
221         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
222         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
223
224     private:
225         TopParser m_topParser;
226 };
227
228 class PacketGen : public Gen
229 {
230     public:
231         explicit PacketGen(const std::string& type)
232             : m_type(type)
233         {
234             m_gen.add("packet", m_type);
235         }
236         void run(yajl_gen_t* handle) const
237         {
238             m_gen.run(handle);
239         }
240         PacketGen& add(const std::string& key, const std::string& value)
241         {
242             m_gen.add(key, new StringGen(value));
243             return *this;
244         }
245         PacketGen& add(const std::string& key, MapGen& map)
246         {
247             m_gen.add(key, map);
248             return *this;
249         }
250     private:
251         MapGen m_gen;
252         StringGen m_type;
253 };
254
255 }
256
257 class Conn::Impl
258 {
259 public:
260     Impl(const std::string& address, Callback callback, void* data);
261     ~Impl();
262
263     bool stop();
264     bool connected() const { return m_connected; }
265
266     bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
267
268 private:
269     ChannelConfig m_config;
270
271     int m_sock;
272
273     bool m_running;
274     bool m_stopped;
275
276     time_t m_lastPing;
277     time_t m_lastActivity;
278
279     pthread_t m_thread;
280     pthread_mutex_t m_mutex;
281
282     Callback m_callback;
283     void* m_data;
284
285     ProtoParser m_parser;
286
287     bool m_connected;
288
289     void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password);
290     void m_writePairBlock(const PAIRS& source);
291     PAIRS m_readPairBlock();
292
293     static void* run(void* );
294
295     void runImpl();
296
297     bool start();
298
299     int connect();
300     int connectTCP();
301     int connectUNIX();
302
303     bool read();
304     bool tick();
305
306     static void process(void* data);
307     void processPing();
308     void processPong();
309     void processData();
310     bool sendPing();
311     bool sendPong();
312
313     static bool write(void* data, const char* buf, size_t size);
314 };
315
316 ChannelConfig::ChannelConfig(std::string addr)
317 {
318     // unix:pass@/var/run/stg.sock
319     // tcp:secret@192.168.0.1:12345
320     // udp:key@isp.com.ua:54321
321
322     size_t pos = addr.find_first_of(':');
323     if (pos == std::string::npos)
324         throw Error("Missing transport name.");
325     transport = ToLower(addr.substr(0, pos));
326     addr = addr.substr(pos + 1);
327     if (addr.empty())
328         throw Error("Missing address to connect to.");
329     pos = addr.find_first_of('@');
330     if (pos != std::string::npos) {
331         key = addr.substr(0, pos);
332         addr = addr.substr(pos + 1);
333         if (addr.empty())
334             throw Error("Missing address to connect to.");
335     }
336     if (transport == "unix")
337     {
338         address = addr;
339         return;
340     }
341     pos = addr.find_first_of(':');
342     if (pos == std::string::npos)
343         throw Error("Missing port.");
344     address = addr.substr(0, pos);
345     portStr = addr.substr(pos + 1);
346     if (str2x(portStr, port))
347         throw Error("Invalid port value.");
348 }
349
350 Conn::Conn(const std::string& address, Callback callback, void* data)
351     : m_impl(new Impl(address, callback, data))
352 {
353 }
354
355 Conn::~Conn()
356 {
357 }
358
359 bool Conn::stop()
360 {
361     return m_impl->stop();
362 }
363
364 bool Conn::connected() const
365 {
366     return m_impl->connected();
367 }
368
369 bool Conn::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
370 {
371     return m_impl->request(type, userName, password, pairs);
372 }
373
374 Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
375     : m_config(address),
376       m_sock(connect()),
377       m_running(false),
378       m_stopped(true),
379       m_lastPing(time(NULL)),
380       m_lastActivity(m_lastPing),
381       m_callback(callback),
382       m_data(data),
383       m_parser(&Conn::Impl::process, this),
384       m_connected(true)
385 {
386     pthread_mutex_init(&m_mutex, NULL);
387 }
388
389 Conn::Impl::~Impl()
390 {
391     stop();
392     shutdown(m_sock, SHUT_RDWR);
393     close(m_sock);
394     pthread_mutex_destroy(&m_mutex);
395 }
396
397 bool Conn::Impl::stop()
398 {
399     m_connected = false;
400
401     if (m_stopped)
402         return true;
403
404     m_running = false;
405
406     for (size_t i = 0; i < 25 && !m_stopped; i++) {
407         struct timespec ts = {0, 200000000};
408         nanosleep(&ts, NULL);
409     }
410
411     if (m_stopped) {
412         pthread_join(m_thread, NULL);
413         return true;
414     }
415
416     return false;
417 }
418
419 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
420 {
421     if (!m_running)
422         if (!start())
423             return false;
424     MapGen map;
425     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
426         map.add(it->first, new StringGen(it->second));
427     map.add("Radius-Username", new StringGen(userName));
428     map.add("Radius-Userpass", new StringGen(password));
429
430     PacketGen gen("data");
431     gen.add("stage", toStage(type))
432        .add("pairs", map);
433
434     STG_LOCKER lock(m_mutex);
435
436     m_lastPing = time(NULL);
437
438     return generate(gen, &Conn::Impl::write, this);
439 }
440
441 void Conn::Impl::runImpl()
442 {
443     m_running = true;
444
445     while (m_running) {
446         fd_set fds;
447
448         FD_ZERO(&fds);
449         FD_SET(m_sock, &fds);
450
451         struct timeval tv;
452         tv.tv_sec = 0;
453         tv.tv_usec = 500000;
454
455         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
456         if (res < 0)
457         {
458             if (errno == EINTR)
459                 continue;
460             RadLog("'select' is failed: %s", strerror(errno));
461             break;
462         }
463
464
465         if (!m_running)
466             break;
467
468         STG_LOCKER lock(m_mutex);
469
470         if (res > 0)
471         {
472             if (FD_ISSET(m_sock, &fds))
473                 m_running = read();
474         }
475         else
476             m_running = tick();
477     }
478
479     m_connected = false;
480     m_stopped = true;
481 }
482
483 bool Conn::Impl::start()
484 {
485     int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
486     if (res != 0)
487         return false;
488     return true;
489 }
490
491 int Conn::Impl::connect()
492 {
493     if (m_config.transport == "tcp")
494         return connectTCP();
495     else if (m_config.transport == "unix")
496         return connectUNIX();
497     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
498 }
499
500 int Conn::Impl::connectTCP()
501 {
502     addrinfo hints;
503     memset(&hints, 0, sizeof(addrinfo));
504
505     hints.ai_family = AF_INET;       /* Allow IPv4 */
506     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
507     hints.ai_flags = 0;     /* For wildcard IP address */
508     hints.ai_protocol = 0;           /* Any protocol */
509     hints.ai_canonname = NULL;
510     hints.ai_addr = NULL;
511     hints.ai_next = NULL;
512
513     addrinfo* ais = NULL;
514     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
515     if (res != 0)
516         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
517
518     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
519     {
520         int fd = socket(AF_INET, SOCK_STREAM, 0);
521         if (fd == -1)
522         {
523             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
524             freeaddrinfo(ais);
525             throw error;
526         }
527         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
528         {
529             shutdown(fd, SHUT_RDWR);
530             close(fd);
531             RadLog("'connect' is failed: %s", strerror(errno));
532             continue;
533         }
534         freeaddrinfo(ais);
535         return fd;
536     }
537
538     freeaddrinfo(ais);
539
540     throw Error("Failed to resolve '" + m_config.address);
541 };
542
543 int Conn::Impl::connectUNIX()
544 {
545     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
546     if (fd == -1)
547         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
548     struct sockaddr_un addr;
549     memset(&addr, 0, sizeof(addr));
550     addr.sun_family = AF_UNIX;
551     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
552     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
553     {
554         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
555         shutdown(fd, SHUT_RDWR);
556         close(fd);
557         throw error;
558     }
559     return fd;
560 }
561
562 bool Conn::Impl::read()
563 {
564     static std::vector<char> buffer(1024);
565     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
566     if (res < 0)
567     {
568         RadLog("Failed to read data: %s", strerror(errno));
569         return false;
570     }
571     m_lastActivity = time(NULL);
572     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
573     if (res == 0)
574     {
575         m_parser.last();
576         return false;
577     }
578     return m_parser.append(buffer.data(), res);
579 }
580
581 bool Conn::Impl::tick()
582 {
583     time_t now = time(NULL);
584     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
585     {
586         int delta = difftime(now, m_lastActivity);
587         RadLog("Connection timeout: %d sec.", delta);
588         //m_logger("Connection to " + m_remote + " timed out.");
589         return false;
590     }
591     if (difftime(now, m_lastPing) > PING_TIMEOUT)
592     {
593         int delta = difftime(now, m_lastPing);
594         RadLog("Ping timeout: %d sec. Sending ping...", delta);
595         sendPing();
596     }
597     return true;
598 }
599
600 void Conn::Impl::process(void* data)
601 {
602     Impl& impl = *static_cast<Impl*>(data);
603     switch (impl.m_parser.packet())
604     {
605         case PING:
606             impl.processPing();
607             return;
608         case PONG:
609             impl.processPong();
610             return;
611         case DATA:
612             impl.processData();
613             return;
614     }
615     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
616 }
617
618 void Conn::Impl::processPing()
619 {
620     sendPong();
621 }
622
623 void Conn::Impl::processPong()
624 {
625     m_lastActivity = time(NULL);
626 }
627
628 void Conn::Impl::processData()
629 {
630     RESULT data;
631     if (m_parser.result())
632     {
633         for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
634             data.reply.push_back(std::make_pair(it->first, it->second));
635         for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
636             data.modify.push_back(std::make_pair(it->first, it->second));
637         data.returnCode = STG_UPDATED;
638     }
639     else
640         data.returnCode = m_parser.returnCode();
641     m_callback(m_data, data);
642 }
643
644 bool Conn::Impl::sendPing()
645 {
646     PacketGen gen("ping");
647
648     m_lastPing = time(NULL);
649
650     return generate(gen, &Conn::Impl::write, this);
651 }
652
653 bool Conn::Impl::sendPong()
654 {
655     PacketGen gen("pong");
656
657     m_lastPing = time(NULL);
658
659     return generate(gen, &Conn::Impl::write, this);
660 }
661
662 bool Conn::Impl::write(void* data, const char* buf, size_t size)
663 {
664     std::string json(buf, size);
665     RadLog("Sending JSON: %s", json.c_str());
666     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
667     while (size > 0)
668     {
669         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
670         if (res < 0)
671         {
672             impl.m_connected = false;
673             RadLog("Failed to write data: %s.", strerror(errno));
674             return false;
675         }
676         size -= res;
677     }
678     return true;
679 }
680
681 void* Conn::Impl::run(void* data)
682 {
683     Impl& impl = *static_cast<Impl*>(data);
684     impl.runImpl();
685     return NULL;
686 }