]> git.stg.codes - stg.git/blobdiff - projects/stargazer/plugins/store/mysql/mysql_store.cpp
Complete replacement notifiers with subscriptions.
[stg.git] / projects / stargazer / plugins / store / mysql / mysql_store.cpp
index 53de158aa660986ab6f09e83477f3277bfa0b4bd..a4ee430930b95b9b4a9602ea3747edcb7ce12252 100644 (file)
@@ -1,19 +1,21 @@
-#include <sys/time.h>
-#include <cerrno>
-#include <cstdio>
-#include <cstdlib>
-#include <algorithm>
-
-#include <mysql.h>
-#include <errmsg.h>
+#include "mysql_store.h"
 
+#include "stg/common.h"
 #include "stg/user_ips.h"
 #include "stg/user_conf.h"
 #include "stg/user_stat.h"
+#include "stg/admin_conf.h"
+#include "stg/tariff_conf.h"
 #include "stg/blowfish.h"
-#include "stg/plugin_creator.h"
 #include "stg/logger.h"
-#include "mysql_store.h"
+
+#include <algorithm>
+#include <sys/time.h>
+#include <cerrno>
+#include <cstdio>
+#include <cstdlib>
+
+#include <mysql/errmsg.h>
 
 #define adm_enc_passwd "cjeifY8m3"
 
@@ -73,7 +75,7 @@ int GetTime(const std::string & str, time_t * val, time_t defaultVal)
 }
 
 //-----------------------------------------------------------------------------
-std::string ReplaceStr(std::string source, const std::string symlist, const char chgsym)
+std::string ReplaceStr(std::string source, const std::string symlist, const char chgsym)
 {
     std::string::size_type pos=0;
 
@@ -98,47 +100,39 @@ int GetULongLongInt(const std::string & str, uint64_t * val, uint64_t defaultVal
     return 0;
 } 
 
-PLUGIN_CREATOR<MYSQL_STORE> msc;
 }
 
-extern "C" STORE * GetStore();
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-//-----------------------------------------------------------------------------
-STORE * GetStore()
+extern "C" STG::Store* GetStore()
 {
-return msc.GetPlugin();
+    static MYSQL_STORE plugin;
+    return &plugin;
 }
 //-----------------------------------------------------------------------------
 MYSQL_STORE_SETTINGS::MYSQL_STORE_SETTINGS()
