X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/72229403aae25f742c07d07d625bdc1e313b401d..6a79784ca465afb804fd43a233f6a02e4ca894d9:/stglibs/srvconf.lib/netunit.cpp?ds=inline diff --git a/stglibs/srvconf.lib/netunit.cpp b/stglibs/srvconf.lib/netunit.cpp index 1a21295b..baf02004 100644 --- a/stglibs/srvconf.lib/netunit.cpp +++ b/stglibs/srvconf.lib/netunit.cpp @@ -23,12 +23,14 @@ #include "stg/servconf_types.h" #include "stg/common.h" #include "stg/blowfish.h" +#include "stg/bfstream.h" #include // std::min #include #include #include +#include #include #include @@ -45,6 +47,14 @@ namespace const std::string::size_type MAX_XML_CHUNK_LENGTH = 2048; +struct ReadState +{ + bool final; + NETTRANSACT::CALLBACK callback; + void * callbackData; + NETTRANSACT * nt; +}; + } //--------------------------------------------------------------------------- @@ -71,7 +81,7 @@ NETTRANSACT::NETTRANSACT(const std::string & s, uint16_t p, localPort(0), login(l), password(pwd), - outerSocket(-1) + sock(-1) { } //--------------------------------------------------------------------------- @@ -84,14 +94,19 @@ NETTRANSACT::NETTRANSACT(const std::string & s, uint16_t p, localPort(lp), login(l), password(pwd), - outerSocket(-1) + sock(-1) +{ +} +//--------------------------------------------------------------------------- +NETTRANSACT::~NETTRANSACT() { +Disconnect(); } //--------------------------------------------------------------------------- int NETTRANSACT::Connect() { -outerSocket = socket(PF_INET, SOCK_STREAM, 0); -if (outerSocket < 0) +sock = socket(PF_INET, SOCK_STREAM, 0); +if (sock < 0) { errorMsg = CREATE_SOCKET_ERROR; return st_conn_fail; @@ -124,10 +139,9 @@ if (!localAddress.empty()) localAddr.sin_port = htons(localPort); localAddr.sin_addr.s_addr = ip; - if (bind(outerSocket, (struct sockaddr *)&localAddr, sizeof(localAddr)) < 0) + if (bind(sock, (struct sockaddr *)&localAddr, sizeof(localAddr)) < 0) { errorMsg = BIND_FAILED; - close(outerSocket); return st_conn_fail; } } @@ -155,10 +169,9 @@ outerAddr.sin_family = AF_INET; outerAddr.sin_port = htons(port); outerAddr.sin_addr.s_addr = ip; -if (connect(outerSocket, (struct sockaddr *)&outerAddr, sizeof(outerAddr)) < 0) +if (connect(sock, (struct sockaddr *)&outerAddr, sizeof(outerAddr)) < 0) { errorMsg = CONNECT_FAILED; - close(outerSocket); return st_conn_fail; } @@ -167,66 +180,47 @@ return st_ok; //--------------------------------------------------------------------------- void NETTRANSACT::Disconnect() { -close(outerSocket); +if (sock != -1) + { + shutdown(sock, SHUT_RDWR); + close(sock); + sock = -1; + } } //--------------------------------------------------------------------------- int NETTRANSACT::Transact(const std::string & request, CALLBACK callback, void * data) { int ret; if ((ret = TxHeader()) != st_ok) - { - Disconnect(); return ret; - } if ((ret = RxHeaderAnswer()) != st_ok) - { - Disconnect(); return ret; - } if ((ret = TxLogin()) != st_ok) - { - Disconnect(); return ret; - } if ((ret = RxLoginAnswer()) != st_ok) - { - Disconnect(); return ret; - } if ((ret = TxLoginS()) != st_ok) - { - Disconnect(); return ret; - } if ((ret = RxLoginSAnswer()) != st_ok) - { - Disconnect(); return ret; - } if ((ret = TxData(request)) != st_ok) - { - Disconnect(); return ret; - } if ((ret = RxDataAnswer(callback, data)) != st_ok) - { - Disconnect(); return ret; - } return st_ok; } //--------------------------------------------------------------------------- int NETTRANSACT::TxHeader() { -if (send(outerSocket, STG_HEADER, strlen(STG_HEADER), 0) <= 0) +if (!WriteAll(sock, STG_HEADER, strlen(STG_HEADER))) { errorMsg = SEND_HEADER_ERROR; return st_send_fail; @@ -239,7 +233,7 @@ int NETTRANSACT::RxHeaderAnswer() { char buffer[sizeof(STG_HEADER) + 1]; -if (recv(outerSocket, buffer, strlen(OK_HEADER), 0) <= 0) +if (!ReadAll(sock, buffer, strlen(OK_HEADER))) { printf("Receive header answer error: '%s'\n", strerror(errno)); errorMsg = RECV_HEADER_ANSWER_ERROR; @@ -247,31 +241,27 @@ if (recv(outerSocket, buffer, strlen(OK_HEADER), 0) <= 0) } if (strncmp(OK_HEADER, buffer, strlen(OK_HEADER)) == 0) - { return st_ok; + +if (strncmp(ERR_HEADER, buffer, strlen(ERR_HEADER)) == 0) + { + errorMsg = INCORRECT_HEADER; + return st_header_err; } else { - if (strncmp(ERR_HEADER, buffer, strlen(ERR_HEADER)) == 0) - { - errorMsg = INCORRECT_HEADER; - return st_header_err; - } - else - { - errorMsg = UNKNOWN_ERROR; - return st_unknown_err; - } + errorMsg = UNKNOWN_ERROR; + return st_unknown_err; } } //--------------------------------------------------------------------------- int NETTRANSACT::TxLogin() { -char loginZ[ADM_LOGIN_LEN]; -memset(loginZ, 0, ADM_LOGIN_LEN); +char loginZ[ADM_LOGIN_LEN + 1]; +memset(loginZ, 0, ADM_LOGIN_LEN + 1); strncpy(loginZ, login.c_str(), ADM_LOGIN_LEN); -if (send(outerSocket, loginZ, ADM_LOGIN_LEN, 0) <= 0) +if (!WriteAll(sock, loginZ, ADM_LOGIN_LEN)) { errorMsg = SEND_LOGIN_ERROR; return st_send_fail; @@ -284,7 +274,7 @@ int NETTRANSACT::RxLoginAnswer() { char buffer[sizeof(OK_LOGIN) + 1]; -if (recv(outerSocket, buffer, strlen(OK_LOGIN), 0) <= 0) +if (!ReadAll(sock, buffer, strlen(OK_LOGIN))) { printf("Receive login answer error: '%s'\n", strerror(errno)); errorMsg = RECV_LOGIN_ANSWER_ERROR; @@ -292,42 +282,32 @@ if (recv(outerSocket, buffer, strlen(OK_LOGIN), 0) <= 0) } if (strncmp(OK_LOGIN, buffer, strlen(OK_LOGIN)) == 0) - { return st_ok; + +if (strncmp(ERR_LOGIN, buffer, strlen(ERR_LOGIN)) == 0) + { + errorMsg = INCORRECT_LOGIN; + return st_login_err; } else { - if (strncmp(ERR_LOGIN, buffer, strlen(ERR_LOGIN)) == 0) - { - errorMsg = INCORRECT_LOGIN; - return st_login_err; - } - else - { - errorMsg = UNKNOWN_ERROR; - return st_unknown_err; - } + errorMsg = UNKNOWN_ERROR; + return st_unknown_err; } } //--------------------------------------------------------------------------- int NETTRANSACT::TxLoginS() { -char loginZ[ADM_LOGIN_LEN]; -memset(loginZ, 0, ADM_LOGIN_LEN); -strncpy(loginZ, login.c_str(), ADM_LOGIN_LEN); +char loginZ[ADM_LOGIN_LEN + 1]; +memset(loginZ, 0, ADM_LOGIN_LEN + 1); BLOWFISH_CTX ctx; -EnDecodeInit(password.c_str(), PASSWD_LEN, &ctx); - -for (int j = 0; j < ADM_LOGIN_LEN / ENC_MSG_LEN; j++) +InitContext(password.c_str(), PASSWD_LEN, &ctx); +EncryptString(loginZ, login.c_str(), std::min(login.length() + 1, ADM_LOGIN_LEN), &ctx); +if (!WriteAll(sock, loginZ, ADM_LOGIN_LEN)) { - char ct[ENC_MSG_LEN]; - EncodeString(ct, loginZ + j * ENC_MSG_LEN, &ctx); - if (send(outerSocket, ct, ENC_MSG_LEN, 0) <= 0) - { - errorMsg = SEND_LOGIN_ERROR; - return st_send_fail; - } + errorMsg = SEND_LOGIN_ERROR; + return st_send_fail; } return st_ok; @@ -337,7 +317,7 @@ int NETTRANSACT::RxLoginSAnswer() { char buffer[sizeof(OK_LOGINS) + 1]; -if (recv(outerSocket, buffer, strlen(OK_LOGINS), 0) <= 0) +if (!ReadAll(sock, buffer, strlen(OK_LOGINS))) { printf("Receive secret login answer error: '%s'\n", strerror(errno)); errorMsg = RECV_LOGIN_ANSWER_ERROR; @@ -345,44 +325,28 @@ if (recv(outerSocket, buffer, strlen(OK_LOGINS), 0) <= 0) } if (strncmp(OK_LOGINS, buffer, strlen(OK_LOGINS)) == 0) - { return st_ok; + +if (strncmp(ERR_LOGINS, buffer, strlen(ERR_LOGINS)) == 0) + { + errorMsg = INCORRECT_LOGIN; + return st_logins_err; } else { - if (strncmp(ERR_LOGINS, buffer, strlen(ERR_LOGINS)) == 0) - { - errorMsg = INCORRECT_LOGIN; - return st_logins_err; - } - else - { - errorMsg = UNKNOWN_ERROR; - return st_unknown_err; - } + errorMsg = UNKNOWN_ERROR; + return st_unknown_err; } } //--------------------------------------------------------------------------- int NETTRANSACT::TxData(const std::string & text) { -BLOWFISH_CTX ctx; -EnDecodeInit(password.c_str(), PASSWD_LEN, &ctx); - -size_t pos = 0; -while (pos < text.size()) +STG::ENCRYPT_STREAM stream(password, TxCrypto, this); +stream.Put(text.c_str(), text.length() + 1, true); +if (!stream.isOk()) { - char textZ[ENC_MSG_LEN]; - if (text.size() - pos < ENC_MSG_LEN) - memset(textZ, 0, ENC_MSG_LEN); - strncpy(textZ, text.c_str() + pos, std::min(ENC_MSG_LEN, (int)(text.size() - pos))); - char ct[ENC_MSG_LEN]; - EncodeString(ct, textZ, &ctx); - if (send(outerSocket, ct, ENC_MSG_LEN, 0) <= 0) - { - errorMsg = SEND_DATA_ERROR; - return st_send_fail; - } - pos += ENC_MSG_LEN; + errorMsg = SEND_DATA_ERROR; + return st_send_fail; } return st_ok; @@ -390,48 +354,51 @@ return st_ok; //--------------------------------------------------------------------------- int NETTRANSACT::RxDataAnswer(CALLBACK callback, void * data) { -BLOWFISH_CTX ctx; -EnDecodeInit(password.c_str(), PASSWD_LEN, &ctx); - -std::string chunk; -while (true) +ReadState state = {false, callback, data, this}; +STG::DECRYPT_STREAM stream(password, RxCrypto, &state); +while (!state.final) { - char bufferS[ENC_MSG_LEN]; - size_t toRead = ENC_MSG_LEN; - while (toRead > 0) + char buffer[1024]; + ssize_t res = read(sock, buffer, sizeof(buffer)); + if (res < 0) { - int ret = recv(outerSocket, &bufferS[ENC_MSG_LEN - toRead], toRead, 0); - if (ret <= 0) - { - printf("Receive data error: '%s'\n", strerror(errno)); - close(outerSocket); - errorMsg = RECV_DATA_ANSWER_ERROR; - return st_recv_fail; - } - toRead -= ret; + printf("Receive data error: '%s'\n", strerror(errno)); + errorMsg = RECV_DATA_ANSWER_ERROR; + return st_recv_fail; } + stream.Put(buffer, res, res == 0); + if (!stream.isOk()) + return st_xml_parse_error; + } - char buffer[ENC_MSG_LEN]; - DecodeString(buffer, bufferS, &ctx); - - bool final = false; - size_t pos = 0; - for (; pos < ENC_MSG_LEN && buffer[pos] != 0; pos++) ; - if (pos < ENC_MSG_LEN && buffer[pos] == 0) - final = true; - - if (pos > 0) - chunk.append(&buffer[0], &buffer[pos]); +return st_ok; +} +//--------------------------------------------------------------------------- +bool NETTRANSACT::TxCrypto(const void * block, size_t size, void * data) +{ +assert(data != NULL); +NETTRANSACT & nt = *static_cast(data); +if (!WriteAll(nt.sock, block, size)) + return false; +return true; +} +//--------------------------------------------------------------------------- +bool NETTRANSACT::RxCrypto(const void * block, size_t size, void * data) +{ +assert(data != NULL); +ReadState & state = *static_cast(data); - if (chunk.length() > MAX_XML_CHUNK_LENGTH || final) +const char * buffer = static_cast(block); +for (size_t pos = 0; pos < size; ++pos) + if (buffer[pos] == 0) { - if (callback) - if (!callback(chunk, final, data)) - return st_xml_parse_error; - chunk.clear(); + state.final = true; + size = pos; // Adjust string size } - if (final) - return st_ok; - } +if (state.callback) + if (!state.callback(std::string(buffer, size), state.final, state.callbackData)) + return false; + +return true; }