]> git.stg.codes - stg.git/blob - projects/rlm_stg/conn.cpp
Moved connection-related functions into a separate file.
[stg.git] / projects / rlm_stg / conn.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "conn.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/locker.h"
28
29 #include <cerrno>
30 #include <cstring>
31
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 #include <sys/un.h> // UNIX
35 #include <netinet/in.h> // IP
36 #include <netinet/tcp.h> // TCP
37 #include <netdb.h>
38
39 namespace RLM = STG::RLM;
40
41 using RLM::Conn;
42 using STG::JSON::Parser;
43 using STG::JSON::PairsParser;
44 using STG::JSON::EnumParser;
45 using STG::JSON::NodeParser;
46 using STG::JSON::Gen;
47 using STG::JSON::MapGen;
48 using STG::JSON::StringGen;
49
50 namespace
51 {
52
53 double CONN_TIMEOUT = 60;
54 double PING_TIMEOUT = 10;
55
56 struct ChannelConfig {
57     struct Error : std::runtime_error {
58         Error(const std::string& message) : runtime_error(message) {}
59     };
60
61     ChannelConfig(std::string address);
62
63     std::string transport;
64     std::string key;
65     std::string address;
66     std::string portStr;
67     uint16_t port;
68 };
69
70 std::string toStage(RLM::REQUEST_TYPE type)
71 {
72     switch (type)
73     {
74         case RLM::AUTHORIZE: return "authorize";
75         case RLM::AUTHENTICATE: return "authenticate";
76         case RLM::POST_AUTH: return "postauth";
77         case RLM::PRE_ACCT: return "preacct";
78         case RLM::ACCOUNT: return "accounting";
79     }
80     return "";
81 }
82
83 enum Packet
84 {
85     PING,
86     PONG,
87     DATA
88 };
89
90 std::map<std::string, Packet> packetCodes;
91 std::map<std::string, bool> resultCodes;
92
93 class PacketParser : public EnumParser<Packet>
94 {
95     public:
96         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
97             : EnumParser(next, packet, packetStr, packetCodes)
98         {
99             if (!packetCodes.empty())
100                 return;
101             packetCodes["ping"] = PING;
102             packetCodes["pong"] = PONG;
103             packetCodes["data"] = DATA;
104         }
105 };
106
107 class ResultParser : public EnumParser<bool>
108 {
109     public:
110         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
111             : EnumParser(next, result, resultStr, resultCodes)
112         {
113             if (!resultCodes.empty())
114                 return;
115             resultCodes["no"] = false;
116             resultCodes["ok"] = true;
117         }
118 };
119
120 class TopParser : public NodeParser
121 {
122     public:
123         typedef void (*Callback) (void* /*data*/);
124         TopParser(Callback callback, void* data)
125             : m_packet(PING),
126               m_result(false),
127               m_packetParser(this, m_packet, m_packetStr),
128               m_resultParser(this, m_result, m_resultStr),
129               m_replyParser(this, m_reply),
130               m_modifyParser(this, m_modify),
131               m_callback(callback), m_data(data)
132         {}
133
134         virtual NodeParser* parseStartMap() { return this; }
135         virtual NodeParser* parseMapKey(const std::string& value)
136         {
137             std::string key = ToLower(value);
138
139             if (key == "packet")
140                 return &m_packetParser;
141             else if (key == "result")
142                 return &m_resultParser;
143             else if (key == "reply")
144                 return &m_replyParser;
145             else if (key == "modify")
146                 return &m_modifyParser;
147
148             return this;
149         }
150         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
151
152         const std::string& packetStr() const { return m_packetStr; }
153         Packet packet() const { return m_packet; }
154         const std::string& resultStr() const { return m_resultStr; }
155         bool result() const { return m_result; }
156         const PairsParser::Pairs& reply() const { return m_reply; }
157         const PairsParser::Pairs& modify() const { return m_modify; }
158
159     private:
160         std::string m_packetStr;
161         Packet m_packet;
162         std::string m_resultStr;
163         bool m_result;
164         PairsParser::Pairs m_reply;
165         PairsParser::Pairs m_modify;
166
167         PacketParser m_packetParser;
168         ResultParser m_resultParser;
169         PairsParser m_replyParser;
170         PairsParser m_modifyParser;
171
172         Callback m_callback;
173         void* m_data;
174 };
175
176 class ProtoParser : public Parser
177 {
178     public:
179         ProtoParser(TopParser::Callback callback, void* data)
180             : Parser( &m_topParser ),
181               m_topParser(callback, data)
182         {}
183
184         const std::string& packetStr() const { return m_topParser.packetStr(); }
185         Packet packet() const { return m_topParser.packet(); }
186         const std::string& resultStr() const { return m_topParser.resultStr(); }
187         bool result() const { return m_topParser.result(); }
188         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
189         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
190
191     private:
192         TopParser m_topParser;
193 };
194
195 class PacketGen : public Gen
196 {
197     public:
198         PacketGen(const std::string& type)
199             : m_type(type)
200         {
201             m_gen.add("packet", m_type);
202         }
203         void run(yajl_gen_t* handle) const
204         {
205             m_gen.run(handle);
206         }
207         PacketGen& add(const std::string& key, const std::string& value)
208         {
209             m_gen.add(key, new StringGen(value));
210             return *this;
211         }
212         PacketGen& add(const std::string& key, MapGen& map)
213         {
214             m_gen.add(key, map);
215             return *this;
216         }
217     private:
218         MapGen m_gen;
219         StringGen m_type;
220 };
221
222 }
223
224 class Conn::Impl
225 {
226 public:
227     Impl(const std::string& address, Callback callback, void* data);
228     ~Impl();
229
230     bool stop();
231     bool connected() const { return m_connected; }
232
233     bool request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
234
235 private:
236     ChannelConfig m_config;
237
238     int m_sock;
239
240     bool m_running;
241     bool m_stopped;
242
243     time_t m_lastPing;
244     time_t m_lastActivity;
245
246     pthread_t m_thread;
247     pthread_mutex_t m_mutex;
248
249     Callback m_callback;
250     void* m_data;
251
252     ProtoParser m_parser;
253
254     bool m_connected;
255
256     void m_writeHeader(REQUEST_TYPE type, const std::string& userName, const std::string& password);
257     void m_writePairBlock(const PAIRS& source);
258     PAIRS m_readPairBlock();
259
260     static void* run(void* );
261
262     void runImpl();
263
264     int connect();
265     int connectTCP();
266     int connectUNIX();
267
268     bool read();
269     bool tick();
270
271     static void process(void* data);
272     void processPing();
273     void processPong();
274     void processData();
275     bool sendPing();
276     bool sendPong();
277
278     static bool write(void* data, const char* buf, size_t size);
279 };
280
281 ChannelConfig::ChannelConfig(std::string addr)
282 {
283     // unix:pass@/var/run/stg.sock
284     // tcp:secret@192.168.0.1:12345
285     // udp:key@isp.com.ua:54321
286
287     size_t pos = addr.find_first_of(':');
288     if (pos == std::string::npos)
289         throw Error("Missing transport name.");
290     transport = ToLower(addr.substr(0, pos));
291     addr = addr.substr(pos + 1);
292     if (addr.empty())
293         throw Error("Missing address to connect to.");
294     pos = addr.find_first_of('@');
295     if (pos != std::string::npos) {
296         key = addr.substr(0, pos);
297         addr = addr.substr(pos + 1);
298         if (addr.empty())
299             throw Error("Missing address to connect to.");
300     }
301     if (transport == "unix")
302     {
303         address = addr;
304         return;
305     }
306     pos = addr.find_first_of(':');
307     if (pos == std::string::npos)
308         throw Error("Missing port.");
309     address = addr.substr(0, pos);
310     portStr = addr.substr(pos + 1);
311     if (str2x(portStr, port))
312         throw Error("Invalid port value.");
313 }
314
315 Conn::Impl::Impl(const std::string& address, Callback callback, void* data)
316     : m_config(address),
317       m_sock(connect()),
318       m_running(false),
319       m_stopped(true),
320       m_lastPing(time(NULL)),
321       m_lastActivity(m_lastPing),
322       m_callback(callback),
323       m_data(data),
324       m_parser(&Conn::Impl::process, this),
325       m_connected(true)
326 {
327     pthread_mutex_init(&m_mutex, NULL);
328     int res = pthread_create(&m_thread, NULL, &Conn::Impl::run, this);
329     if (res != 0)
330         throw Error("Failed to create thread: " + std::string(strerror(errno)));
331 }
332
333 Conn::Impl::~Impl()
334 {
335     stop();
336     shutdown(m_sock, SHUT_RDWR);
337     close(m_sock);
338     pthread_mutex_destroy(&m_mutex);
339 }
340
341 bool Conn::Impl::stop()
342 {
343     m_connected = false;
344
345     if (m_stopped)
346         return true;
347
348     m_running = false;
349
350     for (size_t i = 0; i < 25 && !m_stopped; i++) {
351         struct timespec ts = {0, 200000000};
352         nanosleep(&ts, NULL);
353     }
354
355     if (m_stopped) {
356         pthread_join(m_thread, NULL);
357         return true;
358     }
359
360     return false;
361 }
362
363 bool Conn::Impl::request(REQUEST_TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
364 {
365     MapGen map;
366     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
367         map.add(it->first, new StringGen(it->second));
368     map.add("Radius-Username", new StringGen(userName));
369     map.add("Radius-Userpass", new StringGen(password));
370
371     PacketGen gen("data");
372     gen.add("stage", toStage(type))
373        .add("pairs", map);
374
375     STG_LOCKER lock(m_mutex);
376
377     m_lastPing = time(NULL);
378
379     return generate(gen, &Conn::Impl::write, this);
380 }
381
382 void Conn::Impl::runImpl()
383 {
384     m_running = true;
385
386     while (m_running) {
387         fd_set fds;
388
389         FD_ZERO(&fds);
390         FD_SET(m_sock, &fds);
391
392         struct timeval tv;
393         tv.tv_sec = 0;
394         tv.tv_usec = 500000;
395
396         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
397         if (res < 0)
398         {
399             if (errno == EINTR)
400                 continue;
401             RadLog("'select' is failed: %s", strerror(errno));
402             break;
403         }
404
405         if (!m_running)
406             break;
407
408         STG_LOCKER lock(m_mutex);
409
410         if (res > 0)
411         {
412             if (FD_ISSET(m_sock, &fds))
413                 m_running = read();
414         }
415         else
416             m_running = tick();
417     }
418
419     m_connected = false;
420     m_stopped = true;
421 }
422
423 int Conn::Impl::connect()
424 {
425     if (m_config.transport == "tcp")
426         return connectTCP();
427     else if (m_config.transport == "unix")
428         return connectUNIX();
429     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
430 }
431
432 int Conn::Impl::connectTCP()
433 {
434     addrinfo hints;
435     memset(&hints, 0, sizeof(addrinfo));
436
437     hints.ai_family = AF_INET;       /* Allow IPv4 */
438     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
439     hints.ai_flags = 0;     /* For wildcard IP address */
440     hints.ai_protocol = 0;           /* Any protocol */
441     hints.ai_canonname = NULL;
442     hints.ai_addr = NULL;
443     hints.ai_next = NULL;
444
445     addrinfo* ais = NULL;
446     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
447     if (res != 0)
448         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
449
450     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
451     {
452         int fd = socket(AF_INET, SOCK_STREAM, 0);
453         if (fd == -1)
454         {
455             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
456             freeaddrinfo(ais);
457             throw error;
458         }
459         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
460         {
461             shutdown(fd, SHUT_RDWR);
462             close(fd);
463             RadLog("'connect' is failed: %s", strerror(errno));
464             continue;
465         }
466         freeaddrinfo(ais);
467         return fd;
468     }
469
470     freeaddrinfo(ais);
471
472     throw Error("Failed to resolve '" + m_config.address);
473 };
474
475 int Conn::Impl::connectUNIX()
476 {
477     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
478     if (fd == -1)
479         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
480     struct sockaddr_un addr;
481     memset(&addr, 0, sizeof(addr));
482     addr.sun_family = AF_UNIX;
483     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
484     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
485     {
486         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
487         shutdown(fd, SHUT_RDWR);
488         close(fd);
489         throw error;
490     }
491     return fd;
492 }
493
494 bool Conn::Impl::read()
495 {
496     static std::vector<char> buffer(1024);
497     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
498     if (res < 0)
499     {
500         RadLog("Failed to read data: %s", strerror(errno));
501         return false;
502     }
503     m_lastActivity = time(NULL);
504     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
505     if (res == 0)
506     {
507         m_parser.last();
508         return false;
509     }
510     return m_parser.append(buffer.data(), res);
511 }
512
513 bool Conn::Impl::tick()
514 {
515     time_t now = time(NULL);
516     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
517     {
518         int delta = difftime(now, m_lastActivity);
519         RadLog("Connection timeout: %d sec.", delta);
520         //m_logger("Connection to " + m_remote + " timed out.");
521         return false;
522     }
523     if (difftime(now, m_lastPing) > PING_TIMEOUT)
524     {
525         int delta = difftime(now, m_lastPing);
526         RadLog("Ping timeout: %d sec. Sending ping...", delta);
527         sendPing();
528     }
529     return true;
530 }
531
532 void Conn::Impl::process(void* data)
533 {
534     Impl& impl = *static_cast<Impl*>(data);
535     switch (impl.m_parser.packet())
536     {
537         case PING:
538             impl.processPing();
539             return;
540         case PONG:
541             impl.processPong();
542             return;
543         case DATA:
544             impl.processData();
545             return;
546     }
547     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
548 }
549
550 void Conn::Impl::processPing()
551 {
552     RadLog("Got ping, sending pong.");
553     sendPong();
554 }
555
556 void Conn::Impl::processPong()
557 {
558     RadLog("Got pong.");
559     m_lastActivity = time(NULL);
560 }
561
562 void Conn::Impl::processData()
563 {
564     RESULT data;
565     RadLog("Got data.");
566     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
567         data.reply.push_back(std::make_pair(it->first, it->second));
568     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
569         data.modify.push_back(std::make_pair(it->first, it->second));
570     m_callback(m_data, data, m_parser.result());
571 }
572
573 bool Conn::Impl::sendPing()
574 {
575     PacketGen gen("ping");
576
577     m_lastPing = time(NULL);
578
579     return generate(gen, &Conn::Impl::write, this);
580 }
581
582 bool Conn::Impl::sendPong()
583 {
584     PacketGen gen("pong");
585
586     m_lastPing = time(NULL);
587
588     return generate(gen, &Conn::Impl::write, this);
589 }
590
591 bool Conn::Impl::write(void* data, const char* buf, size_t size)
592 {
593     RadLog("Sending JSON:");
594     std::string json(buf, size);
595     RadLog("%s", json.c_str());
596     Conn::Impl& impl = *static_cast<Conn::Impl*>(data);
597     while (size > 0)
598     {
599         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
600         if (res < 0)
601         {
602             impl.m_connected = false;
603             RadLog("Failed to write data: %s.", strerror(errno));
604             return false;
605         }
606         size -= res;
607     }
608     return true;
609 }
610
611 void* Conn::Impl::run(void* data)
612 {
613     Impl& impl = *static_cast<Impl*>(data);
614     impl.runImpl();
615     return NULL;
616 }