]> git.stg.codes - stg.git/blob - projects/rlm_stg/stg_client.cpp
137fb61a22a4a1039029f43904f35aaf658741d7
[stg.git] / projects / rlm_stg / stg_client.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "stg_client.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/common.h"
28
29 #include <map>
30 #include <utility>
31 #include <cerrno>
32 #include <cstring>
33
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <sys/un.h> // UNIX
37 #include <netinet/in.h> // IP
38 #include <netinet/tcp.h> // TCP
39 #include <netdb.h>
40
41 using STG::JSON::Parser;
42 using STG::JSON::PairsParser;
43 using STG::JSON::EnumParser;
44 using STG::JSON::NodeParser;
45 using STG::JSON::Gen;
46 using STG::JSON::MapGen;
47 using STG::JSON::StringGen;
48
49 namespace {
50
51 double CONN_TIMEOUT = 60;
52 double PING_TIMEOUT = 10;
53
54 STG_CLIENT* stgClient = NULL;
55
56 std::string toStage(STG_CLIENT::TYPE type)
57 {
58     switch (type)
59     {
60         case STG_CLIENT::AUTHORIZE: return "authorize";
61         case STG_CLIENT::AUTHENTICATE: return "authenticate";
62         case STG_CLIENT::POST_AUTH: return "postauth";
63         case STG_CLIENT::PRE_ACCT: return "preacct";
64         case STG_CLIENT::ACCOUNT: return "accounting";
65     }
66     return "";
67 }
68
69 enum Packet
70 {
71     PING,
72     PONG,
73     DATA
74 };
75
76 std::map<std::string, Packet> packetCodes;
77 std::map<std::string, bool> resultCodes;
78
79 class PacketParser : public EnumParser<Packet>
80 {
81     public:
82         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
83             : EnumParser(next, packet, packetStr, packetCodes)
84         {
85             if (!packetCodes.empty())
86                 return;
87             packetCodes["ping"] = PING;
88             packetCodes["pong"] = PONG;
89             packetCodes["data"] = DATA;
90         }
91 };
92
93 class ResultParser : public EnumParser<bool>
94 {
95     public:
96         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
97             : EnumParser(next, result, resultStr, resultCodes)
98         {
99             if (!resultCodes.empty())
100                 return;
101             resultCodes["no"] = false;
102             resultCodes["ok"] = true;
103         }
104 };
105
106 class TopParser : public NodeParser
107 {
108     public:
109         typedef void (*Callback) (void* /*data*/);
110         TopParser(Callback callback, void* data)
111             : m_packetParser(this, m_packet, m_packetStr),
112               m_resultParser(this, m_result, m_resultStr),
113               m_replyParser(this, m_reply),
114               m_modifyParser(this, m_modify),
115               m_callback(callback), m_data(data)
116         {}
117
118         virtual NodeParser* parseStartMap() { return this; }
119         virtual NodeParser* parseMapKey(const std::string& value)
120         {
121             std::string key = ToLower(value);
122
123             if (key == "packet")
124                 return &m_packetParser;
125             else if (key == "result")
126                 return &m_resultParser;
127             else if (key == "reply")
128                 return &m_replyParser;
129             else if (key == "modify")
130                 return &m_modifyParser;
131
132             return this;
133         }
134         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
135
136         const std::string& packetStr() const { return m_packetStr; }
137         Packet packet() const { return m_packet; }
138         const std::string& resultStr() const { return m_resultStr; }
139         bool result() const { return m_result; }
140         const PairsParser::Pairs& reply() const { return m_reply; }
141         const PairsParser::Pairs& modify() const { return m_modify; }
142
143     private:
144         std::string m_packetStr;
145         Packet m_packet;
146         std::string m_resultStr;
147         bool m_result;
148         PairsParser::Pairs m_reply;
149         PairsParser::Pairs m_modify;
150
151         PacketParser m_packetParser;
152         ResultParser m_resultParser;
153         PairsParser m_replyParser;
154         PairsParser m_modifyParser;
155
156         Callback m_callback;
157         void* m_data;
158 };
159
160 class ProtoParser : public Parser
161 {
162     public:
163         ProtoParser(TopParser::Callback callback, void* data)
164             : Parser( &m_topParser ),
165               m_topParser(callback, data)
166         {}
167
168         const std::string& packetStr() const { return m_topParser.packetStr(); }
169         Packet packet() const { return m_topParser.packet(); }
170         const std::string& resultStr() const { return m_topParser.resultStr(); }
171         bool result() const { return m_topParser.result(); }
172         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
173         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
174
175     private:
176         TopParser m_topParser;
177 };
178
179 class PacketGen : public Gen
180 {
181     public:
182         PacketGen(const std::string& type)
183             : m_type(type)
184         {
185             m_gen.add("packet", m_type);
186         }
187         void run(yajl_gen_t* handle) const
188         {
189             m_gen.run(handle);
190         }
191         PacketGen& add(const std::string& key, const std::string& value)
192         {
193             m_gen.add(key, new StringGen(value));
194             return *this;
195         }
196         PacketGen& add(const std::string& key, MapGen& map)
197         {
198             m_gen.add(key, map);
199             return *this;
200         }
201     private:
202         MapGen m_gen;
203         StringGen m_type;
204 };
205
206 }
207
208 class STG_CLIENT::Impl
209 {
210 public:
211     Impl(const std::string& address, Callback callback, void* data);
212     Impl(const Impl& rhs);
213     ~Impl();
214
215     bool stop();
216     bool connected() const { return m_connected; }
217
218     bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
219
220 private:
221     ChannelConfig m_config;
222
223     int m_sock;
224
225     bool m_running;
226     bool m_stopped;
227
228     time_t m_lastPing;
229     time_t m_lastActivity;
230
231     pthread_t m_thread;
232     pthread_mutex_t m_mutex;
233
234     Callback m_callback;
235     void* m_data;
236
237     ProtoParser m_parser;
238
239     bool m_connected;
240
241     void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
242     void m_writePairBlock(const PAIRS& source);
243     PAIRS m_readPairBlock();
244
245     static void* run(void* );
246
247     void runImpl();
248
249     int connect();
250     int connectTCP();
251     int connectUNIX();
252
253     bool read();
254     bool tick();
255
256     static void process(void* data);
257     void processPing();
258     void processPong();
259     void processData();
260     bool sendPing();
261     bool sendPong();
262
263     static bool write(void* data, const char* buf, size_t size);
264 };
265
266 ChannelConfig::ChannelConfig(std::string addr)
267 {
268     // unix:pass@/var/run/stg.sock
269     // tcp:secret@192.168.0.1:12345
270     // udp:key@isp.com.ua:54321
271
272     size_t pos = addr.find_first_of(':');
273     if (pos == std::string::npos)
274         throw Error("Missing transport name.");
275     transport = ToLower(addr.substr(0, pos));
276     addr = addr.substr(pos + 1);
277     if (addr.empty())
278         throw Error("Missing address to connect to.");
279     pos = addr.find_first_of('@');
280     if (pos != std::string::npos) {
281         key = addr.substr(0, pos);
282         addr = addr.substr(pos + 1);
283         if (addr.empty())
284             throw Error("Missing address to connect to.");
285     }
286     if (transport == "unix")
287     {
288         address = addr;
289         return;
290     }
291     pos = addr.find_first_of(':');
292     if (pos == std::string::npos)
293         throw Error("Missing port.");
294     address = addr.substr(0, pos);
295     portStr = addr.substr(pos + 1);
296     if (str2x(portStr, port))
297         throw Error("Invalid port value.");
298 }
299
300 STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data)
301     : m_impl(new Impl(address, callback, data))
302 {
303 }
304
305 STG_CLIENT::STG_CLIENT(const STG_CLIENT& rhs)
306     : m_impl(new Impl(*rhs.m_impl))
307 {
308 }
309
310 STG_CLIENT::~STG_CLIENT()
311 {
312 }
313
314 bool STG_CLIENT::stop()
315 {
316     return m_impl->stop();
317 }
318
319 bool STG_CLIENT::connected() const
320 {
321     return m_impl->connected();
322 }
323
324 bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
325 {
326     return m_impl->request(type, userName, password, pairs);
327 }
328
329 STG_CLIENT* STG_CLIENT::get()
330 {
331     return stgClient;
332 }
333
334 bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data)
335 {
336     if ( stgClient != NULL && stgClient->stop() )
337         delete stgClient;
338     try {
339         stgClient = new STG_CLIENT(address, callback, data);
340         return true;
341     } catch (const ChannelConfig::Error& ex) {
342         // TODO: Log it
343         RadLog("Client configuration error: %s.", ex.what());
344     }
345     return false;
346 }
347
348 bool STG_CLIENT::reconnect()
349 {
350     if (stgClient == NULL)
351     {
352         RadLog("Connection is not configured.");
353         return false;
354     }
355     if (!stgClient->stop())
356     {
357         RadLog("Failed to drop previous connection.");
358         return false;
359     }
360     try {
361         STG_CLIENT* old = stgClient;
362         stgClient = new STG_CLIENT(*old);
363         delete old;
364         return true;
365     } catch (const ChannelConfig::Error& ex) {
366         // TODO: Log it
367         RadLog("Client configuration error: %s.", ex.what());
368     }
369     return false;
370 }
371
372 STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data)
373     : m_config(address),
374       m_sock(connect()),
375       m_running(false),
376       m_stopped(true),
377       m_lastPing(time(NULL)),
378       m_lastActivity(m_lastPing),
379       m_callback(callback),
380       m_data(data),
381       m_parser(&STG_CLIENT::Impl::process, this),
382       m_connected(true)
383 {
384     int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
385     if (res != 0)
386         throw Error("Failed to create thread: " + std::string(strerror(errno)));
387 }
388
389 STG_CLIENT::Impl::Impl(const Impl& rhs)
390     : m_config(rhs.m_config),
391       m_sock(connect()),
392       m_running(false),
393       m_stopped(true),
394       m_lastPing(time(NULL)),
395       m_lastActivity(m_lastPing),
396       m_callback(rhs.m_callback),
397       m_data(rhs.m_data),
398       m_parser(&STG_CLIENT::Impl::process, this),
399       m_connected(true)
400 {
401     int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
402     if (res != 0)
403         throw Error("Failed to create thread: " + std::string(strerror(errno)));
404 }
405
406 STG_CLIENT::Impl::~Impl()
407 {
408     stop();
409     shutdown(m_sock, SHUT_RDWR);
410     close(m_sock);
411 }
412
413 bool STG_CLIENT::Impl::stop()
414 {
415     m_connected = false;
416
417     if (m_stopped)
418         return true;
419
420     m_running = false;
421
422     for (size_t i = 0; i < 25 && !m_stopped; i++) {
423         struct timespec ts = {0, 200000000};
424         nanosleep(&ts, NULL);
425     }
426
427     if (m_stopped) {
428         pthread_join(m_thread, NULL);
429         return true;
430     }
431
432     return false;
433 }
434
435 bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
436 {
437     MapGen map;
438     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
439         map.add(it->first, new StringGen(it->second));
440     map.add("Radius-Username", new StringGen(userName));
441     map.add("Radius-Userpass", new StringGen(password));
442
443     PacketGen gen("data");
444     gen.add("stage", toStage(type))
445        .add("pairs", map);
446
447     m_lastPing = time(NULL);
448
449     return generate(gen, &STG_CLIENT::Impl::write, this);
450 }
451
452 void STG_CLIENT::Impl::runImpl()
453 {
454     m_running = true;
455
456     while (m_running) {
457         fd_set fds;
458
459         FD_ZERO(&fds);
460         FD_SET(m_sock, &fds);
461
462         struct timeval tv;
463         tv.tv_sec = 0;
464         tv.tv_usec = 500000;
465
466         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
467         if (res < 0)
468         {
469             if (errno == EINTR)
470                 continue;
471             RadLog("'select' is failed: %s", strerror(errno));
472             //m_error = std::string("'select' is failed: '") + strerror(errno) + "'.";
473             //m_logger(m_error);
474             break;
475         }
476
477         if (!m_running)
478             break;
479
480         if (res > 0)
481         {
482             if (FD_ISSET(m_sock, &fds))
483                 m_running = read();
484         }
485         else
486             m_running = tick();
487     }
488
489     m_connected = false;
490     m_stopped = true;
491 }
492
493 int STG_CLIENT::Impl::connect()
494 {
495     if (m_config.transport == "tcp")
496         return connectTCP();
497     else if (m_config.transport == "unix")
498         return connectUNIX();
499     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
500 }
501
502 int STG_CLIENT::Impl::connectTCP()
503 {
504     addrinfo hints;
505     memset(&hints, 0, sizeof(addrinfo));
506
507     hints.ai_family = AF_INET;       /* Allow IPv4 */
508     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
509     hints.ai_flags = 0;     /* For wildcard IP address */
510     hints.ai_protocol = 0;           /* Any protocol */
511     hints.ai_canonname = NULL;
512     hints.ai_addr = NULL;
513     hints.ai_next = NULL;
514
515     addrinfo* ais = NULL;
516     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
517     if (res != 0)
518         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
519
520     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
521     {
522         int fd = socket(AF_INET, SOCK_STREAM, 0);
523         if (fd == -1)
524         {
525             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
526             freeaddrinfo(ais);
527             throw error;
528         }
529         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
530         {
531             shutdown(fd, SHUT_RDWR);
532             close(fd);
533             RadLog("'connect' is failed: %s", strerror(errno));
534             // TODO: log it.
535             continue;
536         }
537         freeaddrinfo(ais);
538         return fd;
539     }
540
541     freeaddrinfo(ais);
542
543     throw Error("Failed to resolve '" + m_config.address);
544 };
545
546 int STG_CLIENT::Impl::connectUNIX()
547 {
548     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
549     if (fd == -1)
550         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
551     struct sockaddr_un addr;
552     memset(&addr, 0, sizeof(addr));
553     addr.sun_family = AF_UNIX;
554     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
555     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
556     {
557         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
558         shutdown(fd, SHUT_RDWR);
559         close(fd);
560         throw error;
561     }
562     return fd;
563 }
564
565 bool STG_CLIENT::Impl::read()
566 {
567     static std::vector<char> buffer(1024);
568     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
569     if (res < 0)
570     {
571         RadLog("Failed to read data: ", strerror(errno));
572         //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
573         return false;
574     }
575     m_lastActivity = time(NULL);
576     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
577     if (res == 0)
578     {
579         m_parser.last();
580         return false;
581     }
582     return m_parser.append(buffer.data(), res);
583 }
584
585 bool STG_CLIENT::Impl::tick()
586 {
587     time_t now = time(NULL);
588     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
589     {
590         int delta = difftime(now, m_lastActivity);
591         RadLog("Connection timeout: %d sec.", delta);
592         //m_logger("Connection to " + m_remote + " timed out.");
593         return false;
594     }
595     if (difftime(now, m_lastPing) > PING_TIMEOUT)
596     {
597         int delta = difftime(now, m_lastPing);
598         RadLog("Ping timeout: %d sec. Sending ping...", delta);
599         sendPing();
600     }
601     return true;
602 }
603
604 void STG_CLIENT::Impl::process(void* data)
605 {
606     Impl& impl = *static_cast<Impl*>(data);
607     switch (impl.m_parser.packet())
608     {
609         case PING:
610             impl.processPing();
611             return;
612         case PONG:
613             impl.processPong();
614             return;
615         case DATA:
616             impl.processData();
617             return;
618     }
619     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
620 }
621
622 void STG_CLIENT::Impl::processPing()
623 {
624     RadLog("Got ping, sending pong.");
625     sendPong();
626 }
627
628 void STG_CLIENT::Impl::processPong()
629 {
630     RadLog("Got pong.");
631     m_lastActivity = time(NULL);
632 }
633
634 void STG_CLIENT::Impl::processData()
635 {
636     RESULT data;
637     RadLog("Got data.");
638     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
639         data.reply.push_back(std::make_pair(it->first, it->second));
640     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
641         data.modify.push_back(std::make_pair(it->first, it->second));
642     m_callback(m_data, data, m_parser.result());
643 }
644
645 bool STG_CLIENT::Impl::sendPing()
646 {
647     PacketGen gen("ping");
648
649     m_lastPing = time(NULL);
650
651     return generate(gen, &STG_CLIENT::Impl::write, this);
652 }
653
654 bool STG_CLIENT::Impl::sendPong()
655 {
656     PacketGen gen("pong");
657
658     m_lastPing = time(NULL);
659
660     return generate(gen, &STG_CLIENT::Impl::write, this);
661 }
662
663 bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size)
664 {
665     RadLog("Sending JSON:");
666     std::string json(buf, size);
667     RadLog("%s", json.c_str());
668     STG_CLIENT::Impl& impl = *static_cast<STG_CLIENT::Impl*>(data);
669     while (size > 0)
670     {
671         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
672         if (res < 0)
673         {
674             impl.m_connected = false;
675             RadLog("Failed to write data: %s.", strerror(errno));
676             //conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
677             return false;
678         }
679         size -= res;
680     }
681     return true;
682 }
683
684 void* STG_CLIENT::Impl::run(void* data)
685 {
686     Impl& impl = *static_cast<Impl*>(data);
687     impl.runImpl();
688     return NULL;
689 }