]> git.stg.codes - stg.git/blob - projects/rlm_stg/stg_client.cpp
Handle EINTR in mod_radius correctly.
[stg.git] / projects / rlm_stg / stg_client.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "stg_client.h"
22
23 #include "radlog.h"
24
25 #include "stg/json_parser.h"
26 #include "stg/json_generator.h"
27 #include "stg/common.h"
28
29 #include <map>
30 #include <utility>
31 #include <cerrno>
32 #include <cstring>
33
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <sys/un.h> // UNIX
37 #include <netinet/in.h> // IP
38 #include <netinet/tcp.h> // TCP
39 #include <netdb.h>
40
41 using STG::JSON::Parser;
42 using STG::JSON::PairsParser;
43 using STG::JSON::EnumParser;
44 using STG::JSON::NodeParser;
45 using STG::JSON::Gen;
46 using STG::JSON::MapGen;
47 using STG::JSON::StringGen;
48
49 namespace {
50
51 double CONN_TIMEOUT = 60;
52 double PING_TIMEOUT = 10;
53
54 STG_CLIENT* stgClient = NULL;
55
56 std::string toStage(STG_CLIENT::TYPE type)
57 {
58     switch (type)
59     {
60         case STG_CLIENT::AUTHORIZE: return "authorize";
61         case STG_CLIENT::AUTHENTICATE: return "authenticate";
62         case STG_CLIENT::POST_AUTH: return "postauth";
63         case STG_CLIENT::PRE_ACCT: return "preacct";
64         case STG_CLIENT::ACCOUNT: return "accounting";
65     }
66     return "";
67 }
68
69 enum Packet
70 {
71     PING,
72     PONG,
73     DATA
74 };
75
76 std::map<std::string, Packet> packetCodes;
77 std::map<std::string, bool> resultCodes;
78
79 class PacketParser : public EnumParser<Packet>
80 {
81     public:
82         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
83             : EnumParser(next, packet, packetStr, packetCodes)
84         {
85             if (!packetCodes.empty())
86                 return;
87             packetCodes["ping"] = PING;
88             packetCodes["pong"] = PONG;
89             packetCodes["data"] = DATA;
90         }
91 };
92
93 class ResultParser : public EnumParser<bool>
94 {
95     public:
96         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
97             : EnumParser(next, result, resultStr, resultCodes)
98         {
99             if (!resultCodes.empty())
100                 return;
101             resultCodes["no"] = false;
102             resultCodes["ok"] = true;
103         }
104 };
105
106 class TopParser : public NodeParser
107 {
108     public:
109         typedef void (*Callback) (void* /*data*/);
110         TopParser(Callback callback, void* data)
111             : m_packetParser(this, m_packet, m_packetStr),
112               m_resultParser(this, m_result, m_resultStr),
113               m_replyParser(this, m_reply),
114               m_modifyParser(this, m_modify),
115               m_callback(callback), m_data(data)
116         {}
117
118         virtual NodeParser* parseStartMap() { return this; }
119         virtual NodeParser* parseMapKey(const std::string& value)
120         {
121             std::string key = ToLower(value);
122
123             if (key == "packet")
124                 return &m_packetParser;
125             else if (key == "result")
126                 return &m_resultParser;
127             else if (key == "reply")
128                 return &m_replyParser;
129             else if (key == "modify")
130                 return &m_modifyParser;
131
132             return this;
133         }
134         virtual NodeParser* parseEndMap() { m_callback(m_data); return this; }
135
136         const std::string& packetStr() const { return m_packetStr; }
137         Packet packet() const { return m_packet; }
138         const std::string& resultStr() const { return m_resultStr; }
139         bool result() const { return m_result; }
140         const PairsParser::Pairs& reply() const { return m_reply; }
141         const PairsParser::Pairs& modify() const { return m_modify; }
142
143     private:
144         std::string m_packetStr;
145         Packet m_packet;
146         std::string m_resultStr;
147         bool m_result;
148         PairsParser::Pairs m_reply;
149         PairsParser::Pairs m_modify;
150
151         PacketParser m_packetParser;
152         ResultParser m_resultParser;
153         PairsParser m_replyParser;
154         PairsParser m_modifyParser;
155
156         Callback m_callback;
157         void* m_data;
158 };
159
160 class ProtoParser : public Parser
161 {
162     public:
163         ProtoParser(TopParser::Callback callback, void* data)
164             : Parser( &m_topParser ),
165               m_topParser(callback, data)
166         {}
167
168         const std::string& packetStr() const { return m_topParser.packetStr(); }
169         Packet packet() const { return m_topParser.packet(); }
170         const std::string& resultStr() const { return m_topParser.resultStr(); }
171         bool result() const { return m_topParser.result(); }
172         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
173         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
174
175     private:
176         TopParser m_topParser;
177 };
178
179 class PacketGen : public Gen
180 {
181     public:
182         PacketGen(const std::string& type)
183             : m_type(type)
184         {
185             m_gen.add("packet", m_type);
186         }
187         void run(yajl_gen_t* handle) const
188         {
189             m_gen.run(handle);
190         }
191         PacketGen& add(const std::string& key, const std::string& value)
192         {
193             m_gen.add(key, new StringGen(value));
194             return *this;
195         }
196         PacketGen& add(const std::string& key, MapGen& map)
197         {
198             m_gen.add(key, map);
199             return *this;
200         }
201     private:
202         MapGen m_gen;
203         StringGen m_type;
204 };
205
206 }
207
208 class STG_CLIENT::Impl
209 {
210 public:
211     Impl(const std::string& address, Callback callback, void* data);
212     ~Impl();
213
214     bool stop();
215
216     bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
217
218 private:
219     ChannelConfig m_config;
220
221     int m_sock;
222
223     bool m_running;
224     bool m_stopped;
225
226     time_t m_lastPing;
227     time_t m_lastActivity;
228
229     pthread_t m_thread;
230     pthread_mutex_t m_mutex;
231
232     Callback m_callback;
233     void* m_data;
234
235     ProtoParser m_parser;
236
237     void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
238     void m_writePairBlock(const PAIRS& source);
239     PAIRS m_readPairBlock();
240
241     static void* run(void* );
242
243     void runImpl();
244
245     int connect();
246     int connectTCP();
247     int connectUNIX();
248
249     bool read();
250     bool tick();
251
252     static void process(void* data);
253     void processPing();
254     void processPong();
255     void processData();
256     bool sendPing();
257     bool sendPong();
258
259     static bool write(void* data, const char* buf, size_t size);
260 };
261
262 ChannelConfig::ChannelConfig(std::string addr)
263 {
264     // unix:pass@/var/run/stg.sock
265     // tcp:secret@192.168.0.1:12345
266     // udp:key@isp.com.ua:54321
267
268     size_t pos = addr.find_first_of(':');
269     if (pos == std::string::npos)
270         throw Error("Missing transport name.");
271     transport = ToLower(addr.substr(0, pos));
272     addr = addr.substr(pos + 1);
273     if (addr.empty())
274         throw Error("Missing address to connect to.");
275     pos = addr.find_first_of('@');
276     if (pos != std::string::npos) {
277         key = addr.substr(0, pos);
278         addr = addr.substr(pos + 1);
279         if (addr.empty())
280             throw Error("Missing address to connect to.");
281     }
282     if (transport == "unix")
283     {
284         address = addr;
285         return;
286     }
287     pos = addr.find_first_of(':');
288     if (pos == std::string::npos)
289         throw Error("Missing port.");
290     address = addr.substr(0, pos);
291     portStr = addr.substr(pos + 1);
292     if (str2x(portStr, port))
293         throw Error("Invalid port value.");
294 }
295
296 STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data)
297     : m_impl(new Impl(address, callback, data))
298 {
299 }
300
301 STG_CLIENT::~STG_CLIENT()
302 {
303 }
304
305 bool STG_CLIENT::stop()
306 {
307     return m_impl->stop();
308 }
309
310 bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
311 {
312     return m_impl->request(type, userName, password, pairs);
313 }
314
315 STG_CLIENT* STG_CLIENT::get()
316 {
317     return stgClient;
318 }
319
320 bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data)
321 {
322     if ( stgClient != NULL && stgClient->stop() )
323         delete stgClient;
324     try {
325         stgClient = new STG_CLIENT(address, callback, data);
326         return true;
327     } catch (const ChannelConfig::Error& ex) {
328         // TODO: Log it
329         RadLog("Client configuration error: %s.", ex.what());
330     }
331     return false;
332 }
333
334 STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data)
335     : m_config(address),
336       m_sock(connect()),
337       m_running(false),
338       m_stopped(true),
339       m_lastPing(time(NULL)),
340       m_lastActivity(m_lastPing),
341       m_callback(callback),
342       m_data(data),
343       m_parser(&STG_CLIENT::Impl::process, this)
344 {
345     int res = pthread_create(&m_thread, NULL, &STG_CLIENT::Impl::run, this);
346     if (res != 0)
347         throw Error("Failed to create thread: " + std::string(strerror(errno)));
348 }
349
350 STG_CLIENT::Impl::~Impl()
351 {
352     stop();
353     shutdown(m_sock, SHUT_RDWR);
354     close(m_sock);
355 }
356
357 bool STG_CLIENT::Impl::stop()
358 {
359     if (m_stopped)
360         return true;
361
362     m_running = false;
363
364     for (size_t i = 0; i < 25 && !m_stopped; i++) {
365         struct timespec ts = {0, 200000000};
366         nanosleep(&ts, NULL);
367     }
368
369     if (m_stopped) {
370         pthread_join(m_thread, NULL);
371         return true;
372     }
373
374     return false;
375 }
376
377 bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
378 {
379     MapGen map;
380     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
381         map.add(it->first, new StringGen(it->second));
382     map.add("Radius-Username", new StringGen(userName));
383     map.add("Radius-Userpass", new StringGen(password));
384
385     PacketGen gen("data");
386     gen.add("stage", toStage(type))
387        .add("pairs", map);
388
389     m_lastPing = time(NULL);
390
391     return generate(gen, &STG_CLIENT::Impl::write, this);
392 }
393
394 void STG_CLIENT::Impl::runImpl()
395 {
396     m_running = true;
397
398     while (m_running) {
399         fd_set fds;
400
401         FD_ZERO(&fds);
402         FD_SET(m_sock, &fds);
403
404         struct timeval tv;
405         tv.tv_sec = 0;
406         tv.tv_usec = 500000;
407
408         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
409         if (res < 0)
410         {
411             if (errno == EINTR)
412                 continue;
413             RadLog("'select' is failed: %s", strerror(errno));
414             //m_error = std::string("'select' is failed: '") + strerror(errno) + "'.";
415             //m_logger(m_error);
416             break;
417         }
418
419         if (!m_running)
420             break;
421
422         if (res > 0)
423         {
424             if (FD_ISSET(m_sock, &fds))
425                 m_running = read();
426         }
427         else
428             m_running = tick();
429     }
430
431     m_stopped = true;
432 }
433
434 int STG_CLIENT::Impl::connect()
435 {
436     if (m_config.transport == "tcp")
437         return connectTCP();
438     else if (m_config.transport == "unix")
439         return connectUNIX();
440     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
441 }
442
443 int STG_CLIENT::Impl::connectTCP()
444 {
445     addrinfo hints;
446     memset(&hints, 0, sizeof(addrinfo));
447
448     hints.ai_family = AF_INET;       /* Allow IPv4 */
449     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
450     hints.ai_flags = 0;     /* For wildcard IP address */
451     hints.ai_protocol = 0;           /* Any protocol */
452     hints.ai_canonname = NULL;
453     hints.ai_addr = NULL;
454     hints.ai_next = NULL;
455
456     addrinfo* ais = NULL;
457     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
458     if (res != 0)
459         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
460
461     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
462     {
463         int fd = socket(AF_INET, SOCK_STREAM, 0);
464         if (fd == -1)
465         {
466             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
467             freeaddrinfo(ais);
468             throw error;
469         }
470         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
471         {
472             shutdown(fd, SHUT_RDWR);
473             close(fd);
474             RadLog("'connect' is failed: %s", strerror(errno));
475             // TODO: log it.
476             continue;
477         }
478         freeaddrinfo(ais);
479         return fd;
480     }
481
482     freeaddrinfo(ais);
483
484     throw Error("Failed to resolve '" + m_config.address);
485 };
486
487 int STG_CLIENT::Impl::connectUNIX()
488 {
489     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
490     if (fd == -1)
491         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
492     struct sockaddr_un addr;
493     memset(&addr, 0, sizeof(addr));
494     addr.sun_family = AF_UNIX;
495     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
496     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
497     {
498         Error error(std::string("Error connecting UNIX socket: ") + strerror(errno));
499         shutdown(fd, SHUT_RDWR);
500         close(fd);
501         throw error;
502     }
503     return fd;
504 }
505
506 bool STG_CLIENT::Impl::read()
507 {
508     static std::vector<char> buffer(1024);
509     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
510     if (res < 0)
511     {
512         RadLog("Failed to read data: ", strerror(errno));
513         //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
514         return false;
515     }
516     m_lastActivity = time(NULL);
517     RadLog("Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
518     if (res == 0)
519     {
520         m_parser.last();
521         return false;
522     }
523     return m_parser.append(buffer.data(), res);
524 }
525
526 bool STG_CLIENT::Impl::tick()
527 {
528     time_t now = time(NULL);
529     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
530     {
531         int delta = difftime(now, m_lastActivity);
532         RadLog("Connection timeout: %d sec.", delta);
533         //m_logger("Connection to " + m_remote + " timed out.");
534         return false;
535     }
536     if (difftime(now, m_lastPing) > PING_TIMEOUT)
537     {
538         int delta = difftime(now, m_lastPing);
539         RadLog("Ping timeout: %d sec. Sending ping...", delta);
540         sendPing();
541     }
542     return true;
543 }
544
545 void STG_CLIENT::Impl::process(void* data)
546 {
547     Impl& impl = *static_cast<Impl*>(data);
548     switch (impl.m_parser.packet())
549     {
550         case PING:
551             impl.processPing();
552             return;
553         case PONG:
554             impl.processPong();
555             return;
556         case DATA:
557             impl.processData();
558             return;
559     }
560     RadLog("Received invalid packet type: '%s'.", impl.m_parser.packetStr().c_str());
561 }
562
563 void STG_CLIENT::Impl::processPing()
564 {
565     RadLog("Got ping, sending pong.");
566     sendPong();
567 }
568
569 void STG_CLIENT::Impl::processPong()
570 {
571     RadLog("Got pong.");
572     m_lastActivity = time(NULL);
573 }
574
575 void STG_CLIENT::Impl::processData()
576 {
577     RESULT data;
578     RadLog("Got data.");
579     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
580         data.reply.push_back(std::make_pair(it->first, it->second));
581     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
582         data.modify.push_back(std::make_pair(it->first, it->second));
583     m_callback(m_data, data, m_parser.result());
584 }
585
586 bool STG_CLIENT::Impl::sendPing()
587 {
588     PacketGen gen("ping");
589
590     m_lastPing = time(NULL);
591
592     return generate(gen, &STG_CLIENT::Impl::write, this);
593 }
594
595 bool STG_CLIENT::Impl::sendPong()
596 {
597     PacketGen gen("pong");
598
599     m_lastPing = time(NULL);
600
601     return generate(gen, &STG_CLIENT::Impl::write, this);
602 }
603
604 bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size)
605 {
606     RadLog("Sending JSON:");
607     std::string json(buf, size);
608     RadLog("%s", json.c_str());
609     STG_CLIENT::Impl& impl = *static_cast<STG_CLIENT::Impl*>(data);
610     while (size > 0)
611     {
612         ssize_t res = ::send(impl.m_sock, buf, size, MSG_NOSIGNAL);
613         if (res < 0)
614         {
615             RadLog("Failed to write data: %s.", strerror(errno));
616             //conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
617             return false;
618         }
619         size -= res;
620     }
621     return true;
622 }
623
624 void* STG_CLIENT::Impl::run(void* data)
625 {
626     Impl& impl = *static_cast<Impl*>(data);
627     impl.runImpl();
628     return NULL;
629 }