2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18 * Author : Maxim Mamontov <faust@stargazer.dp.ua>
26 #include "stg/json_parser.h"
27 #include "stg/json_generator.h"
28 #include "stg/users.h"
30 #include "stg/logger.h"
31 #include "stg/common.h"
33 #include <yajl/yajl_gen.h>
41 #include <sys/types.h>
42 #include <sys/socket.h>
45 // On OSX this flag does not exist.
46 #define MSG_NOSIGNAL 0
51 using STG::JSON::Parser;
52 using STG::JSON::PairsParser;
53 using STG::JSON::EnumParser;
54 using STG::JSON::NodeParser;
56 using STG::JSON::MapGen;
57 using STG::JSON::StringGen;
62 double CONN_TIMEOUT = 60;
63 double PING_TIMEOUT = 10;
81 std::map<std::string, Packet> packetCodes;
82 std::map<std::string, Stage> stageCodes;
84 class PacketParser : public EnumParser<Packet>
87 PacketParser(NodeParser* next, Packet& packet, std::string& packetStr)
88 : EnumParser(next, packet, packetStr, packetCodes)
90 if (!packetCodes.empty())
92 packetCodes["ping"] = PING;
93 packetCodes["pong"] = PONG;
94 packetCodes["data"] = DATA;
98 class StageParser : public EnumParser<Stage>
101 StageParser(NodeParser* next, Stage& stage, std::string& stageStr)
102 : EnumParser(next, stage, stageStr, stageCodes)
104 if (!stageCodes.empty())
106 stageCodes["authorize"] = AUTHORIZE;
107 stageCodes["authenticate"] = AUTHENTICATE;
108 stageCodes["preacct"] = PREACCT;
109 stageCodes["accounting"] = ACCOUNTING;
110 stageCodes["postauth"] = POSTAUTH;
114 class TopParser : public NodeParser
117 typedef void (*Callback) (void* /*data*/);
118 TopParser(Callback callback, void* data)
119 : m_packetParser(this, m_packet, m_packetStr),
120 m_stageParser(this, m_stage, m_stageStr),
121 m_pairsParser(this, m_data),
122 m_callback(callback), m_callbackData(data)
125 virtual NodeParser* parseStartMap() { return this; }
126 virtual NodeParser* parseMapKey(const std::string& value)
128 std::string key = ToLower(value);
131 return &m_packetParser;
132 else if (key == "stage")
133 return &m_stageParser;
134 else if (key == "pairs")
135 return &m_pairsParser;
139 virtual NodeParser* parseEndMap() { m_callback(m_callbackData); return this; }
141 const std::string& packetStr() const { return m_packetStr; }
142 Packet packet() const { return m_packet; }
143 const std::string& stageStr() const { return m_stageStr; }
144 Stage stage() const { return m_stage; }
145 const Config::Pairs& data() const { return m_data; }
148 std::string m_packetStr;
150 std::string m_stageStr;
152 Config::Pairs m_data;
154 PacketParser m_packetParser;
155 StageParser m_stageParser;
156 PairsParser m_pairsParser;
159 void* m_callbackData;
162 class ProtoParser : public Parser
165 ProtoParser(TopParser::Callback callback, void* data)
166 : Parser( &m_topParser ),
167 m_topParser(callback, data)
170 const std::string& packetStr() const { return m_topParser.packetStr(); }
171 Packet packet() const { return m_topParser.packet(); }
172 const std::string& stageStr() const { return m_topParser.stageStr(); }
173 Stage stage() const { return m_topParser.stage(); }
174 const Config::Pairs& data() const { return m_topParser.data(); }
177 TopParser m_topParser;
180 class PacketGen : public Gen
183 PacketGen(const std::string& type)
186 m_gen.add("packet", m_type);
188 void run(yajl_gen_t* handle) const
192 PacketGen& add(const std::string& key, const std::string& value)
194 m_gen.add(key, new StringGen(value));
197 PacketGen& add(const std::string& key, MapGen* map)
202 PacketGen& add(const std::string& key, MapGen& map)
212 std::string toString(Config::ReturnCode code)
216 case Config::REJECT: return "reject";
217 case Config::FAIL: return "fail";
218 case Config::OK: return "ok";
219 case Config::HANDLED: return "handled";
220 case Config::INVALID: return "invalid";
221 case Config::USERLOCK: return "userlock";
222 case Config::NOTFOUND: return "notfound";
223 case Config::NOOP: return "noop";
224 case Config::UPDATED: return "noop";
234 Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote);
237 int sock() const { return m_sock; }
242 bool isOk() const { return m_ok; }
246 PLUGIN_LOGGER& m_logger;
248 const Config& m_config;
250 std::string m_remote;
253 time_t m_lastActivity;
254 ProtoParser m_parser;
255 std::set<std::string> m_authorized;
257 template <typename T>
258 const T& stageMember(T Config::Section::* member) const
260 switch (m_parser.stage())
262 case AUTHORIZE: return m_config.autz.*member;
263 case AUTHENTICATE: return m_config.auth.*member;
264 case POSTAUTH: return m_config.postauth.*member;
265 case PREACCT: return m_config.preacct.*member;
266 case ACCOUNTING: return m_config.acct.*member;
268 throw std::runtime_error("Invalid stage: '" + m_parser.stageStr() + "'.");
271 const Config::Pairs& match() const { return stageMember(&Config::Section::match); }
272 const Config::Pairs& modify() const { return stageMember(&Config::Section::modify); }
273 const Config::Pairs& reply() const { return stageMember(&Config::Section::reply); }
274 Config::ReturnCode returnCode() const { return stageMember(&Config::Section::returnCode); }
275 const Config::Authorize& authorize() const { return stageMember(&Config::Section::authorize); }
277 static void process(void* data);
281 bool answer(const USER& user);
286 static bool write(void* data, const char* buf, size_t size);
289 Conn::Conn(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
290 : m_impl(new Impl(users, logger, plugin, config, fd, remote))
298 int Conn::sock() const
300 return m_impl->sock();
305 return m_impl->read();
310 return m_impl->tick();
313 bool Conn::isOk() const
315 return m_impl->isOk();
318 Conn::Impl::Impl(USERS& users, PLUGIN_LOGGER& logger, RADIUS& plugin, const Config& config, int fd, const std::string& remote)
326 m_lastPing(time(NULL)),
327 m_lastActivity(m_lastPing),
328 m_parser(&Conn::Impl::process, this)
336 std::set<std::string>::const_iterator it = m_authorized.begin();
337 for (; it != m_authorized.end(); ++it)
338 m_plugin.unauthorize(*it, "Lost connection to RADIUS server " + m_remote + ".");
341 bool Conn::Impl::read()
343 static std::vector<char> buffer(1024);
344 ssize_t res = ::read(m_sock, buffer.data(), buffer.size());
347 m_logger("Failed to read data from '" + m_remote + "': " + strerror(errno));
351 printfd(__FILE__, "Read %d bytes.\n%s\n", res, std::string(buffer.data(), res).c_str());
352 m_lastActivity = time(NULL);
358 return m_parser.append(buffer.data(), res);
361 bool Conn::Impl::tick()
363 time_t now = time(NULL);
364 if (difftime(now, m_lastActivity) > CONN_TIMEOUT)
366 int delta = difftime(now, m_lastActivity);
367 printfd(__FILE__, "Connection to '%s' timed out: %d sec.\n", m_remote.c_str(), delta);
368 m_logger("Connection to " + m_remote + " timed out.");
372 if (difftime(now, m_lastPing) > PING_TIMEOUT)
374 int delta = difftime(now, m_lastPing);
375 printfd(__FILE__, "Ping timeout: %d sec. Sending ping...\n", delta);
381 void Conn::Impl::process(void* data)
383 Impl& impl = *static_cast<Impl*>(data);
386 switch (impl.m_parser.packet())
399 catch (const std::exception& ex)
401 printfd(__FILE__, "Processing error. %s", ex.what());
402 impl.m_logger("Processing error. %s", ex.what());
404 printfd(__FILE__, "Received invalid packet type: '%s'.\n", impl.m_parser.packetStr().c_str());
405 impl.m_logger("Received invalid packet type: " + impl.m_parser.packetStr());
408 void Conn::Impl::processPing()
410 printfd(__FILE__, "Got ping. Sending pong...\n");
414 void Conn::Impl::processPong()
416 printfd(__FILE__, "Got pong.\n");
417 m_lastActivity = time(NULL);
420 void Conn::Impl::processData()
422 printfd(__FILE__, "Got data.\n");
423 int handle = m_users.OpenSearch();
425 USER_PTR user = NULL;
426 bool matched = false;
427 while (m_users.SearchNext(handle, &user) == 0)
433 for (Config::Pairs::const_iterator it = match().begin(); it != match().end(); ++it)
435 Config::Pairs::const_iterator pos = m_parser.data().find(it->first);
436 if (pos == m_parser.data().end())
441 if (user->GetParamValue(it->second) != pos->second)
450 if (authorize().check(*user, m_parser.data()))
452 m_plugin.authorize(*user);
453 m_authorized.insert(user->GetLogin());
461 m_users.CloseSearch(handle);
464 bool Conn::Impl::answer(const USER& user)
466 printfd(__FILE__, "Got match. Sending answer...\n");
468 for (Config::Pairs::const_iterator it = reply().begin(); it != reply().end(); ++it)
469 replyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
472 for (Config::Pairs::const_iterator it = modify().begin(); it != modify().end(); ++it)
473 modifyData.add(it->first, new StringGen(user.GetParamValue(it->second)));
475 PacketGen gen("data");
476 gen.add("result", "ok")
477 .add("reply", replyData)
478 .add("modify", modifyData);
480 m_lastPing = time(NULL);
482 return generate(gen, &Conn::Impl::write, this);
485 bool Conn::Impl::answerNo()
487 printfd(__FILE__, "No match. Sending answer...\n");
488 PacketGen gen("data");
489 gen.add("result", "no");
490 gen.add("return_code", toString(returnCode()));
492 m_lastPing = time(NULL);
494 return generate(gen, &Conn::Impl::write, this);
497 bool Conn::Impl::sendPing()
499 PacketGen gen("ping");
501 m_lastPing = time(NULL);
503 return generate(gen, &Conn::Impl::write, this);
506 bool Conn::Impl::sendPong()
508 PacketGen gen("pong");
510 m_lastPing = time(NULL);
512 return generate(gen, &Conn::Impl::write, this);
515 bool Conn::Impl::write(void* data, const char* buf, size_t size)
517 std::string json(buf, size);
518 printfd(__FILE__, "Writing JSON:\n%s\n", json.c_str());
519 Conn::Impl& conn = *static_cast<Conn::Impl*>(data);
522 ssize_t res = ::send(conn.m_sock, buf, size, MSG_NOSIGNAL);
525 conn.m_logger("Failed to write pong to '" + conn.m_remote + "': " + strerror(errno));