-    : settings(NULL),
-      errorStr(),
-      dbUser(),
-      dbPass(),
-      dbName(),
-      dbHost()
+    : settings(NULL)
+    , dbPort(0)
 {
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE_SETTINGS::ParseParam(const std::vector<PARAM_VALUE> & moduleParams, 
-                        const std::string & name, std::string & result)
+int MYSQL_STORE_SETTINGS::ParseParam(const std::vector<STG::ParamValue> & moduleParams,
+                                     const std::string & name, std::string & result)
 {
-PARAM_VALUE pv;
+STG::ParamValue pv;
 pv.param = name;
-std::vector<PARAM_VALUE>::const_iterator pvi;
+std::vector<STG::ParamValue>::const_iterator pvi;
 pvi = find(moduleParams.begin(), moduleParams.end(), pv);
-if (pvi == moduleParams.end())
+if (pvi == moduleParams.end() || pvi->value.empty())
     {
     errorStr = "Parameter \'" + name + "\' not found.";
     return -1;
     }
-    
+
 result = pvi->value[0];
 
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE_SETTINGS::ParseSettings(const MODULE_SETTINGS & s)
+int MYSQL_STORE_SETTINGS::ParseSettings(const STG::ModuleSettings & s)
 {
 if (ParseParam(s.moduleParams, "user", dbUser) < 0 &&
     ParseParam(s.moduleParams, "dbuser", dbUser) < 0)
@@ -153,17 +147,27 @@ if (ParseParam(s.moduleParams, "server", dbHost) < 0 &&
     ParseParam(s.moduleParams, "dbhost", dbHost) < 0)
     return -1;
 
+// not required
+std::string dbPortAsString;
+if (ParseParam(s.moduleParams, "port", dbPortAsString) == 0 ||
+    ParseParam(s.moduleParams, "dbport", dbPortAsString) == 0)
+{
+    if (GetInt<unsigned int>(dbPortAsString, &dbPort, 0) != 0)
+    {
+        errorStr = "Can't parse db port from string: \"" + dbPortAsString + "\"\n";
+        return -1;
+    }
+}
+
 return 0;
 }
 //-----------------------------------------------------------------------------
 //-----------------------------------------------------------------------------
 //-----------------------------------------------------------------------------
 MYSQL_STORE::MYSQL_STORE()
-    : errorStr(),
-      version("mysql_store v.0.67"),
-      storeSettings(),
-      settings(),
-      logger(GetPluginLogger(GetStgLogger(), "store_mysql"))
+    : version("mysql_store v.0.68"),
+      schemaVersion(0),
+      logger(STG::PluginLogger::get("store_mysql"))
 {
 }
 //-----------------------------------------------------------------------------
@@ -191,7 +195,6 @@ int MYSQL_STORE::ParseSettings()
 {
 int ret = storeSettings.ParseSettings(settings);
 MYSQL mysql;
-MYSQL * sock;
 mysql_init(&mysql);
 if (ret)
     errorStr = storeSettings.GetStrError();
@@ -202,10 +205,10 @@ else
         errorStr = "Database password must be not empty. Please read Manual.";
         return -1;
     }
-    
+    MYSQL * sock;
     if (!(sock = mysql_real_connect(&mysql,storeSettings.GetDBHost().c_str(),
             storeSettings.GetDBUser().c_str(),storeSettings.GetDBPassword().c_str(),
-            0,0,NULL,0)))
+            0,storeSettings.GetDBPort(),NULL,0)))
         {
             errorStr = "Couldn't connect to mysql engine! With error:\n";
             errorStr += mysql_error(&mysql);
@@ -229,12 +232,13 @@ else
             {
                  if(mysql_select_db(sock, storeSettings.GetDBName().c_str()))
                  {
-                    errorStr = "Couldn't select database! With error:\n";
-                    errorStr += mysql_error(sock);
-                    mysql_close(sock);
-                    ret = -1;
+                     errorStr = "Couldn't select database! With error:\n";
+                     errorStr += mysql_error(sock);
+                     mysql_close(sock);
+                     ret = -1;
                  }
-                 ret = CheckAllTables(sock);
+                 else
+                     ret = CheckAllTables(sock);
             }
         }
         else
@@ -261,7 +265,7 @@ if (!(result=mysql_list_tables(sock,str.c_str() )))
     errorStr = "Couldn't get tables list With error:\n";
     errorStr += mysql_error(sock);
     mysql_close(sock);
-    return -1;
+    return false;
 }
 
 my_ulonglong num_rows =  mysql_num_rows(result);
@@ -376,7 +380,9 @@ if(!IsTablePresent("tariffs",sock))
     
     res += "PassiveCost DOUBLE DEFAULT 0.0, Fee DOUBLE DEFAULT 0.0,"
         "Free DOUBLE DEFAULT 0.0, TraffType VARCHAR(10) DEFAULT '',"
-        "period VARCHAR(32) NOT NULL DEFAULT 'month')";
+        "period VARCHAR(32) NOT NULL DEFAULT 'month',"
+        "change_policy VARCHAR(32) NOT NULL DEFAULT 'allow',"
+        "change_policy_timeout TIMESTAMP NOT NULL DEFAULT 0)";
     
     if(MysqlQuery(res.c_str(),sock))
     {
@@ -430,7 +436,8 @@ if(!IsTablePresent("tariffs",sock))
     
     res += "PassiveCost=0.0, Fee=10.0, Free=0,"\
         "SinglePrice0=1, SinglePrice1=1,PriceDayA1=0.75,PriceDayB1=0.75,"\
-        "PriceNightA0=1.0,PriceNightB0=1.0,TraffType='up+down',period='month'";
+        "PriceNightA0=1.0,PriceNightB0=1.0,TraffType='up+down',period='month',"\
+        "change_policy='allow', change_policy_timeout=0";
     
     if(MysqlQuery(res.c_str(),sock))
     {
@@ -449,7 +456,7 @@ if(!IsTablePresent("tariffs",sock))
         mysql_close(sock);
         return -1;
     }
-    schemaVersion = 1;
+    schemaVersion = 2;
 }
 
 //users-----------------------------------------------------------------------
