1use bytes::{Buf, BufMut, Bytes, BytesMut};
7use core::fmt;
8use std::slice::Iter;
9
10mod topic;
11pub mod v4;
12
13pub use topic::*;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
17pub enum Error {
18 #[error("Expected Connect, received: {0:?}")]
19 NotConnect(PacketType),
20 #[error("Unexpected Connect")]
21 UnexpectedConnect,
22 #[error("Invalid Connect return code: {0}")]
23 InvalidConnectReturnCode(u8),
24 #[error("Invalid protocol")]
25 InvalidProtocol,
26 #[error("Invalid protocol level: {0}")]
27 InvalidProtocolLevel(u8),
28 #[error("Incorrect packet format")]
29 IncorrectPacketFormat,
30 #[error("Invalid packet type: {0}")]
31 InvalidPacketType(u8),
32 #[error("Invalid property type: {0}")]
33 InvalidPropertyType(u8),
34 #[error("Invalid QoS level: {0}")]
35 InvalidQoS(u8),
36 #[error("Invalid subscribe reason code: {0}")]
37 InvalidSubscribeReasonCode(u8),
38 #[error("Packet id Zero")]
39 PacketIdZero,
40 #[error("Payload size is incorrect")]
41 PayloadSizeIncorrect,
42 #[error("payload is too long")]
43 PayloadTooLong,
44 #[error("payload size limit exceeded: {0}")]
45 PayloadSizeLimitExceeded(usize),
46 #[error("Payload required")]
47 PayloadRequired,
48 #[error("Topic is not UTF-8")]
49 TopicNotUtf8,
50 #[error("Promised boundary crossed: {0}")]
51 BoundaryCrossed(usize),
52 #[error("Malformed packet")]
53 MalformedPacket,
54 #[error("Malformed remaining length")]
55 MalformedRemainingLength,
56 #[error("A Subscribe packet must contain atleast one filter")]
57 EmptySubscription,
58 #[error("At least {0} more bytes required to frame packet")]
62 InsufficientBytes(usize),
63}
64
65#[repr(u8)]
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum PacketType {
69 Connect = 1,
70 ConnAck,
71 Publish,
72 PubAck,
73 PubRec,
74 PubRel,
75 PubComp,
76 Subscribe,
77 SubAck,
78 Unsubscribe,
79 UnsubAck,
80 PingReq,
81 PingResp,
82 Disconnect,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum Protocol {
88 V4,
89 V5,
90}
91
92#[repr(u8)]
94#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
95pub enum QoS {
96 AtMostOnce = 0,
97 AtLeastOnce = 1,
98 ExactlyOnce = 2,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
114pub struct FixedHeader {
115 byte1: u8,
118 fixed_header_len: usize,
122 remaining_len: usize,
125}
126
127impl FixedHeader {
128 pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
129 FixedHeader {
130 byte1,
131 fixed_header_len: remaining_len_len + 1,
132 remaining_len,
133 }
134 }
135
136 pub fn packet_type(&self) -> Result<PacketType, Error> {
137 let num = self.byte1 >> 4;
138 match num {
139 1 => Ok(PacketType::Connect),
140 2 => Ok(PacketType::ConnAck),
141 3 => Ok(PacketType::Publish),
142 4 => Ok(PacketType::PubAck),
143 5 => Ok(PacketType::PubRec),
144 6 => Ok(PacketType::PubRel),
145 7 => Ok(PacketType::PubComp),
146 8 => Ok(PacketType::Subscribe),
147 9 => Ok(PacketType::SubAck),
148 10 => Ok(PacketType::Unsubscribe),
149 11 => Ok(PacketType::UnsubAck),
150 12 => Ok(PacketType::PingReq),
151 13 => Ok(PacketType::PingResp),
152 14 => Ok(PacketType::Disconnect),
153 _ => Err(Error::InvalidPacketType(num)),
154 }
155 }
156
157 pub fn frame_length(&self) -> usize {
160 self.fixed_header_len + self.remaining_len
161 }
162}
163
164pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
170 let stream_len = stream.len();
173 let fixed_header = parse_fixed_header(stream)?;
174
175 if fixed_header.remaining_len > max_packet_size {
178 return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
179 }
180
181 let frame_length = fixed_header.frame_length();
184 if stream_len < frame_length {
185 return Err(Error::InsufficientBytes(frame_length - stream_len));
186 }
187
188 Ok(fixed_header)
189}
190
191fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
193 let stream_len = stream.len();
195 if stream_len < 2 {
196 return Err(Error::InsufficientBytes(2 - stream_len));
197 }
198
199 let byte1 = stream.next().unwrap();
200 let (len_len, len) = length(stream)?;
201
202 Ok(FixedHeader::new(*byte1, len_len, len))
203}
204
205fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
209 let mut len: usize = 0;
210 let mut len_len = 0;
211 let mut done = false;
212 let mut shift = 0;
213
214 for byte in stream {
219 len_len += 1;
220 let byte = *byte as usize;
221 len += (byte & 0x7F) << shift;
222
223 done = (byte & 0x80) == 0;
225 if done {
226 break;
227 }
228
229 shift += 7;
230
231 if shift > 21 {
234 return Err(Error::MalformedRemainingLength);
235 }
236 }
237
238 if !done {
241 return Err(Error::InsufficientBytes(1));
242 }
243
244 Ok((len_len, len))
245}
246
247fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
249 let len = read_u16(stream)? as usize;
250
251 if len > stream.len() {
256 return Err(Error::BoundaryCrossed(len));
257 }
258
259 Ok(stream.split_to(len))
260}
261
262fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
264 let s = read_mqtt_bytes(stream)?;
265 match String::from_utf8(s.to_vec()) {
266 Ok(v) => Ok(v),
267 Err(_e) => Err(Error::TopicNotUtf8),
268 }
269}
270
271fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
273 stream.put_u16(bytes.len() as u16);
274 stream.extend_from_slice(bytes);
275}
276
277fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
279 write_mqtt_bytes(stream, string.as_bytes());
280}
281
282fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
284 if len > 268_435_455 {
285 return Err(Error::PayloadTooLong);
286 }
287
288 let mut done = false;
289 let mut x = len;
290 let mut count = 0;
291
292 while !done {
293 let mut byte = (x % 128) as u8;
294 x /= 128;
295 if x > 0 {
296 byte |= 128;
297 }
298
299 stream.put_u8(byte);
300 count += 1;
301 done = x == 0;
302 }
303
304 Ok(count)
305}
306
307pub fn qos(num: u8) -> Result<QoS, Error> {
309 match num {
310 0 => Ok(QoS::AtMostOnce),
311 1 => Ok(QoS::AtLeastOnce),
312 2 => Ok(QoS::ExactlyOnce),
313 qos => Err(Error::InvalidQoS(qos)),
314 }
315}
316
317fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
323 if stream.len() < 2 {
324 return Err(Error::MalformedPacket);
325 }
326
327 Ok(stream.get_u16())
328}
329
330fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
331 if stream.is_empty() {
332 return Err(Error::MalformedPacket);
333 }
334
335 Ok(stream.get_u8())
336}