void Put(const void * data, size_t size, bool last)
{
size_t dataSize = m_ptr - m_buffer;
- if (dataSize + size > sizeof(m_buffer))
+ while (dataSize + size > sizeof(m_buffer))
{
memcpy(m_ptr, data, sizeof(m_buffer) - dataSize); // Fill buffer
size -= sizeof(m_buffer) - dataSize; // Adjust size
m_proc(m_buffer, m_buffer, sizeof(m_buffer), &m_ctx); // Process
m_ok = m_ok && m_callback(m_buffer, sizeof(m_buffer), m_data); // Consume
m_ptr = m_buffer;
+ dataSize = 0;
}
if (!m_ok)
return;
dataSize += 8;
remainder = 0;
}
- if (dataSize == 0)
+ if (!last && dataSize == 0) // Allow to call callback with 0 data on last call.
return;
m_proc(m_buffer, m_buffer, dataSize, &m_ctx);
m_ok = m_ok && m_callback(m_buffer, dataSize, m_data);
#include "tut/tut.hpp"
+#include "longstring.h"
+
#include "stg/bfstream.h"
#include "stg/os_int.h"
{
m_lastBlock = new char[size];
memcpy(m_lastBlock, block, size);
+ m_result.append(m_lastBlock, size);
}
else
m_lastBlock = NULL;
size_t CallCount() const { return m_callCount; }
const void * LastBlock() const { return m_lastBlock; }
+ const std::string& Result() const { return m_result; }
+
private:
size_t m_lastSize;
size_t m_callCount;
char * m_lastBlock;
+
+ std::string m_result;
+};
+
+bool DecryptCallback(const void * block, size_t size, void * data);
+
+class Decryptor
+{
+ public:
+ Decryptor(const std::string & key)
+ : m_stream(key, DecryptCallback, this)
+ {}
+
+ bool Call(const void * block, size_t size)
+ {
+ m_stream.Put(block, size);
+ return true;
+ }
+
+ bool Put(const void * block, size_t size)
+ {
+ const char * data = static_cast<const char *>(block);
+ size = strnlen(data, size);
+ m_result.append(data, size);
+ return true;
+ }
+
+ void Flush()
+ {
+ m_stream.Put(NULL, 0);
+ }
+
+ const std::string & Result() const { return m_result; }
+
+ private:
+ STG::DECRYPT_STREAM m_stream;
+ std::string m_result;
};
+bool EncryptCallback(const void * block, size_t size, void * data)
+{
+Decryptor & decryptor = *static_cast<Decryptor *>(data);
+return decryptor.Call(block, size);
+}
+
+bool DecryptCallback(const void * block, size_t size, void * data)
+{
+Decryptor & decryptor = *static_cast<Decryptor *>(data);
+return decryptor.Put(block, size);
+}
+
bool Callback(const void * block, size_t size, void * data)
{
TRACKER & tracker = *static_cast<TRACKER *>(data);
ensure_equals("Decrypt(Encrypt(source)) == source", std::string(buffer), source);
}
+ template<>
+ template<>
+ void testobject::test<4>()
+ {
+ set_test_name("Check bfstream very long string processing");
+
+ Decryptor decryptor("pr7Hhen");
+ STG::ENCRYPT_STREAM estream("pr7Hhen", EncryptCallback, &decryptor);
+ //char buffer[source.length() + 9];
+ //memset(buffer, 0, sizeof(buffer));
+
+ estream.Put(longString.c_str(), longString.length() + 1, true);
+ //ensure("Encryption long string LastSize()", tracker.LastSize() >= source.length() + 1);
+ //ensure("Encryption long string LastBlock() != NULL", tracker.LastBlock() != NULL);
+ //memcpy(buffer, tracker.LastBlock(), std::min(tracker.LastSize(), sizeof(buffer)));
+
+ //dstream.Put(buffer, sizeof(buffer), true);
+ //ensure("Decryption long string LastSize() decryption", tracker.LastSize() >= sizeof(buffer));
+ //ensure("Decryption long string LastBlock() != NULL", tracker.LastBlock() != NULL);
+ //memcpy(buffer, tracker.LastBlock(), std::min(tracker.LastSize(), sizeof(buffer)));
+
+ ensure_equals("Decrypt(Encrypt(source)) == source", decryptor.Result(), longString);
+ }
+
+ template<>
+ template<>
+ void testobject::test<5>()
+ {
+ set_test_name("Check bfstream mechanics");
+
+ TRACKER tracker;
+ STG::ENCRYPT_STREAM stream("pr7Hhen", Callback, &tracker);
+ ensure_equals("CallCount() == 0 after construction", tracker.CallCount(), 0);
+
+ uint32_t block[2] = {0x12345678, 0x87654321};
+ stream.Put(&block[0], sizeof(block[0]));
+ ensure_equals("CallCount() == 0 after first put", tracker.CallCount(), 0);
+ stream.Put(&block[1], sizeof(block[1]));
+ ensure_equals("CallCount() == 1 after second put", tracker.CallCount(), 1);
+ stream.Put(&block[0], 0, true); // Check last callback
+ ensure_equals("CallCount() == 2 after third (null) put", tracker.CallCount(), 2);
+ }
+
}