X-Git-Url: https://git.stg.codes/stg.git/blobdiff_plain/3ca84ef40f3f45d4e13e36651bde13090dc4051a..5d5457b414d2dd5ee77d2a17427e2c03d30ac176:/tests/test_bfstream.cpp diff --git a/tests/test_bfstream.cpp b/tests/test_bfstream.cpp index 595301b2..892d5759 100644 --- a/tests/test_bfstream.cpp +++ b/tests/test_bfstream.cpp @@ -1,5 +1,7 @@ #include "tut/tut.hpp" +#include "longstring.h" + #include "stg/bfstream.h" #include "stg/os_int.h" @@ -15,33 +17,86 @@ class TRACKER public: TRACKER() : m_lastSize(0), m_callCount(0), m_lastBlock(NULL) {} ~TRACKER() { delete[] m_lastBlock; } - void Call(const void * block, size_t size) + bool Call(const void * block, size_t size) { delete[] m_lastBlock; if (size > 0) { m_lastBlock = new char[size]; memcpy(m_lastBlock, block, size); + m_result.append(m_lastBlock, size); } else m_lastBlock = NULL; m_lastSize = size; ++m_callCount; + return true; } size_t LastSize() const { return m_lastSize; } size_t CallCount() const { return m_callCount; } const void * LastBlock() const { return m_lastBlock; } + const std::string& Result() const { return m_result; } + private: size_t m_lastSize; size_t m_callCount; char * m_lastBlock; + + std::string m_result; +}; + +bool DecryptCallback(const void * block, size_t size, void * data); + +class Decryptor +{ + public: + Decryptor(const std::string & key) + : m_stream(key, DecryptCallback, this) + {} + + bool Call(const void * block, size_t size) + { + m_stream.Put(block, size); + return true; + } + + bool Put(const void * block, size_t size) + { + const char * data = static_cast(block); + size = strnlen(data, size); + m_result.append(data, size); + return true; + } + + void Flush() + { + m_stream.Put(NULL, 0); + } + + const std::string & Result() const { return m_result; } + + private: + STG::DECRYPT_STREAM m_stream; + std::string m_result; }; -void Callback(const void * block, size_t size, void * data) +bool EncryptCallback(const void * block, size_t size, void * data) +{ +Decryptor & decryptor = *static_cast(data); +return decryptor.Call(block, size); +} + +bool DecryptCallback(const void * block, size_t size, void * data) +{ +Decryptor & decryptor = *static_cast(data); +return decryptor.Put(block, size); +} + +bool Callback(const void * block, size_t size, void * data) { TRACKER & tracker = *static_cast(data); -tracker.Call(block, size); +return tracker.Call(block, size); } } @@ -138,4 +193,47 @@ namespace tut ensure_equals("Decrypt(Encrypt(source)) == source", std::string(buffer), source); } + template<> + template<> + void testobject::test<4>() + { + set_test_name("Check bfstream very long string processing"); + + Decryptor decryptor("pr7Hhen"); + STG::ENCRYPT_STREAM estream("pr7Hhen", EncryptCallback, &decryptor); + //char buffer[source.length() + 9]; + //memset(buffer, 0, sizeof(buffer)); + + estream.Put(longString.c_str(), longString.length() + 1, true); + //ensure("Encryption long string LastSize()", tracker.LastSize() >= source.length() + 1); + //ensure("Encryption long string LastBlock() != NULL", tracker.LastBlock() != NULL); + //memcpy(buffer, tracker.LastBlock(), std::min(tracker.LastSize(), sizeof(buffer))); + + //dstream.Put(buffer, sizeof(buffer), true); + //ensure("Decryption long string LastSize() decryption", tracker.LastSize() >= sizeof(buffer)); + //ensure("Decryption long string LastBlock() != NULL", tracker.LastBlock() != NULL); + //memcpy(buffer, tracker.LastBlock(), std::min(tracker.LastSize(), sizeof(buffer))); + + ensure_equals("Decrypt(Encrypt(source)) == source", decryptor.Result(), longString); + } + + template<> + template<> + void testobject::test<5>() + { + set_test_name("Check bfstream mechanics"); + + TRACKER tracker; + STG::ENCRYPT_STREAM stream("pr7Hhen", Callback, &tracker); + ensure_equals("CallCount() == 0 after construction", tracker.CallCount(), 0); + + uint32_t block[2] = {0x12345678, 0x87654321}; + stream.Put(&block[0], sizeof(block[0])); + ensure_equals("CallCount() == 0 after first put", tracker.CallCount(), 0); + stream.Put(&block[1], sizeof(block[1])); + ensure_equals("CallCount() == 1 after second put", tracker.CallCount(), 1); + stream.Put(&block[0], 0, true); // Check last callback + ensure_equals("CallCount() == 2 after third (null) put", tracker.CallCount(), 2); + } + }