2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18 * Author : Maxim Mamontov <faust@stargazer.dp.ua>
26 #include "stg/json_parser.h"
27 #include "stg/json_generator.h"
28 #include "stg/users.h"
30 #include "stg/logger.h"
31 #include "stg/common.h"
33 #include <yajl/yajl_gen.h>
41 #include <sys/types.h>
42 #include <sys/socket.h>
46 using STG::JSON::Parser;
47 using STG::JSON::PairsParser;
48 using STG::JSON::EnumParser;
49 using STG::JSON::NodeParser;
51 using STG::JSON::MapGen;
52 using STG::JSON::StringGen;
57 double CONN_TIMEOUT = 60;
58 double PING_TIMEOUT = 10;
76 std::map<std::string, Packet> packetCodes;
77 std::map<std::string, Stage> stageCodes;
79 class PacketParser : public EnumParser<Packet>
82 PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
83 : EnumParser(next, packet, packetStr, packetCodes)
85 if (!packetCodes.empty())
87 packetCodes["ping"] = PING;
88 packetCodes["pong"] = PONG;
89 packetCodes["data"] = DATA;
93 class StageParser : public EnumParser<Stage>
96 StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
97 : EnumParser(next, stage, stageStr, stageCodes)
99 if (!stageCodes.empty())
101 stageCodes["authorize"] = AUTHORIZE;
102 stageCodes["authenticate"] = AUTHENTICATE;
103 stageCodes["preacct"] = PREACCT;
104 stageCodes["accounting"] = ACCOUNTING;
105 stageCodes["postauth"] = POSTAUTH;
109 class TopParser : public NodeParser
112 typedef void (*Callback) (void* /*data*/);
113 TopParser(Callback callback, void* data)
114 : m_packetParser(this, m_packet, m_packetStr),
115 m_stageParser(this, m_stage, m_stageStr),
116 m_pairsParser(this, m_data),
117 m_callback(callback), m_callbackData(data)
120 virtual NodeParser* parseStartMap() { return this; }
121 virtual NodeParser* parseMapKey(const std::string& value)
123 std::string key = ToLower(value);
126 return &m_packetParser;
127 else if (key == "stage")
128 return &m_stageParser;
129 else if (key == "pairs")
130 return &m_pairsParser;
134 virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
136 const std::string& packetStr() const { return m_packetStr; }
137 Packet packet() const { return m_packet; }
138 const std::string& stageStr() const { return m_stageStr; }
139 Stage stage() const { return m_stage; }
140 const Config::Pairs& data() const { return m_data; }
143 std::string m_packetStr;
145 std::string m_stageStr;
147 Config::Pairs m_data;
149 PacketParser m_packetParser;
150 StageParser m_stageParser;
151 PairsParser m_pairsParser;
154 void* m_callbackData;
157 class ProtoParser : public Parser
160 ProtoParser(TopParser::Callback callback, void* data)
161 : Parser( &m_topParser ),
162 m_topParser(callback, data)
165 const std::string& packetStr() const { return m_topParser.packetStr(); }
166 Packet packet() const { return m_topParser.packet(); }
167 const std::string& stageStr() const { return m_topParser.stageStr(); }
168 Stage stage() const { return m_topParser.stage(); }
169 const Config::Pairs& data() const { return m_topParser.data(); }
172 TopParser m_topParser;
175 class PacketGen : public Gen
178 PacketGen(const std::string& type)
181 m_gen.add("packet", m_type);
183 void run(yajl_gen_t* handle) const
187 PacketGen& add(const std::string& key, const std::string& value)
189 m_gen.add(key, new StringGen(value));
192 PacketGen& add(const std::string& key, MapGen* map)
197 PacketGen& add(const std::string& key, MapGen& map)
207 std::string toString(Config::ReturnCode code)
211 case Config::REJECT: return "reject";
212 case Config::FAIL: return "fail";
213 case Config::OK: return "ok";
214 case Config::HANDLED: return "handled";
215 case Config::INVALID: return "invalid";
216 case Config::USERLOCK: return "userlock";
217 case Config::NOTFOUND: return "notfound";
218 case Config::NOOP: return "noop";
219 case Config::UPDATED: return "noop";
229 Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
232 int sock() const { return m_sock; }
237 bool isOk() const { return m_ok; }
241 PLUGIN_LOGGER& m_logger;
243 const Config& m_config;
245 std::string m_remote;
248 time_t m_lastActivity;
249 ProtoParser m_parser;
250 std::set<std::string> m_authorized;
252 template <typename T>
253 const T& stageMember(T Config::Section::* member) const
255 switch (m_parser.stage())
257 case AUTHORIZE: return m_config.autz.*member;
258 case AUTHENTICATE: return m_config.auth.*member;
259 case POSTAUTH: return m_config.postauth.*member;
260 case PREACCT: return m_config.preacct.*member;
261 case ACCOUNTING: return m_config.acct.*member;
263 throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
266 const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
267 const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
268 const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
269 Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
270 const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
272 static void process(void* data);
276 bool answer(const USER& user);
281 static bool write(void* data, const char* buf, size_t size);
284 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
285 : m_impl(new Impl(users, logger, plugin, config, fd, remote))
293 int Conn::sock() const
295 return m_impl->sock();
300 return m_impl->read();
305 return m_impl->tick();
308 bool Conn::isOk() const
310 return m_impl->isOk();
313 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
321 m_lastPing(time(NULL)),
322 m_lastActivity(m_lastPing),
323 m_parser(&Conn::Impl::process, this)
331 std::set<std::string>::const_iterator it = m_authorized.begin();
332 for (; it != m_authorized.end(); ++it)
333 m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
336 bool Conn::Impl::read()
338 static std::vector<char> buffer(1024);
339 ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
342 m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
346 printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
347 m_lastActivity = time(NULL);
353 return m_parser.append(buffer.data(), res);
356 bool Conn::Impl::tick()
358 time_t now = time(NULL);
359 if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
361 int delta = difftime(now, m_lastActivity);
362 printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
363 m_logger("Connection to " + m_remote + " timed out.");
367 if (difftime(now, m_lastPing) > PING_TIMEOUT)
369 int delta = difftime(now, m_lastPing);
370 printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
376 void Conn::Impl::process(void* data)
378 Impl& impl = *static_cast<Impl*>(data);
381 switch (impl.m_parser.packet())
394 catch (const std::exception& ex)
396 printfd(__FILE__, "Processing error. %s", ex.what());
397 impl.m_logger("Processing error. %s", ex.what());
399 printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
400 impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
403 void Conn::Impl::processPing()
405 printfd(__FILE__, "Got ping. Sending pong...\n");
409 void Conn::Impl::processPong()
411 printfd(__FILE__, "Got pong.\n");
412 m_lastActivity = time(NULL);
415 void Conn::Impl::processData()
417 printfd(__FILE__, "Got data.\n");
418 int handle = m_users.OpenSearch();
420 USER_PTR user = NULL;
421 bool matched = false;
422 while (m_users.SearchNext(handle, &user) == 0)
428 for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
430 Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
431 if (pos == m_parser.data().end())
436 if (user->GetParamValue(it->second) != pos->second)
445 if (authorize().check(*user, m_parser.data()))
447 m_plugin.authorize(*user);
448 m_authorized.insert(user->GetLogin());
456 m_users.CloseSearch(handle);
459 bool Conn::Impl::answer(const USER& user)
461 printfd(__FILE__, "Got match. Sending answer...\n");
463 for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
464 replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
467 for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
468 modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
470 PacketGen gen("data");
471 gen.add("result", "ok")
472 .add("reply", replyData)
473 .add("modify", modifyData);
475 m_lastPing = time(NULL);
477 return generate(gen, &Conn::Impl::write, this);
480 bool Conn::Impl::answerNo()
482 printfd(__FILE__, "No match. Sending answer...\n");
483 PacketGen gen("data");
484 gen.add("result", "no");
485 gen.add("return_code", toString(returnCode()));
487 m_lastPing = time(NULL);
489 return generate(gen, &Conn::Impl::write, this);
492 bool Conn::Impl::sendPing()
494 PacketGen gen("ping");
496 m_lastPing = time(NULL);
498 return generate(gen, &Conn::Impl::write, this);
501 bool Conn::Impl::sendPong()
503 PacketGen gen("pong");
505 m_lastPing = time(NULL);
507 return generate(gen, &Conn::Impl::write, this);
510 bool Conn::Impl::write(void* data, const char* buf, size_t size)
512 std::string json(buf, size);
513 printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
514 Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
517 ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
520 conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));