]> git.stg.codes - stg.git/blob - include/stg/raw_ip_packet.h
Fix divert_cap in FreeBSD.
[stg.git] / include / stg / raw_ip_packet.h
1 #pragma once
2
3 #include <cstring>
4
5 #if defined(FREE_BSD)
6 #include <sys/types.h> // u_char, u_int32_t, etc.
7 #include <netinet/in_systm.h> // n_long in netinet/ip.h
8 #endif
9
10 #include <netinet/in.h> // for htons
11 #include <netinet/ip.h> // for struct ip
12
13 #define IPv4 (2)
14
15 namespace STG
16 {
17
18 enum { packetSize = 68 }; //60(max) ip + 8 udp or tcp (part of tcp or udp header to ports)
19 //-----------------------------------------------------------------------------
20 struct RawPacket
21 {
22     RawPacket()
23         : dataLen(-1)
24     {
25         memset(rawPacket.data, 0, packetSize);
26     }
27
28     RawPacket(const RawPacket& rhs) noexcept
29     {
30         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
31     }
32     RawPacket& operator=(const RawPacket& rhs) noexcept
33     {
34         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
35         return *this;
36     }
37     RawPacket(RawPacket&& rhs) noexcept
38     {
39         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
40     }
41     RawPacket& operator=(RawPacket&& rhs) noexcept
42     {
43         memcpy(rawPacket.data, rhs.rawPacket.data, packetSize);
44         return *this;
45     }
46
47     uint16_t GetIPVersion() const noexcept;
48     uint8_t  GetHeaderLen() const noexcept;
49     uint8_t  GetProto() const noexcept;
50     uint32_t GetLen() const noexcept;
51     uint32_t GetSrcIP() const noexcept;
52     uint32_t GetDstIP() const noexcept;
53     uint16_t GetSrcPort() const noexcept;
54     uint16_t GetDstPort() const noexcept;
55
56     bool     operator==(const RawPacket& rhs) const noexcept;
57     bool     operator!=(const RawPacket& rhs) const noexcept { return !(*this == rhs); }
58     bool     operator<(const RawPacket& rhs) const noexcept;
59
60     union
61     {
62         uint8_t data[packetSize]; // Packet header as a raw data
63         struct
64         {
65             struct ip   ipHeader;
66             // Only for packets without options field
67             uint16_t    sPort;
68             uint16_t    dPort;
69         } header;
70     } rawPacket;
71     int32_t dataLen; // IP packet length. Set to -1 to use length field from the header
72 };
73 //-----------------------------------------------------------------------------
74 inline uint16_t RawPacket::GetIPVersion() const noexcept
75 {
76     return rawPacket.header.ipHeader.ip_v;
77 }
78 //-----------------------------------------------------------------------------
79 inline uint8_t RawPacket::GetHeaderLen() const noexcept
80 {
81     return rawPacket.header.ipHeader.ip_hl * 4;
82 }
83 //-----------------------------------------------------------------------------
84 inline uint8_t RawPacket::GetProto() const noexcept
85 {
86     return rawPacket.header.ipHeader.ip_p;
87 }
88 //-----------------------------------------------------------------------------
89 inline uint32_t RawPacket::GetLen() const noexcept
90 {
91     if (dataLen != -1)
92         return dataLen;
93     return ntohs(rawPacket.header.ipHeader.ip_len);
94 }
95 //-----------------------------------------------------------------------------
96 inline uint32_t RawPacket::GetSrcIP() const noexcept
97 {
98     return rawPacket.header.ipHeader.ip_src.s_addr;
99 }
100 //-----------------------------------------------------------------------------
101 inline uint32_t RawPacket::GetDstIP() const noexcept
102 {
103     return rawPacket.header.ipHeader.ip_dst.s_addr;
104 }
105 //-----------------------------------------------------------------------------
106 inline uint16_t RawPacket::GetSrcPort() const noexcept
107 {
108     if (rawPacket.header.ipHeader.ip_p == 1) // for icmp proto return port 0
109         return 0;
110     const uint8_t* pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4;
111     return ntohs(*reinterpret_cast<const uint16_t *>(pos));
112 }
113 //-----------------------------------------------------------------------------
114 inline uint16_t RawPacket::GetDstPort() const noexcept
115 {
116     if (rawPacket.header.ipHeader.ip_p == 1) // for icmp proto return port 0
117         return 0;
118     const uint8_t * pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4 + 2;
119     return ntohs(*reinterpret_cast<const uint16_t *>(pos));
120 }
121 //-----------------------------------------------------------------------------
122 inline bool RawPacket::operator==(const RawPacket& rhs) const noexcept
123 {
124     if (rawPacket.header.ipHeader.ip_src.s_addr != rhs.rawPacket.header.ipHeader.ip_src.s_addr)
125         return false;
126
127     if (rawPacket.header.ipHeader.ip_dst.s_addr != rhs.rawPacket.header.ipHeader.ip_dst.s_addr)
128         return false;
129
130     if (rawPacket.header.ipHeader.ip_p != 1 && rhs.rawPacket.header.ipHeader.ip_p != 1)
131     {
132         const uint8_t * pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4;
133         const uint8_t * rpos = rhs.rawPacket.data + rhs.rawPacket.header.ipHeader.ip_hl * 4;
134         if (*reinterpret_cast<const uint16_t *>(pos) != *reinterpret_cast<const uint16_t *>(rpos))
135             return false;
136
137         pos += 2;
138         rpos += 2;
139         if (*reinterpret_cast<const uint16_t *>(pos) != *reinterpret_cast<const uint16_t *>(rpos))
140             return false;
141     }
142
143     if (rawPacket.header.ipHeader.ip_p != rhs.rawPacket.header.ipHeader.ip_p)
144         return false;
145
146     return true;
147 }
148 //-----------------------------------------------------------------------------
149 inline bool RawPacket::operator<(const RawPacket& rhs) const noexcept
150 {
151     if (rawPacket.header.ipHeader.ip_src.s_addr < rhs.rawPacket.header.ipHeader.ip_src.s_addr)
152         return true;
153     if (rawPacket.header.ipHeader.ip_src.s_addr > rhs.rawPacket.header.ipHeader.ip_src.s_addr)
154         return false;
155
156     if (rawPacket.header.ipHeader.ip_dst.s_addr < rhs.rawPacket.header.ipHeader.ip_dst.s_addr)
157         return true;
158     if (rawPacket.header.ipHeader.ip_dst.s_addr > rhs.rawPacket.header.ipHeader.ip_dst.s_addr)
159         return false;
160
161     if (rawPacket.header.ipHeader.ip_p != 1 && rhs.rawPacket.header.ipHeader.ip_p != 1)
162     {
163         const uint8_t * pos = rawPacket.data + rawPacket.header.ipHeader.ip_hl * 4;
164         const uint8_t * rpos = rhs.rawPacket.data + rhs.rawPacket.header.ipHeader.ip_hl * 4;
165         if (*reinterpret_cast<const uint16_t *>(pos) < *reinterpret_cast<const uint16_t *>(rpos))
166             return true;
167         if (*reinterpret_cast<const uint16_t *>(pos) > *reinterpret_cast<const uint16_t *>(rpos))
168             return false;
169
170         pos += 2;
171         rpos += 2;
172         if (*reinterpret_cast<const uint16_t *>(pos) < *reinterpret_cast<const uint16_t *>(rpos))
173             return true;
174         if (*reinterpret_cast<const uint16_t *>(pos) > *reinterpret_cast<const uint16_t *>(rpos))
175             return false;
176     }
177
178     if (rawPacket.header.ipHeader.ip_p < rhs.rawPacket.header.ipHeader.ip_p)
179         return true;
180
181     return false;
182 }
183 //-----------------------------------------------------------------------------
184
185 }