]> git.stg.codes - stg.git/blob - projects/rlm_stg/stg_client.cpp
6987976b7d44bf011ec4bb8d406fd538601dcb26
[stg.git] / projects / rlm_stg / stg_client.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "stg_client.h"
22
23 #include "stg/common.h"
24
25 #include <boost/bind.hpp>
26
27 #include <stdexcept>
28
29 namespace {
30
31 STG_CLIENT* stgClient = NULL;
32
33 unsigned fromType(STG_CLIENT::TYPE type)
34 {
35     return static_cast<unsigned>(type);
36 }
37
38 STG::SGCP::TransportType toTransport(const std::string& value)
39 {
40     std::string type = ToLower(value);
41     if (type == "unix") return STG::SGCP::UNIX;
42     else if (type == "udp") return STG::SGCP::UDP;
43     else if (type == "tcp") return STG::SGCP::TCP;
44     throw ChannelConfig::Error("Invalid transport type. Should be 'unix', 'udp' or 'tcp'.");
45 }
46
47 }
48
49 ChannelConfig::ChannelConfig(std::string addr)
50     : transport(STG::SGCP::TCP)
51 {
52     // unix:pass@/var/run/stg.sock
53     // tcp:secret@192.168.0.1:12345
54     // udp:key@isp.com.ua:54321
55
56     size_t pos = addr.find_first_of(':');
57     if (pos == std::string::npos)
58         throw Error("Missing transport name.");
59     transport = toTransport(addr.substr(0, pos));
60     addr = addr.substr(pos + 1);
61     if (addr.empty())
62         throw Error("Missing address to connect to.");
63     pos = addr.find_first_of('@');
64     if (pos != std::string::npos) {
65         key = addr.substr(0, pos);
66         addr = addr.substr(pos + 1);
67         if (addr.empty())
68             throw Error("Missing address to connect to.");
69     }
70     if (transport == STG::SGCP::UNIX)
71     {
72         address = addr;
73         return;
74     }
75     pos = addr.find_first_of(':');
76     if (pos == std::string::npos)
77         throw Error("Missing port.");
78     address = addr.substr(0, pos);
79     if (str2x(addr.substr(pos + 1), port))
80         throw Error("Invalid port value.");
81 }
82
83 STG_CLIENT::STG_CLIENT(const std::string& address)
84     : m_config(address),
85       m_proto(m_config.transport, m_config.key),
86       m_thread(boost::bind(&STG_CLIENT::m_run, this))
87 {
88 }
89
90 STG_CLIENT::~STG_CLIENT()
91 {
92     stop();
93 }
94
95 bool STG_CLIENT::stop()
96 {
97     return m_proto.stop();
98 }
99
100 RESULT STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
101 {
102     m_writeHeader(type, userName, password);
103     m_writePairBlock(pairs);
104     RESULT result;
105     result.modify = m_readPairBlock();
106     result.reply = m_readPairBlock();
107     return result;
108 }
109
110 STG_CLIENT* STG_CLIENT::get()
111 {
112     return stgClient;
113 }
114
115 bool STG_CLIENT::configure(const std::string& address)
116 {
117     if ( stgClient != NULL && stgClient->stop() )
118         delete stgClient;
119     try {
120         stgClient = new STG_CLIENT(address);
121         return true;
122     } catch (const ChannelConfig::Error& ex) {
123         // TODO: Log it
124     }
125     return false;
126 }
127
128 void STG_CLIENT::m_writeHeader(TYPE type, const std::string& userName, const std::string& password)
129 {
130     try {
131         m_proto.writeAll<uint64_t>(fromType(type));
132         m_proto.writeAll(userName);
133         m_proto.writeAll(password);
134     } catch (const STG::SGCP::Proto::Error& ex) {
135         throw Error(ex.what());
136     }
137 }
138
139 void STG_CLIENT::m_writePairBlock(const PAIRS& pairs)
140 {
141     try {
142         m_proto.writeAll<uint64_t>(pairs.size());
143         for (size_t i = 0; i < pairs.size(); ++i) {
144             m_proto.writeAll(pairs[i].first);
145             m_proto.writeAll(pairs[i].second);
146         }
147     } catch (const STG::SGCP::Proto::Error& ex) {
148         throw Error(ex.what());
149     }
150 }
151
152 PAIRS STG_CLIENT::m_readPairBlock()
153 {
154     try {
155         size_t count = m_proto.readAll<uint64_t>();
156         if (count == 0)
157             return PAIRS();
158         PAIRS res(count);
159         for (size_t i = 0; i < count; ++i) {
160             res[i].first = m_proto.readAll<std::string>();
161             res[i].second = m_proto.readAll<std::string>();
162         }
163         return res;
164     } catch (const STG::SGCP::Proto::Error& ex) {
165         throw Error(ex.what());
166     }
167 }
168
169 void STG_CLIENT::m_run()
170 {
171     m_proto.connect(m_config.address, m_config.port);
172     m_proto.run();
173 }