]> git.stg.codes - stg.git/blob - include/stg/raw_ip_packet.h
Merge remote-tracking branch 'github/master'
[stg.git] / include / stg / raw_ip_packet.h
1 #pragma once
2
3 #include <cstring>
4
5 #if defined(FREE_BSD)
6 #include <netinet/in_systm.h> // n_long in netinet/ip.h
7 #endif
8
9 #include <netinet/in.h> // for htons
10 #include <netinet/ip.h> // for struct ip
11
12 #define IPv4 (2)
13
14 namespace STG
15 {
16
17 enum { packetSize = 68 }; //60(max) ip + 8 udp or tcp (part of tcp or udp header to ports)
18 //-----------------------------------------------------------------------------
19 struct RawPacket
20 {
21     RawPacket()
22         : dataLen(-1)
23     {
24         memset(rawPacket.data, 0, packetSize);
25     }
26
27     RawPacket(const RawPacket& rhs) noexcept
28     {
29         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
30     }
31     RawPacket& operator=(const RawPacket& rhs) noexcept
32     {
33         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
34         return *this;
35     }
36     RawPacket(RawPacket&& rhs) noexcept
37     {
38         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
39     }
40     RawPacket& operator=(RawPacket&& rhs) noexcept
41     {
42         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
43         return *this;
44     }
45
46     uint16_t GetIPVersion() const noexcept;
47     uint8_t  GetHeaderLen() const noexcept;
48     uint8_t  GetProto() const noexcept;
49     uint32_t GetLen() const noexcept;
50     uint32_t GetSrcIP() const noexcept;
51     uint32_t GetDstIP() const noexcept;
52     uint16_t GetSrcPort() const noexcept;
53     uint16_t GetDstPort() const noexcept;
54
55     bool     operator==(const RawPacket& rhs) const noexcept;
56     bool     operator!=(const RawPacket& rhs) const noexcept { return !(*this == rhs); }
57     bool     operator<(const RawPacket& rhs) const noexcept;
58
59     union
60     {
61         uint8_t data[packetSize]; // Packet header as a raw data
62         struct
63         {
64             struct ip   ipHeader;
65             // Only for packets without options field
66             uint16_t    sPort;
67             uint16_t    dPort;
68         } header;
69     } rawPacket;
70     int32_t dataLen; // IP packet length. Set to -1 to use length field from the header
71 };
72 //-----------------------------------------------------------------------------
73 inline uint16_t RawPacket::GetIPVersion() const noexcept
74 {
75     return rawPacket.header.ipHeader.ip_v;
76 }
77 //-----------------------------------------------------------------------------
78 inline uint8_t RawPacket::GetHeaderLen() const noexcept
79 {
80     return rawPacket.header.ipHeader.ip_hl * 4;
81 }
82 //-----------------------------------------------------------------------------
83 inline uint8_t RawPacket::GetProto() const noexcept
84 {
85     return rawPacket.header.ipHeader.ip_p;
86 }
87 //-----------------------------------------------------------------------------
88 inline uint32_t RawPacket::GetLen() const noexcept
89 {
90     if (dataLen != -1)
91         return dataLen;
92     return ntohs(rawPacket.header.ipHeader.ip_len);
93 }
94 //-----------------------------------------------------------------------------
95 inline uint32_t RawPacket::GetSrcIP() const noexcept
96 {
97     return rawPacket.header.ipHeader.ip_src.s_addr;
98 }
99 //-----------------------------------------------------------------------------
100 inline uint32_t RawPacket::GetDstIP() const noexcept
101 {
102     return rawPacket.header.ipHeader.ip_dst.s_addr;
103 }
104 //-----------------------------------------------------------------------------
105 inline uint16_t RawPacket::GetSrcPort() const noexcept
106 {
107     if (rawPacket.header.ipHeader.ip_p == 1) // for icmp proto return port 0
108         return 0;
109     const uint8_t* pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4;
110     return ntohs(*reinterpret_cast<const uint16_t *>(pos));
111 }
112 //-----------------------------------------------------------------------------
113 inline uint16_t RawPacket::GetDstPort() const noexcept
114 {
115     if (rawPacket.header.ipHeader.ip_p == 1) // for icmp proto return port 0
116         return 0;
117     const uint8_t * pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4 + 2;
118     return ntohs(*reinterpret_cast<const uint16_t *>(pos));
119 }
120 //-----------------------------------------------------------------------------
121 inline bool RawPacket::operator==(const RawPacket& rhs) const noexcept
122 {
123     if (rawPacket.header.ipHeader.ip_src.s_addr != rhs.rawPacket.header.ipHeader.ip_src.s_addr)
124         return false;
125
126     if (rawPacket.header.ipHeader.ip_dst.s_addr != rhs.rawPacket.header.ipHeader.ip_dst.s_addr)
127         return false;
128
129     if (rawPacket.header.ipHeader.ip_p != 1 && rhs.rawPacket.header.ipHeader.ip_p != 1)
130     {
131         const uint8_t * pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4;
132         const uint8_t * rpos = rhs.rawPacket.data + rhs.rawPacket.header.ipHeader.ip_hl * 4;
133         if (*reinterpret_cast<const uint16_t *>(pos) != *reinterpret_cast<const uint16_t *>(rpos))
134             return false;
135
136         pos += 2;
137         rpos += 2;
138         if (*reinterpret_cast<const uint16_t *>(pos) != *reinterpret_cast<const uint16_t *>(rpos))
139             return false;
140     }
141
142     if (rawPacket.header.ipHeader.ip_p != rhs.rawPacket.header.ipHeader.ip_p)
143         return false;
144
145     return true;
146 }
147 //-----------------------------------------------------------------------------
148 inline bool RawPacket::operator<(const RawPacket& rhs) const noexcept
149 {
150     if (rawPacket.header.ipHeader.ip_src.s_addr < rhs.rawPacket.header.ipHeader.ip_src.s_addr)
151         return true;
152     if (rawPacket.header.ipHeader.ip_src.s_addr > rhs.rawPacket.header.ipHeader.ip_src.s_addr)
153         return false;
154
155     if (rawPacket.header.ipHeader.ip_dst.s_addr < rhs.rawPacket.header.ipHeader.ip_dst.s_addr)
156         return true;
157     if (rawPacket.header.ipHeader.ip_dst.s_addr > rhs.rawPacket.header.ipHeader.ip_dst.s_addr)
158         return false;
159
160     if (rawPacket.header.ipHeader.ip_p != 1 && rhs.rawPacket.header.ipHeader.ip_p != 1)
161     {
162         const uint8_t * pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4;
163         const uint8_t * rpos = rhs.rawPacket.data + rhs.rawPacket.header.ipHeader.ip_hl * 4;
164         if (*reinterpret_cast<const uint16_t *>(pos) < *reinterpret_cast<const uint16_t *>(rpos))
165             return true;
166         if (*reinterpret_cast<const uint16_t *>(pos) > *reinterpret_cast<const uint16_t *>(rpos))
167             return false;
168
169         pos += 2;
170         rpos += 2;
171         if (*reinterpret_cast<const uint16_t *>(pos) < *reinterpret_cast<const uint16_t *>(rpos))
172             return true;
173         if (*reinterpret_cast<const uint16_t *>(pos) > *reinterpret_cast<const uint16_t *>(rpos))
174             return false;
175     }
176
177     if (rawPacket.header.ipHeader.ip_p < rhs.rawPacket.header.ipHeader.ip_p)
178         return true;
179
180     return false;
181 }
182 //-----------------------------------------------------------------------------
183
184 }