]> git.stg.codes - stg.git/blob - tests/test_bfstream.cpp
Handle string for encryption properly.
[stg.git] / tests / test_bfstream.cpp
1 #include "tut/tut.hpp"
2
3 #include "stg/bfstream.h"
4 #include "stg/os_int.h"
5
6 #include <algorithm>
7 #include <string>
8 #include <cstring>
9
10 namespace
11 {
12
13 class TRACKER
14 {
15     public:
16         TRACKER() : m_lastSize(0), m_callCount(0), m_lastBlock(NULL) {}
17         ~TRACKER() { delete[] m_lastBlock; }
18         bool Call(const void * block, size_t size)
19         {
20         delete[] m_lastBlock;
21         if (size > 0)
22             {
23             m_lastBlock = new char[size];
24             memcpy(m_lastBlock,  block, size);
25             }
26         else
27             m_lastBlock = NULL;
28         m_lastSize = size;
29         ++m_callCount;
30         return true;
31         }
32         size_t LastSize() const { return m_lastSize; }
33         size_t CallCount() const { return m_callCount; }
34         const void * LastBlock() const { return m_lastBlock; }
35
36     private:
37         size_t m_lastSize;
38         size_t m_callCount;
39         char * m_lastBlock;
40 };
41
42 bool Callback(const void * block, size_t size, void * data)
43 {
44 TRACKER & tracker = *static_cast<TRACKER *>(data);
45 return tracker.Call(block, size);
46 }
47
48 }
49
50 namespace tut
51 {
52     struct bfstream_data {
53     };
54
55     typedef test_group<bfstream_data> tg;
56     tg bfstream_test_group("BFStream tests group");
57
58     typedef tg::object testobject;
59
60     template<>
61     template<>
62     void testobject::test<1>()
63     {
64         set_test_name("Check bfstream mechanics");
65
66         TRACKER tracker;
67         STG::ENCRYPT_STREAM stream("pr7Hhen", Callback, &tracker);
68         ensure_equals("CallCount() == 0 after construction", tracker.CallCount(), 0);
69
70         uint32_t block[2] = {0x12345678, 0x87654321};
71         stream.Put(&block[0], sizeof(block[0]));
72         ensure_equals("CallCount() == 0 after first put", tracker.CallCount(), 0);
73         stream.Put(&block[1], sizeof(block[1]));
74         ensure_equals("CallCount() == 1 after second put", tracker.CallCount(), 1);
75
76         uint32_t block2[4] = {0x12345678, 0x87654321, 0x12345678, 0x87654321};
77         stream.Put(&block2[0], sizeof(block2[0]) * 3);
78         ensure_equals("CallCount() == 2 after third put", tracker.CallCount(), 2);
79         stream.Put(&block2[3], sizeof(block2[3]));
80         ensure_equals("CallCount() == 3 after fourth put", tracker.CallCount(), 3);
81     }
82
83     template<>
84     template<>
85     void testobject::test<2>()
86     {
87         set_test_name("Check bfstream encryption");
88
89         TRACKER tracker;
90         STG::ENCRYPT_STREAM stream("pr7Hhen", Callback, &tracker);
91
92         uint32_t block[2] = {0x12345678, 0x87654321};
93         stream.Put(&block[0], sizeof(block[0]));
94         ensure_equals("LastSize() == 0 after first put", tracker.LastSize(), 0);
95         ensure_equals("LastBlock() == NULL after first put", tracker.LastBlock(), static_cast<const void *>(NULL));
96         stream.Put(&block[1], sizeof(block[1]));
97         ensure_equals("LastSize() == 8 after second put", tracker.LastSize(), 8);
98         const uint32_t * ptr = static_cast<const uint32_t *>(tracker.LastBlock());
99         ensure_equals("ptr[0] == 0xd3988cd after second put", ptr[0], 0xd3988cd);
100         ensure_equals("ptr[1] == 0x7996c6d6 after second put", ptr[1], 0x7996c6d6);
101
102         uint32_t block2[4] = {0x12345678, 0x87654321, 0x12345678, 0x87654321};
103         stream.Put(&block2[0], sizeof(block2[0]) * 3);
104         ensure_equals("LastSize() == 8 after third put", tracker.LastSize(), 8);
105         ptr = static_cast<const uint32_t *>(tracker.LastBlock());
106         ensure_equals("ptr[0] == 0xd3988cd after third put", ptr[0], 0xd3988cd);
107         ensure_equals("ptr[1] == 0x7996c6d6 after third put", ptr[1], 0x7996c6d6);
108
109         stream.Put(&block2[3], sizeof(block2[3]));
110         ensure_equals("LastSize() == 8 after fourth put", tracker.LastSize(), 8);
111         ptr = static_cast<const uint32_t *>(tracker.LastBlock());
112         ensure_equals("ptr[0] == 0xd3988cd after fourth put", ptr[0], 0xd3988cd);
113         ensure_equals("ptr[1] == 0x7996c6d6 after fourth put", ptr[1], 0x7996c6d6);
114     }
115
116     template<>
117     template<>
118     void testobject::test<3>()
119     {
120         set_test_name("Check bfstream long string processing");
121
122         TRACKER tracker;
123         STG::ENCRYPT_STREAM estream("pr7Hhen", Callback, &tracker);
124         std::string source = "This is a test long string for checking stream encryption/decryption. \"abcdefghijklmnopqrstuvwxyz 0123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ\"";
125         char buffer[source.length() + 9];
126         memset(buffer, 0, sizeof(buffer));
127
128         estream.Put(source.c_str(), source.length() + 1, true);
129         ensure("Encryption long string LastSize()", tracker.LastSize() >= source.length() + 1);
130         ensure("Encryption long string LastBlock() != NULL", tracker.LastBlock() != NULL);
131         memcpy(buffer, tracker.LastBlock(), std::min(tracker.LastSize(), sizeof(buffer)));
132
133         STG::DECRYPT_STREAM dstream("pr7Hhen", Callback, &tracker);
134         dstream.Put(buffer, sizeof(buffer), true);
135         ensure("Decryption long string LastSize() decryption", tracker.LastSize() >= sizeof(buffer));
136         ensure("Decryption long string LastBlock() != NULL", tracker.LastBlock() != NULL);
137         memcpy(buffer, tracker.LastBlock(), std::min(tracker.LastSize(), sizeof(buffer)));
138
139         ensure_equals("Decrypt(Encrypt(source)) == source", std::string(buffer), source);
140     }
141
142 }