@@ -500,8 +507,13 @@ if(!IsTablePresent("users",sock))
     res = "INSERT INTO users SET login='test',Address='',AlwaysOnline=0,"\
         "Credit=0.0,CreditExpire=0,Down=0,Email='',DisabledDetailStat=0,"\
         "StgGroup='',IP='192.168.1.1',Note='',Passive=0,Password='123456',"\
-        "Phone='', RealName='',Tariff='tariff',TariffChange='',Userdata0='',"\
-        "Userdata1='',";
+        "Phone='', RealName='',Tariff='tariff',TariffChange='',NAS='',";
+    
+    for (int i = 0; i < USERDATA_NUM; i++)
+        {
+        strprintf(&param, " Userdata%d='',", i);
+        res += param;
+        }
     
     for (int i = 0; i < DIR_NUM; i++)
         {
@@ -602,6 +614,27 @@ if (schemaVersion  < 1)
     schemaVersion = 1;
     logger("MYSQL_STORE: Updated DB schema to version %d", schemaVersion);
     }
+
+if (schemaVersion  < 2)
+    {
+    if (MysqlQuery("ALTER TABLE tariffs ADD change_policy VARCHAR(32) NOT NULL DEFAULT 'allow'", sock) ||
+        MysqlQuery("ALTER TABLE tariffs ADD change_policy_timeout TIMESTAMP NOT NULL DEFAULT 0", sock))
+        {
+        errorStr = "Couldn't update tariffs table to version 2. With error:\n";
+        errorStr += mysql_error(sock);
+        mysql_close(sock);
+        return -1;
+        }
+    if (MysqlQuery("UPDATE info SET version = 2", sock))
+        {
+        errorStr = "Couldn't update DB schema version to 2. With error:\n";
+        errorStr += mysql_error(sock);
+        mysql_close(sock);
+        return -1;
+        }
+    schemaVersion = 2;
+    logger("MYSQL_STORE: Updated DB schema to version %d", schemaVersion);
+    }
 return 0;
 }
 //-----------------------------------------------------------------------------
