]> git.stg.codes - stg.git/blob - projects/rlm_stg/stg_client.cpp
Implemented rlm_stg.
[stg.git] / projects / rlm_stg / stg_client.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "stg_client.h"
22
23 #include "stg/json_parser.h"
24 #include "stg/json_generator.h"
25 #include "stg/common.h"
26
27 #include <map>
28 #include <utility>
29 #include <cerrno>
30 #include <cstring>
31
32 #include <sys/types.h>
33 #include <sys/socket.h>
34 #include <sys/un.h> // UNIX
35 #include <netinet/in.h> // IP
36 #include <netinet/tcp.h> // TCP
37 #include <netdb.h>
38
39 using STG::JSON::Parser;
40 using STG::JSON::PairsParser;
41 using STG::JSON::EnumParser;
42 using STG::JSON::NodeParser;
43 using STG::JSON::Gen;
44 using STG::JSON::MapGen;
45 using STG::JSON::StringGen;
46
47 namespace {
48
49 double CONN_TIMEOUT = 5;
50 double PING_TIMEOUT = 1;
51
52 STG_CLIENT* stgClient = NULL;
53
54 std::string toStage(STG_CLIENT::TYPE type)
55 {
56     switch (type)
57     {
58         case STG_CLIENT::AUTHORIZE: return "authorize";
59         case STG_CLIENT::AUTHENTICATE: return "authenticate";
60         case STG_CLIENT::POST_AUTH: return "postauth";
61         case STG_CLIENT::PRE_ACCT: return "preacct";
62         case STG_CLIENT::ACCOUNT: return "accounting";
63     }
64     return "";
65 }
66
67 enum Packet
68 {
69     PING,
70     PONG,
71     DATA
72 };
73
74 std::map<std::string, Packet> packetCodes;
75 std::map<std::string, bool> resultCodes;
76
77 class PacketParser : public EnumParser<Packet>
78 {
79     public:
80         PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
81             : EnumParser(next, packet, packetStr, packetCodes)
82         {
83             if (!packetCodes.empty())
84                 return;
85             packetCodes["ping"] = PING;
86             packetCodes["pong"] = PONG;
87             packetCodes["data"] = DATA;
88         }
89 };
90
91 class ResultParser : public EnumParser<bool>
92 {
93     public:
94         ResultParser(NodeParser* next, bool& result, std::string& resultStr)
95             : EnumParser(next, result, resultStr, resultCodes)
96         {
97             if (!resultCodes.empty())
98                 return;
99             resultCodes["no"] = false;
100             resultCodes["ok"] = true;
101         }
102 };
103
104 class TopParser : public NodeParser
105 {
106     public:
107         TopParser()
108             : m_packetParser(this, m_packet, m_packetStr),
109               m_resultParser(this, m_result, m_resultStr),
110               m_replyParser(this, m_reply),
111               m_modifyParser(this, m_modify)
112         {}
113
114         virtual NodeParser* parseStartMap() { return this; }
115         virtual NodeParser* parseMapKey(const std::string& value)
116         {
117             std::string key = ToLower(value);
118
119             if (key == "packet")
120                 return &m_packetParser;
121             else if (key == "result")
122                 return &m_resultParser;
123             else if (key == "reply")
124                 return &m_replyParser;
125             else if (key == "modify")
126                 return &m_modifyParser;
127
128             return this;
129         }
130         virtual NodeParser* parseEndMap() { return this; }
131
132         const std::string& packetStr() const { return m_packetStr; }
133         Packet packet() const { return m_packet; }
134         const std::string& resultStr() const { return m_resultStr; }
135         bool result() const { return m_result; }
136         const PairsParser::Pairs& reply() const { return m_reply; }
137         const PairsParser::Pairs& modify() const { return m_modify; }
138
139     private:
140         std::string m_packetStr;
141         Packet m_packet;
142         std::string m_resultStr;
143         bool m_result;
144         PairsParser::Pairs m_reply;
145         PairsParser::Pairs m_modify;
146
147         PacketParser m_packetParser;
148         ResultParser m_resultParser;
149         PairsParser m_replyParser;
150         PairsParser m_modifyParser;
151 };
152
153 class ProtoParser : public Parser
154 {
155     public:
156         ProtoParser() : Parser( &m_topParser ) {}
157
158         const std::string& packetStr() const { return m_topParser.packetStr(); }
159         Packet packet() const { return m_topParser.packet(); }
160         const std::string& resultStr() const { return m_topParser.resultStr(); }
161         bool result() const { return m_topParser.result(); }
162         const PairsParser::Pairs& reply() const { return m_topParser.reply(); }
163         const PairsParser::Pairs& modify() const { return m_topParser.modify(); }
164
165     private:
166         TopParser m_topParser;
167 };
168
169 class PacketGen : public Gen
170 {
171     public:
172         PacketGen(const std::string& type)
173             : m_type(type)
174         {
175             m_gen.add("packet", m_type);
176         }
177         void run(yajl_gen_t* handle) const
178         {
179             m_gen.run(handle);
180         }
181         PacketGen& add(const std::string& key, const std::string& value)
182         {
183             m_gen.add(key, new StringGen(value));
184             return *this;
185         }
186         PacketGen& add(const std::string& key, MapGen* map)
187         {
188             m_gen.add(key, map);
189             return *this;
190         }
191     private:
192         MapGen m_gen;
193         StringGen m_type;
194 };
195
196 }
197
198 class STG_CLIENT::Impl
199 {
200 public:
201     Impl(const std::string& address, Callback callback, void* data);
202     ~Impl();
203
204     bool stop();
205
206     bool request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs);
207
208 private:
209     ChannelConfig m_config;
210
211     int m_sock;
212
213     bool m_running;
214     bool m_stopped;
215
216     time_t m_lastPing;
217     time_t m_lastActivity;
218
219     pthread_t m_thread;
220     pthread_mutex_t m_mutex;
221
222     Callback m_callback;
223     void* m_data;
224
225     ProtoParser m_parser;
226
227     void m_writeHeader(TYPE type, const std::string& userName, const std::string& password);
228     void m_writePairBlock(const PAIRS& source);
229     PAIRS m_readPairBlock();
230
231     static void* run(void* );
232
233     void runImpl();
234
235     int connect();
236     int connectTCP();
237     int connectUNIX();
238
239     bool read();
240     bool tick();
241
242     bool process();
243     bool processPing();
244     bool processPong();
245     bool processData();
246     bool sendPing();
247     bool sendPong();
248
249     static bool write(void* data, const char* buf, size_t size);
250 };
251
252 ChannelConfig::ChannelConfig(std::string addr)
253 {
254     // unix:pass@/var/run/stg.sock
255     // tcp:secret@192.168.0.1:12345
256     // udp:key@isp.com.ua:54321
257
258     size_t pos = addr.find_first_of(':');
259     if (pos == std::string::npos)
260         throw Error("Missing transport name.");
261     transport = ToLower(addr.substr(0, pos));
262     addr = addr.substr(pos + 1);
263     if (addr.empty())
264         throw Error("Missing address to connect to.");
265     pos = addr.find_first_of('@');
266     if (pos != std::string::npos) {
267         key = addr.substr(0, pos);
268         addr = addr.substr(pos + 1);
269         if (addr.empty())
270             throw Error("Missing address to connect to.");
271     }
272     if (transport == "unix")
273     {
274         address = addr;
275         return;
276     }
277     pos = addr.find_first_of(':');
278     if (pos == std::string::npos)
279         throw Error("Missing port.");
280     address = addr.substr(0, pos);
281     portStr = addr.substr(pos + 1);
282     if (str2x(portStr, port))
283         throw Error("Invalid port value.");
284 }
285
286 STG_CLIENT::STG_CLIENT(const std::string& address, Callback callback, void* data)
287     : m_impl(new Impl(address, callback, data))
288 {
289 }
290
291 STG_CLIENT::~STG_CLIENT()
292 {
293 }
294
295 bool STG_CLIENT::stop()
296 {
297     return m_impl->stop();
298 }
299
300 bool STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
301 {
302     return m_impl->request(type, userName, password, pairs);
303 }
304
305 STG_CLIENT* STG_CLIENT::get()
306 {
307     return stgClient;
308 }
309
310 bool STG_CLIENT::configure(const std::string& address, Callback callback, void* data)
311 {
312     if ( stgClient != NULL && stgClient->stop() )
313         delete stgClient;
314     try {
315         stgClient = new STG_CLIENT(address, callback, data);
316         return true;
317     } catch (const ChannelConfig::Error& ex) {
318         // TODO: Log it
319     }
320     return false;
321 }
322
323 STG_CLIENT::Impl::Impl(const std::string& address, Callback callback, void* data)
324     : m_config(address),
325       m_sock(connect()),
326       m_running(false),
327       m_stopped(true),
328       m_lastPing(time(NULL)),
329       m_lastActivity(m_lastPing),
330       m_callback(callback),
331       m_data(data)
332 {
333     int res = pthread_create(&m_thread, NULL, run, this);
334     if (res != 0)
335         throw Error("Failed to create thread: " + std::string(strerror(errno)));
336 }
337
338 STG_CLIENT::Impl::~Impl()
339 {
340     stop();
341     shutdown(m_sock, SHUT_RDWR);
342     close(m_sock);
343 }
344
345 bool STG_CLIENT::Impl::stop()
346 {
347     if (m_stopped)
348         return true;
349
350     m_running = false;
351
352     for (size_t i = 0; i < 25 && !m_stopped; i++) {
353         struct timespec ts = {0, 200000000};
354         nanosleep(&ts, NULL);
355     }
356
357     if (m_stopped) {
358         pthread_join(m_thread, NULL);
359         return true;
360     }
361
362     return false;
363 }
364
365 bool STG_CLIENT::Impl::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
366 {
367     boost::scoped_ptr<MapGen> map(new MapGen);
368     for (PAIRS::const_iterator it = pairs.begin(); it != pairs.end(); ++it)
369         map->add(it->first, new StringGen(it->second));
370     map->add("Radius-Username", new StringGen(userName));
371     map->add("Radius-Userpass", new StringGen(password));
372
373     PacketGen gen("data");
374     gen.add("stage", toStage(type))
375        .add("pairs", map.get());
376
377     m_lastPing = time(NULL);
378
379     return generate(gen, &STG_CLIENT::Impl::write, this);
380 }
381
382 void STG_CLIENT::Impl::runImpl()
383 {
384     m_running = true;
385
386     while (m_running) {
387         fd_set fds;
388
389         FD_ZERO(&fds);
390         FD_SET(m_sock, &fds);
391
392         struct timeval tv;
393         tv.tv_sec = 0;
394         tv.tv_usec = 500000;
395
396         int res = select(m_sock + 1, &fds, NULL, NULL, &tv);
397         if (res < 0)
398         {
399             //m_error = std::string("'select' is failed: '") + strerror(errno) + "'.";
400             //m_logger(m_error);
401             break;
402         }
403
404         if (!m_running)
405             break;
406
407         if (res > 0)
408         {
409             if (FD_ISSET(m_sock, &fds))
410                 m_running = read();
411         }
412         else
413             m_running = tick();
414     }
415
416     m_stopped = true;
417 }
418
419 int STG_CLIENT::Impl::connect()
420 {
421     if (m_config.transport == "tcp")
422         return connectTCP();
423     else if (m_config.transport == "unix")
424         return connectUNIX();
425     throw Error("Invalid transport type: '" + m_config.transport + "'. Should be 'tcp' or 'unix'.");
426 }
427
428 int STG_CLIENT::Impl::connectTCP()
429 {
430     addrinfo hints;
431     memset(&hints, 0, sizeof(addrinfo));
432
433     hints.ai_family = AF_INET;       /* Allow IPv4 */
434     hints.ai_socktype = SOCK_STREAM; /* Stream socket */
435     hints.ai_flags = 0;     /* For wildcard IP address */
436     hints.ai_protocol = 0;           /* Any protocol */
437     hints.ai_canonname = NULL;
438     hints.ai_addr = NULL;
439     hints.ai_next = NULL;
440
441     addrinfo* ais = NULL;
442     int res = getaddrinfo(m_config.address.c_str(), m_config.portStr.c_str(), &hints, &ais);
443     if (res != 0)
444         throw Error("Error resolvin address '" + m_config.address + "': " + gai_strerror(res));
445
446     for (addrinfo* ai = ais; ai != NULL; ai = ai->ai_next)
447     {
448         int fd = socket(AF_INET, SOCK_STREAM, 0);
449         if (fd == -1)
450         {
451             Error error(std::string("Error creating TCP socket: ") + strerror(errno));
452             freeaddrinfo(ais);
453             throw error;
454         }
455         if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == -1)
456         {
457             shutdown(fd, SHUT_RDWR);
458             close(fd);
459             // TODO: log it.
460             continue;
461         }
462         freeaddrinfo(ais);
463         return fd;
464     }
465
466     freeaddrinfo(ais);
467
468     throw Error("Failed to resolve '" + m_config.address);
469 };
470
471 int STG_CLIENT::Impl::connectUNIX()
472 {
473     int fd = socket(AF_UNIX, SOCK_STREAM, 0);
474     if (fd == -1)
475         throw Error(std::string("Error creating UNIX socket: ") + strerror(errno));
476     struct sockaddr_un addr;
477     memset(&addr, 0, sizeof(addr));
478     addr.sun_family = AF_UNIX;
479     strncpy(addr.sun_path, m_config.address.c_str(), m_config.address.length());
480     if (::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == -1)
481     {
482         Error error(std::string("Error binding UNIX socket: ") + strerror(errno));
483         shutdown(fd, SHUT_RDWR);
484         close(fd);
485         throw error;
486     }
487     return fd;
488 }
489
490 bool STG_CLIENT::Impl::read()
491 {
492     static std::vector<char> buffer(1024);
493     ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
494     if (res < 0)
495     {
496         //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
497         return false;
498     }
499     m_lastActivity = time(NULL);
500     if (res == 0)
501     {
502         if (!m_parser.done())
503         {
504             //m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
505             return false;
506         }
507         return process();
508     }
509     return m_parser.append(buffer.data(), res);
510 }
511
512 bool STG_CLIENT::Impl::tick()
513 {
514     time_t now = time(NULL);
515     if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
516     {
517         //m_logger("Connection to " + m_remote + " timed out.");
518         return false;
519     }
520     if (difftime(now, m_lastPing) > PING_TIMEOUT)
521         sendPing();
522     return true;
523 }
524
525 bool STG_CLIENT::Impl::process()
526 {
527     switch (m_parser.packet())
528     {
529         case PING:
530             return processPing();
531         case PONG:
532             return processPong();
533         case DATA:
534             return processData();
535     }
536     //m_logger("Received invalid packet type: " + m_parser.packetStr());
537     return false;
538 }
539
540 bool STG_CLIENT::Impl::processPing()
541 {
542     return sendPong();
543 }
544
545 bool STG_CLIENT::Impl::processPong()
546 {
547     m_lastActivity = time(NULL);
548     return true;
549 }
550
551 bool STG_CLIENT::Impl::processData()
552 {
553     RESULT result;
554     for (PairsParser::Pairs::const_iterator it = m_parser.reply().begin(); it != m_parser.reply().end(); ++it)
555         result.reply.push_back(std::make_pair(it->first, it->second));
556     for (PairsParser::Pairs::const_iterator it = m_parser.modify().begin(); it != m_parser.modify().end(); ++it)
557         result.modify.push_back(std::make_pair(it->first, it->second));
558     return m_callback(m_data, result);
559 }
560
561 bool STG_CLIENT::Impl::sendPing()
562 {
563     PacketGen gen("ping");
564
565     m_lastPing = time(NULL);
566
567     return generate(gen, &STG_CLIENT::Impl::write, this);
568 }
569
570 bool STG_CLIENT::Impl::sendPong()
571 {
572     PacketGen gen("pong");
573
574     m_lastPing = time(NULL);
575
576     return generate(gen, &STG_CLIENT::Impl::write, this);
577 }
578
579 bool STG_CLIENT::Impl::write(void* data, const char* buf, size_t size)
580 {
581     STG_CLIENT::Impl& impl = *static_cast<STG_CLIENT::Impl*>(data);
582     while (size > 0)
583     {
584         ssize_t res = ::write(impl.m_sock, buf, size);
585         if (res < 0)
586         {
587             //conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));
588             return false;
589         }
590         size -= res;
591     }
592     return true;
593 }
594
595 void* STG_CLIENT::Impl::run(void* data)
596 {
597     Impl& impl = *static_cast<Impl*>(data);
598     impl.runImpl();
599     return NULL;
600 }