]> git.stg.codes - stg.git/blob - projects/rlm_stg/stg_client.cpp
199aca905df76b3c5ba703fa21b35d7f42b907e1
[stg.git] / projects / rlm_stg / stg_client.cpp
1 /*
2  *    This program is free software; you can redistribute it and/or modify
3  *    it under the terms of the GNU General Public License as published by
4  *    the Free Software Foundation; either version 2 of the License, or
5  *    (at your option) any later version.
6  *
7  *    This program is distributed in the hope that it will be useful,
8  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10  *    GNU General Public License for more details.
11  *
12  *    You should have received a copy of the GNU General Public License
13  *    along with this program; if not, write to the Free Software
14  *    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15  */
16
17 /*
18  *    Author : Maxim Mamontov <faust@stargazer.dp.ua>
19  */
20
21 #include "stg_client.h"
22
23 #include "stg/common.h"
24
25 #include <stdexcept>
26
27 namespace {
28
29 STG_CLIENT* stgClient = NULL;
30
31 unsigned fromType(STG_CLIENT::TYPE type)
32 {
33     return static_cast<unsigned>(type);
34 }
35
36 STG::SGCP::TransportType toTransport(const std::string& value)
37 {
38     std::string type = ToLower(value);
39     if (type == "unix") return STG::SGCP::UNIX;
40     else if (type == "udp") return STG::SGCP::UDP;
41     else if (type == "tcp") return STG::SGCP::TCP;
42     throw ChannelConfig::Error("Invalid transport type. Should be 'unix', 'udp' or 'tcp'.");
43 }
44
45 }
46
47 ChannelConfig::ChannelConfig(std::string addr)
48     : transport(STG::SGCP::TCP)
49 {
50     // unix:pass@/var/run/stg.sock
51     // tcp:secret@192.168.0.1:12345
52     // udp:key@isp.com.ua:54321
53
54     size_t pos = addr.find_first_of(':');
55     if (pos == std::string::npos)
56         throw Error("Missing transport name.");
57     transport = toTransport(addr.substr(0, pos));
58     addr = addr.substr(pos + 1);
59     if (addr.empty())
60         throw Error("Missing address to connect to.");
61     pos = addr.find_first_of('@');
62     if (pos != std::string::npos) {
63         key = addr.substr(0, pos);
64         addr = addr.substr(pos + 1);
65         if (addr.empty())
66             throw Error("Missing address to connect to.");
67     }
68     if (transport == STG::SGCP::UNIX)
69     {
70         address = addr;
71         return;
72     }
73     pos = addr.find_first_of(':');
74     if (pos == std::string::npos)
75         throw Error("Missing port.");
76     address = addr.substr(0, pos);
77     if (str2x(addr.substr(pos + 1), port))
78         throw Error("Invalid port value.");
79 }
80
81 STG_CLIENT::STG_CLIENT(const std::string& address)
82     : m_config(address),
83       m_proto(m_config.transport, m_config.key)
84 {
85     try {
86         m_proto.connect(m_config.address, m_config.port);
87     } catch (const STG::SGCP::Proto::Error& ex) {
88         throw Error(ex.what());
89     }
90 }
91
92 STG_CLIENT::~STG_CLIENT()
93 {
94 }
95
96 RESULT STG_CLIENT::request(TYPE type, const std::string& userName, const std::string& password, const PAIRS& pairs)
97 {
98     m_writeHeader(type, userName, password);
99     m_writePairBlock(pairs);
100     RESULT result;
101     result.modify = m_readPairBlock();
102     result.reply = m_readPairBlock();
103     return result;
104 }
105
106 STG_CLIENT* STG_CLIENT::get()
107 {
108     return stgClient;
109 }
110
111 bool STG_CLIENT::configure(const std::string& address)
112 {
113     if ( stgClient != NULL )
114         delete stgClient;
115     try {
116         stgClient = new STG_CLIENT(address);
117         return true;
118     } catch (const ChannelConfig::Error& ex) {
119         // TODO: Log it
120     }
121     return false;
122 }
123
124 void STG_CLIENT::m_writeHeader(TYPE type, const std::string& userName, const std::string& password)
125 {
126     try {
127         m_proto.writeAll<uint64_t>(fromType(type));
128         m_proto.writeAll(userName);
129         m_proto.writeAll(password);
130     } catch (const STG::SGCP::Proto::Error& ex) {
131         throw Error(ex.what());
132     }
133 }
134
135 void STG_CLIENT::m_writePairBlock(const PAIRS& pairs)
136 {
137     try {
138         m_proto.writeAll<uint64_t>(pairs.size());
139         for (size_t i = 0; i < pairs.size(); ++i) {
140             m_proto.writeAll(pairs[i].first);
141             m_proto.writeAll(pairs[i].second);
142         }
143     } catch (const STG::SGCP::Proto::Error& ex) {
144         throw Error(ex.what());
145     }
146 }
147
148 PAIRS STG_CLIENT::m_readPairBlock()
149 {
150     try {
151         size_t count = m_proto.readAll<uint64_t>();
152         if (count == 0)
153             return PAIRS();
154         PAIRS res(count);
155         for (size_t i = 0; i < count; ++i) {
156             res[i].first = m_proto.readAll<std::string>();
157             res[i].second = m_proto.readAll<std::string>();
158         }
159         return res;
160     } catch (const STG::SGCP::Proto::Error& ex) {
161         throw Error(ex.what());
162     }
163 }