@@ -676,9 +709,12 @@ return 0;
 //-----------------------------------------------------------------------------
 int MYSQL_STORE::AddUser(const std::string & login) const
 {
-sprintf(qbuf,"INSERT INTO users SET login='%s'", login.c_str());
-    
-if(MysqlSetQuery(qbuf))
+std::string query = "INSERT INTO users SET login='" + login + "',Note='',NAS=''";
+
+for (int i = 0; i < USERDATA_NUM; i++)
+    query += ",Userdata" + std::to_string(i) + "=''";
+
+if(MysqlSetQuery(query.c_str()))
 {
     errorStr = "Couldn't add user:\n";
     //errorStr += mysql_error(sock);
@@ -702,7 +738,7 @@ if(MysqlSetQuery(qbuf))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::RestoreUserConf(USER_CONF * conf, const std::string & login) const
+int MYSQL_STORE::RestoreUserConf(STG::UserConf * conf, const std::string & login) const
 {
 MYSQL_RES *res;
 MYSQL_ROW row;
@@ -749,8 +785,6 @@ if (mysql_num_rows(res) != 1)
 
 row = mysql_fetch_row(res);
 
-std::string param;
-
 conf->password = row[1];
 
 if (conf->password.empty())
@@ -828,10 +862,10 @@ for (int i = 0; i < USERDATA_NUM; i++)
 GetTime(row[15+USERDATA_NUM], &conf->creditExpire, 0);
     
 std::string ipStr = row[16+USERDATA_NUM];
-USER_IPS i;
+STG::UserIPs i;
 try
     {
-    i = StrToIPS(ipStr);
+    i = STG::UserIPs::parse(ipStr);
     }
 catch (const std::string & s)
     {
@@ -848,7 +882,7 @@ mysql_close(sock);
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::RestoreUserStat(USER_STAT * stat, const std::string & login) const
+int MYSQL_STORE::RestoreUserStat(STG::UserStat * stat, const std::string & login) const
 {
 MYSQL_RES *res;
 MYSQL_ROW row;
@@ -971,7 +1005,7 @@ mysql_close(sock);
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::SaveUserConf(const USER_CONF & conf, const std::string & login) const
+int MYSQL_STORE::SaveUserConf(const STG::UserConf & conf, const std::string & login) const
 {
 std::string param;
 std::string res;
@@ -1024,7 +1058,7 @@ if(MysqlSetQuery(res.c_str()))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::SaveUserStat(const USER_STAT & stat, const std::string & login) const
+int MYSQL_STORE::SaveUserStat(const STG::UserStat & stat, const std::string & login) const
 {
 std::string param;
 std::string res;
@@ -1148,10 +1182,10 @@ return WriteLogString(logStr, login);
 }
 //-----------------------------------------------------------------------------
 int MYSQL_STORE::WriteUserDisconnect(const std::string & login,
-                                     const DIR_TRAFF & up,
-                                     const DIR_TRAFF & down,
-                                     const DIR_TRAFF & sessionUp,
-                                     const DIR_TRAFF & sessionDown,
+                                     const STG::DirTraff & up,
+                                     const STG::DirTraff & down,
+                                     const STG::DirTraff & sessionUp,
+                                     const STG::DirTraff & sessionDown,
                                      double cash,
                                      double /*freeMb*/,
                                      const std::string & /*reason*/) const
@@ -1186,7 +1220,7 @@ logStr += "\'";
 return WriteLogString(logStr, login);
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::SaveMonthStat(const USER_STAT & stat, int month, int year, 
+int MYSQL_STORE::SaveMonthStat(const STG::UserStat & stat, int month, int year, 
                                 const std::string & login) const
 {
 std::string param, res;
@@ -1244,7 +1278,7 @@ if(MysqlSetQuery(qbuf))
 return 0;
 }
 //-----------------------------------------------------------------------------*/
-int MYSQL_STORE::SaveAdmin(const ADMIN_CONF & ac) const
+int MYSQL_STORE::SaveAdmin(const STG::AdminConf & ac) const
 {
 char passwordE[2 * ADM_PASSWD_LEN + 2];
 char pass[ADM_PASSWD_LEN + 1];
@@ -1254,14 +1288,14 @@ memset(pass, 0, sizeof(pass));
 memset(adminPass, 0, sizeof(adminPass));
 
 BLOWFISH_CTX ctx;
-EnDecodeInit(adm_enc_passwd, strlen(adm_enc_passwd), &ctx);
+InitContext(adm_enc_passwd, strlen(adm_enc_passwd), &ctx);
 
 strncpy(adminPass, ac.password.c_str(), ADM_PASSWD_LEN);
 adminPass[ADM_PASSWD_LEN - 1] = 0;
 
 for (int i = 0; i < ADM_PASSWD_LEN/8; i++)
     {
-    EncodeString(pass + 8*i, adminPass + 8*i, &ctx);
+    EncryptBlock(pass + 8*i, adminPass + 8*i, &ctx);
     }
 
 pass[ADM_PASSWD_LEN - 1] = 0;
@@ -1291,16 +1325,14 @@ if(MysqlSetQuery(qbuf))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::RestoreAdmin(ADMIN_CONF * ac, const std::string & login) const
+int MYSQL_STORE::RestoreAdmin(STG::AdminConf * ac, const std::string & login) const
 {
 char pass[ADM_PASSWD_LEN + 1];
 char password[ADM_PASSWD_LEN + 1];
 char passwordE[2*ADM_PASSWD_LEN + 2];
 BLOWFISH_CTX ctx;
 
-memset(pass, 0, sizeof(pass));
 memset(password, 0, sizeof(password));
-memset(passwordE, 0, sizeof(passwordE));
 
 std::string p;
 MYSQL_RES *res;
@@ -1352,11 +1384,11 @@ memset(pass, 0, sizeof(pass));
 if (passwordE[0] != 0)
     {
     Decode21(pass, passwordE);
-    EnDecodeInit(adm_enc_passwd, strlen(adm_enc_passwd), &ctx);
+    InitContext(adm_enc_passwd, strlen(adm_enc_passwd), &ctx);
 
     for (int i = 0; i < ADM_PASSWD_LEN/8; i++)
         {
-        DecodeString(password + 8*i, pass + 8*i, &ctx);
+        DecryptBlock(password + 8*i, pass + 8*i, &ctx);
         }
     }
 else
@@ -1471,7 +1503,7 @@ if(MysqlSetQuery(qbuf))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::RestoreTariff(TARIFF_DATA * td, const std::string & tariffName) const
+int MYSQL_STORE::RestoreTariff(STG::TariffData * td, const std::string & tariffName) const
 {
 MYSQL_RES *res;
 MYSQL_ROW row;
@@ -1621,24 +1653,7 @@ if (GetDouble(row[1+8*DIR_NUM], &td->tariffConf.passiveCost, 0.0) < 0)
         return -1;
         }
 
-if (!strcasecmp(str.c_str(), "up"))
-    td->tariffConf.traffType = TRAFF_UP;
-else
-    if (!strcasecmp(str.c_str(), "down"))
-        td->tariffConf.traffType = TRAFF_DOWN;
-    else
-        if (!strcasecmp(str.c_str(), "up+down"))
-            td->tariffConf.traffType = TRAFF_UP_DOWN;
-        else
-            if (!strcasecmp(str.c_str(), "max"))
-                td->tariffConf.traffType = TRAFF_MAX;
-            else
-                {
-                mysql_free_result(res);
-                errorStr = "Cannot read tariff " + tariffName + ". Parameter TraffType incorrect";
-                mysql_close(sock);
-                return -1;
-                }
+td->tariffConf.traffType = STG::Tariff::parseTraffType(str);
 
 if (schemaVersion > 0)
 {
@@ -1653,11 +1668,45 @@ if (schemaVersion > 0)
         return -1;
         }
 
-    td->tariffConf.period = TARIFF::StringToPeriod(str);
+    td->tariffConf.period = STG::Tariff::parsePeriod(str);
+    }
+else
+    {
+    td->tariffConf.period = STG::Tariff::MONTH;
+    }
+
+if (schemaVersion > 1)
+    {
+    str = row[6+8*DIR_NUM];
+    param = "ChangePolicy";
+
+    if (str.length() == 0)
+        {
+        mysql_free_result(res);
+        errorStr = "Cannot read tariff " + tariffName + ". Parameter " + param;
+        mysql_close(sock);
+        return -1;
+        }
+
+    td->tariffConf.changePolicy = STG::Tariff::parseChangePolicy(str);
+
+    str = row[7+8*DIR_NUM];
+    param = "ChangePolicyTimeout";
+
+    if (str.length() == 0)
+        {
+        mysql_free_result(res);
+        errorStr = "Cannot read tariff " + tariffName + ". Parameter " + param;
+        mysql_close(sock);
+        return -1;
+        }
+
+    td->tariffConf.changePolicyTimeout = readTime(str);
     }
 else
     {
-    td->tariffConf.period = TARIFF::MONTH;
+    td->tariffConf.changePolicy = STG::Tariff::ALLOW;
+    td->tariffConf.changePolicyTimeout = 0;
     }
 
 mysql_free_result(res);
@@ -1665,7 +1714,7 @@ mysql_close(sock);
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::SaveTariff(const TARIFF_DATA & td, const std::string & tariffName) const
+int MYSQL_STORE::SaveTariff(const STG::TariffData & td, const std::string & tariffName) const
 {
 std::string param;
 
@@ -1722,24 +1771,14 @@ res += param;
 strprintf(&param, " Free=%f,", td.tariffConf.free);
 res += param;
 
-switch (td.tariffConf.traffType)
-    {
-    case TRAFF_UP:
-        res += " TraffType='up'";
-        break;
-    case TRAFF_DOWN:
-        res += " TraffType='down'";
-        break;
-    case TRAFF_UP_DOWN:
-        res += " TraffType='up+down'";
-        break;
-    case TRAFF_MAX:
-        res += " TraffType='max'";
-        break;
-    }
+res += " TraffType='" + STG::Tariff::toString(td.tariffConf.traffType) + "'";
 
 if (schemaVersion > 0)
-    res += ", Period='" + TARIFF::PeriodToString(td.tariffConf.period) + "'";
+    res += ", Period='" + STG::Tariff::toString(td.tariffConf.period) + "'";
+
+if (schemaVersion > 1)
+    res += ", change_policy='" + STG::Tariff::toString(td.tariffConf.changePolicy) + "'"\
+           ", change_policy_timeout='" + formatTime(td.tariffConf.changePolicy) + "'";
 
 strprintf(&param, " WHERE name='%s' LIMIT 1", tariffName.c_str());
 res += param;
@@ -1754,7 +1793,7 @@ if(MysqlSetQuery(res.c_str()))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::WriteDetailedStat(const std::map<IP_DIR_PAIR, STAT_NODE> & statTree, 
+int MYSQL_STORE::WriteDetailedStat(const STG::TraffStat & statTree, 
                                    time_t lastStat, 
                                    const std::string & login) const
 {
@@ -1839,7 +1878,7 @@ strprintf(&res,"INSERT INTO detailstat_%02d_%4d SET login='%s',"\
     endTime.c_str()
     );
 
-std::map<IP_DIR_PAIR, STAT_NODE>::const_iterator stIter;
+STG::TraffStat::const_iterator stIter;
 stIter = statTree.begin();
 
 while (stIter != statTree.end())
@@ -1870,7 +1909,7 @@ mysql_close(sock);
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::AddMessage(STG_MSG * msg, const std::string & login) const
+int MYSQL_STORE::AddMessage(STG::Message * msg, const std::string & login) const
 {
 struct timeval tv;
 
@@ -1893,7 +1932,7 @@ if(MysqlSetQuery(qbuf))
 return EditMessage(*msg, login);
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::EditMessage(const STG_MSG & msg, const std::string & login) const
+int MYSQL_STORE::EditMessage(const STG::Message & msg, const std::string & login) const
 {
 std::string res;
 
@@ -1921,7 +1960,7 @@ if(MysqlSetQuery(res.c_str()))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::GetMessage(uint64_t id, STG_MSG * msg, const std::string & login) const
+int MYSQL_STORE::GetMessage(uint64_t id, STG::Message * msg, const std::string & login) const
 {
 MYSQL_RES *res;
 MYSQL_ROW row;
@@ -2019,7 +2058,7 @@ if(MysqlSetQuery(qbuf))
 return 0;
 }
 //-----------------------------------------------------------------------------
-int MYSQL_STORE::GetMessageHdrs(std::vector<STG_MSG_HDR> * hdrsList, const std::string & login) const
+int MYSQL_STORE::GetMessageHdrs(std::vector<STG::Message::Header> * hdrsList, const std::string & login) const
 {
 MYSQL_RES *res;
 MYSQL_ROW row;
@@ -2052,7 +2091,7 @@ for (i = 0; i < num_rows; i++)
     if (str2x(row[1], id))
         continue;
     
-    STG_MSG_HDR hdr;
+    STG::Message::Header hdr;
     if (row[2]) 
         if(str2x(row[2], hdr.type))
             continue;
@@ -2110,7 +2149,7 @@ MYSQL *  MYSQL_STORE::MysqlConnect() const {
     }
     if (!(sock = mysql_real_connect(sock,storeSettings.GetDBHost().c_str(),
             storeSettings.GetDBUser().c_str(),storeSettings.GetDBPassword().c_str(),
-            0,0,NULL,0)))
+            0,storeSettings.GetDBPort(),NULL,0)))
         {
             errorStr = "Couldn't connect to mysql engine! With error:\n";
             errorStr += mysql_error(sock);