]> git.stg.codes - stg.git/blob - projects/rlm_stg/stg_client.cpp
Moved connection-related functions into a separate file.
[stg.git] / projects / rlm_stg / stg_client.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "stg_client.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/common.h"
28
29 #include <map>
30 #include <utility>
31 #include <cerrno>
32 #include <cstring>
33
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <sys/un.h> // UNIX
37 #include <netinet/in.h> // IP
38 #include <netinet/tcp.h> // TCP
39 #include <netdb.h>
40
41 using STG::JSON::Parser;
42 using STG::JSON::PairsParser;
43 using STG::JSON::EnumParser;
44 using STG::JSON::NodeParser;
45 using STG::JSON::Gen;
46 using STG::JSON::MapGen;
47 using STG::JSON::StringGen;
48
49 namespace {
50
51 double CONN_TIMEOUT = 60;
52 double PING_TIMEOUT = 10;
53
54 STG_CLIENT* stgClient = NULL;
55
56 std::string toStage(STG_CLIENT::TYPE type)
57 {
58     switch (type)
59     {
60         case STG_CLIENT::AUTHORIZE: return "authorize";
61         case STG_CLIENT::AUTHENTICATE: return "authenticate";
62         case STG_CLIENT::POST_AUTH: return "postauth";
63         case STG_CLIENT::PRE_ACCT: return "preacct";
64         case STG_CLIENT::ACCOUNT: return "accounting";
65     }
66     return "";
67 }
68
69 enum Packet
70 {
71     PING,
72     PONG,
73     DATA
74 };
75
76 std::map<std::string, Packet> packetCodes;
77 std::map<std::string, bool> resultCodes;
78
79 class PacketParser : public EnumParser<Packet>
80 {
81     public:
82         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
83             : EnumParser(next, packet, packetStr, packetCodes)
84         {
85             if (!packetCodes.empty())
86                 return;
87             packetCodes["ping"] = PING;
88             packetCodes["pong"] = PONG;
89             packetCodes["data"] = DATA;
90         }
91 };
92
93 class ResultParser : public EnumParser<bool>
94 {
95     public:
96         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
97             : EnumParser(next, result, resultStr, resultCodes)
98         {
99             if (!resultCodes.empty())
100                 return;
101             resultCodes["no"] = false;
102             resultCodes["ok"] = true;
103         }
104 };
105
106 class TopParser : public NodeParser
107 {
108     public:
109         typedef void (*Callback) (void* /*data*/);
110         TopParser(Callback callback, void* data)
111             : m_packetParser(this, m_packet, m_packetStr),
112               m_resultParser(this, m_result, m_resultStr),
113               m_replyParser(this, m_reply),
114               m_modifyParser(this, m_modify),
115               m_callback(callback), m_data(data)
116         {}
117
118         virtual NodeParser* parseStartMap() { return this; }
119         virtual NodeParser* parseMapKey(const std::string& value)
120         {
121             std::string key = ToLower(value);
122
123             if (key == "packet")
124                 return &m_packetParser;
125             else if (key == "result")
126                 return &m_resultParser;
127             else if (key == "reply")
128                 return &m_replyParser;
129             else if (key == "modify")
130                 return &m_modifyParser;
131
132             return this;
133         }
134         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
135
136         const std::string& packetStr() const { return m_packetStr; }
137         Packet packet() const { return m_packet; }
138         const std::string& resultStr() const { return m_resultStr; }
139         bool result() const { return m_result; }
140         const PairsParser::Pairs& reply() const { return m_reply; }
141         const PairsParser::Pairs& modify() const { return m_modify; }
142
143     private:
144         std::string m_packetStr;
145         Packet m_packet;
146         std::string m_resultStr;
147         bool m_result;
148         PairsParser::Pairs m_reply;
149         PairsParser::Pairs m_modify;
150
151         PacketParser m_packetParser;
152         ResultParser m_resultParser;
153         PairsParser m_replyParser;
154         PairsParser m_modifyParser;
155
156         Callback m_callback;
157         void* m_data;
158 };
159
160 class ProtoParser : public Parser
161 {
162     public:
163         ProtoParser(TopParser::Callback callback, void* data)
164             : Parser( &m_topParser ),
165               m_topParser(callback, data)
166         {}
167
168         const std::string& packetStr() const { return m_topParser.packetStr(); }
169         Packet packet() const { return m_topParser.packet(); }
170         const std::string& resultStr() const { return m_topParser.resultStr(); }
171         bool result() const { return m_topParser.result(); }
172         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
173         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
174
175     private:
176         TopParser m_topParser;
177 };
178
179 class PacketGen : public Gen
180 {
181     public:
182         PacketGen(const std::string& type)
183             : m_type(type)
184         {
185             m_gen.add("packet", m_type);
186         }
187         void run(yajl_gen_t* handle) const
188         {
189             m_gen.run(handle);
190         }
191         PacketGen& add(const std::string& key, const std::string& value)
192         {
193             m_gen.add(key, new StringGen(value));
194             return *this;
195         }
196         PacketGen& add(const std::string& key, MapGen& map)
197         {
198             m_gen.add(key, map);
199             return *this;
200         }
201     private:
202         MapGen m_gen;
203         StringGen m_type;
204 };
205
206 }
207
208 class STG_CLIENT::Impl
209 {
210 public:
211     Impl(const std::string& address, Callback callback, void* data);
212     Impl(const Impl& rhs);
213     ~Impl();
214
215     bool stop();
216     bool connected() const { return m_connected; }
217
218     bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
219
220 private:
221     ChannelConfig m_config;
222
223     int m_sock;
224
225     bool m_running;
226     bool m_stopped;
227
228     time_t m_lastPing;
229     time_t m_lastActivity;
230
231     pthread_t m_thread;
232     pthread_mutex_t m_mutex;
233
234     Callback m_callback;
235     void* m_data;
236
237     ProtoParser m_parser;
238
239     bool m_connected;
240
241     void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
242     void m_writePairBlock(const PAIRS& source);
243     PAIRS m_readPairBlock();
244
245     static void* run(void* );
246
247     void runImpl();
248
249     int connect();
250     int connectTCP();
251     int connectUNIX();
252
253     bool read();
254     bool tick();
255
256     static void process(void* data);
257     void processPing();
258     void processPong();
259     void processData();
260     bool sendPing();
261     bool sendPong();
262
263     static bool write(void* data, const char* buf, size_t size);
264 };
265
266 ChannelConfig::ChannelConfig(std::string addr)
267 {
268     // unix:pass@/var/run/stg.sock
269     // tcp:secret@192.168.0.1:12345
270     // udp:key@isp.com.ua:54321
271
272     size_t pos = addr.find_first_of(':');
273     if (pos == std::string::npos)
274         throw Error("Missing transport name.");
275     transport = ToLower(addr.substr(0, pos));
276     addr = addr.substr(pos + 1);
277     if (addr.empty())
278         throw Error("Missing address to connect to.");
279     pos = addr.find_first_of('@');
280     if (pos != std::string::npos) {
281         key = addr.substr(0, pos);
282         addr = addr.substr(pos + 1);
283         if (addr.empty())
284             throw Error("Missing address to connect to.");
285     }
286     if (transport == "unix")
287     {
288         address = addr;
289         return;
290     }
291     pos = addr.find_first_of(':');
292     if (pos == std::string::npos)
293         throw Error("Missing port.");
294     address = addr.substr(0, pos);
295     portStr = addr.substr(pos + 1);
296     if (str2x(portStr, port))
297         throw Error("Invalid port value.");
298 }
299
300 STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data)
301     : m_impl(new Impl(address, callback, data))
302 {
303 }
304
305 STG_CLIENT::STG_CLIENT(const STG_CLIENT& rhs)
306     : m_impl(new Impl(*rhs.m_impl))
307 {
308 }
309
310 STG_CLIENT::~STG_CLIENT()
311 {
312 }
313
314 bool STG_CLIENT::stop()
315 {
316     return m_impl->stop();
317 }
318
319 bool STG_CLIENT::connected() const
320 {
321     return m_impl->connected();
322 }
323
324 bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
325 {
326     return m_impl->request(type, userName, password, pairs);
327 }
328
329 STG_CLIENT* STG_CLIENT::get()
330 {
331     return stgClient;
332 }
333
334 bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data)
335 {
336     if ( stgClient != NULL && stgClient->stop() )
337         delete stgClient;
338     try {
339         stgClient = new STG_CLIENT(address, callback, data);
340         return true;
341     } catch (const std::exception& ex) {
342         // TODO: Log it
343         RadLog("Client configuration error: %s.", ex.what());
344     }
345     return false;
346 }
347
348 bool STG_CLIENT::reconnect()
349 {
350     if (stgClient == NULL)
351     {
352         RadLog("Connection is not configured.");
353         return false;
354     }
355     if (!stgClient->stop())
356     {
357         RadLog("Failed to drop previous connection.");
358         return false;
359     }
360     try {
361         STG_CLIENT* old = stgClient;
362         stgClient = new STG_CLIENT(*old);
363         delete old;
364         return true;
365     } catch (const ChannelConfig::Error& ex) {
366         // TODO: Log it
367         RadLog("Client configuration error: %s.", ex.what());
368     }
369     return false;
370 }
371
372 STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data)
373     : m_config(address),
374       m_sock(connect()),
375       m_running(false),
376       m_stopped(true),
377       m_lastPing(time(NULL)),
378       m_lastActivity(m_lastPing),
379       m_callback(callback),
380       m_data(data),
381       m_parser(&STG_CLIENT::Impl::process, this),
382       m_connected(true)
383 {
384     int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
385     if (res != 0)
386         throw Error("Failed to create thread: " + std::string(strerror(errno)));
387 }
388
389 STG_CLIENT::Impl::Impl(const Impl& rhs)
390     : m_config(rhs.m_config),
391       m_sock(connect()),
392       m_running(false),
393       m_stopped(true),
394       m_lastPing(time(NULL)),
395       m_lastActivity(m_lastPing),
396       m_callback(rhs.m_callback),
397       m_data(rhs.m_data),
398       m_parser(&STG_CLIENT::Impl::process, this),
399       m_connected(true)
400 {
401     int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
402     if (res != 0)
403         throw Error("Failed to create thread: " + std::string(strerror(errno)));
404 }
405
406 STG_CLIENT::Impl::~Impl()
407 {
408     stop();
409     shutdown(m_sock, SHUT_RDWR);
410     close(m_sock);
411 }
412
413 bool STG_CLIENT::Impl::stop()
414 {
415     m_connected = false;
416
417     if (m_stopped)
418         return true;
419
420     m_running = false;
421
422     for (size_t i = 0; i < 25 && !m_stopped; i++) {
423         struct timespec ts = {0, 200000000};
424         nanosleep(&ts, NULL);
425     }
426
427     if (m_stopped) {
428         pthread_join(m_thread, NULL);
429         return true;
430     }
431
432     return false;
433 }
434
435 bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
436 {
437     MapGen map;
438     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
439         map.add(it->first, new StringGen(it->second));
440     map.add("Radius-Username", new StringGen(userName));
441     map.add("Radius-Userpass", new StringGen(password));
442
443     PacketGen gen("data");
444     gen.add("stage", toStage(type))
445        .add("pairs", map);
446
447     m_lastPing = time(NULL);
448
449     return generate(gen, &STG_CLIENT::Impl::write, this);
450 }
451
452 void STG_CLIENT::Impl::runImpl()
453 {
454     m_running = true;
455
456     while (m_running) {
457         fd_set fds;
458
459         FD_ZERO(&fds);
460         FD_SET(m_sock, &fds);
461
462         struct timeval tv;
463         tv.tv_sec = 0;
464         tv.tv_usec = 500000;
465
466         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
467         if (res < 0)
468         {
469             if (errno == EINTR)
470                 continue;
471             RadLog("'select' is failed: %s", strerror(errno));
472             break;
473         }
474
475         if (!m_running)
476             break;
477
478         if (res > 0)
479         {
480             if (FD_ISSET(m_sock, &fds))
481                 m_running = read();
482         }
483         else
484             m_running = tick();
485     }
486
487     m_connected = false;
488     m_stopped = true;
489 }
490
491 int STG_CLIENT::Impl::connect()
492 {
493     if (m_config.transport == "tcp")
494         return connectTCP();
495     else if (m_config.transport == "unix")
496         return connectUNIX();
497     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
498 }
499
500 int STG_CLIENT::Impl::connectTCP()
501 {
502     addrinfo hints;
503     memset(&hints, 0, sizeof(addrinfo));
504
505     hints.ai_family = AF_INET;       /* Allow IPv4 */
506     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
507     hints.ai_flags = 0;     /* For wildcard IP address */
508     hints.ai_protocol = 0;           /* Any protocol */
509     hints.ai_canonname = NULL;
510     hints.ai_addr = NULL;
511     hints.ai_next = NULL;
512
513     addrinfo* ais = NULL;
514     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
515     if (res != 0)
516         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
517
518     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
519     {
520         int fd = socket(AF_INET, SOCK_STREAM, 0);
521         if (fd == -1)
522         {
523             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
524             freeaddrinfo(ais);
525             throw error;
526         }
527         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
528         {
529             shutdown(fd, SHUT_RDWR);
530             close(fd);
531             RadLog("'connect' is failed: %s", strerror(errno));
532             continue;
533         }
534         freeaddrinfo(ais);
535         return fd;
536     }
537
538     freeaddrinfo(ais);
539
540     throw Error("Failed to resolve '" + m_config.address);
541 };
542
543 int STG_CLIENT::Impl::connectUNIX()
544 {
545     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
546     if (fd == -1)
547         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
548     struct sockaddr_un addr;
549     memset(&addr, 0, sizeof(addr));
550     addr.sun_family = AF_UNIX;
551     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
552     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
553     {
554         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
555         shutdown(fd, SHUT_RDWR);
556         close(fd);
557         throw error;
558     }
559     return fd;
560 }
561
562 bool STG_CLIENT::Impl::read()
563 {
564     static std::vector<char> buffer(1024);
565     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
566     if (res < 0)
567     {
568         RadLog("Failed to read data: %s", strerror(errno));
569         return false;
570     }
571     m_lastActivity = time(NULL);
572     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
573     if (res == 0)
574     {
575         m_parser.last();
576         return false;
577     }
578     return m_parser.append(buffer.data(), res);
579 }
580
581 bool STG_CLIENT::Impl::tick()
582 {
583     time_t now = time(NULL);
584     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
585     {
586         int delta = difftime(now, m_lastActivity);
587         RadLog("Connection timeout: %d sec.", delta);
588         //m_logger("Connection to " + m_remote + " timed out.");
589         return false;
590     }
591     if (difftime(now, m_lastPing) > PING_TIMEOUT)
592     {
593         int delta = difftime(now, m_lastPing);
594         RadLog("Ping timeout: %d sec. Sending ping...", delta);
595         sendPing();
596     }
597     return true;
598 }
599
600 void STG_CLIENT::Impl::process(void* data)
601 {
602     Impl& impl = *static_cast<Impl*>(data);
603     switch (impl.m_parser.packet())
604     {
605         case PING:
606             impl.processPing();
607             return;
608         case PONG:
609             impl.processPong();
610             return;
611         case DATA:
612             impl.processData();
613             return;
614     }
615     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
616 }
617
618 void STG_CLIENT::Impl::processPing()
619 {
620     RadLog("Got ping, sending pong.");
621     sendPong();
622 }
623
624 void STG_CLIENT::Impl::processPong()
625 {
626     RadLog("Got pong.");
627     m_lastActivity = time(NULL);
628 }
629
630 void STG_CLIENT::Impl::processData()
631 {
632     RESULT data;
633     RadLog("Got data.");
634     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
635         data.reply.push_back(std::make_pair(it->first, it->second));
636     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
637         data.modify.push_back(std::make_pair(it->first, it->second));
638     m_callback(m_data, data, m_parser.result());
639 }
640
641 bool STG_CLIENT::Impl::sendPing()
642 {
643     PacketGen gen("ping");
644
645     m_lastPing = time(NULL);
646
647     return generate(gen, &STG_CLIENT::Impl::write, this);
648 }
649
650 bool STG_CLIENT::Impl::sendPong()
651 {
652     PacketGen gen("pong");
653
654     m_lastPing = time(NULL);
655
656     return generate(gen, &STG_CLIENT::Impl::write, this);
657 }
658
659 bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size)
660 {
661     RadLog("Sending JSON:");
662     std::string json(buf, size);
663     RadLog("%s", json.c_str());
664     STG_CLIENT::Impl& impl = *static_cast<STG_CLIENT::Impl*>(data);
665     while (size > 0)
666     {
667         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
668         if (res < 0)
669         {
670             impl.m_connected = false;
671             RadLog("Failed to write data: %s.", strerror(errno));
672             return false;
673         }
674         size -= res;
675     }
676     return true;
677 }
678
679 void* STG_CLIENT::Impl::run(void* data)
680 {
681     Impl& impl = *static_cast<Impl*>(data);
682     impl.runImpl();
683     return NULL;
684 }