rumqttc/mqttbytes/
mod.rs

1//! # mqttbytes
2//!
3//! This module contains the low level struct definitions required to assemble and disassemble MQTT 3.1.1 packets in rumqttc.
4//! The [`bytes`](https://docs.rs/bytes) crate is used internally.
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use core::fmt;
8use std::slice::Iter;
9
10mod topic;
11pub mod v4;
12
13pub use topic::*;
14
15/// Error during serialization and deserialization
16#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
17pub enum Error {
18    #[error("Expected Connect, received: {0:?}")]
19    NotConnect(PacketType),
20    #[error("Unexpected Connect")]
21    UnexpectedConnect,
22    #[error("Invalid Connect return code: {0}")]
23    InvalidConnectReturnCode(u8),
24    #[error("Invalid protocol")]
25    InvalidProtocol,
26    #[error("Invalid protocol level: {0}")]
27    InvalidProtocolLevel(u8),
28    #[error("Incorrect packet format")]
29    IncorrectPacketFormat,
30    #[error("Invalid packet type: {0}")]
31    InvalidPacketType(u8),
32    #[error("Invalid property type: {0}")]
33    InvalidPropertyType(u8),
34    #[error("Invalid QoS level: {0}")]
35    InvalidQoS(u8),
36    #[error("Invalid subscribe reason code: {0}")]
37    InvalidSubscribeReasonCode(u8),
38    #[error("Packet id Zero")]
39    PacketIdZero,
40    #[error("Payload size is incorrect")]
41    PayloadSizeIncorrect,
42    #[error("payload is too long")]
43    PayloadTooLong,
44    #[error("payload size limit exceeded: {0}")]
45    PayloadSizeLimitExceeded(usize),
46    #[error("Payload required")]
47    PayloadRequired,
48    #[error("Topic is not UTF-8")]
49    TopicNotUtf8,
50    #[error("Promised boundary crossed: {0}")]
51    BoundaryCrossed(usize),
52    #[error("Malformed packet")]
53    MalformedPacket,
54    #[error("Malformed remaining length")]
55    MalformedRemainingLength,
56    #[error("A Subscribe packet must contain atleast one filter")]
57    EmptySubscription,
58    /// More bytes required to frame packet. Argument
59    /// implies minimum additional bytes required to
60    /// proceed further
61    #[error("At least {0} more bytes required to frame packet")]
62    InsufficientBytes(usize),
63}
64
65/// MQTT packet type
66#[repr(u8)]
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum PacketType {
69    Connect = 1,
70    ConnAck,
71    Publish,
72    PubAck,
73    PubRec,
74    PubRel,
75    PubComp,
76    Subscribe,
77    SubAck,
78    Unsubscribe,
79    UnsubAck,
80    PingReq,
81    PingResp,
82    Disconnect,
83}
84
85/// Protocol type
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum Protocol {
88    V4,
89    V5,
90}
91
92/// Quality of service
93#[repr(u8)]
94#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
95pub enum QoS {
96    AtMostOnce = 0,
97    AtLeastOnce = 1,
98    ExactlyOnce = 2,
99}
100
101/// Packet type from a byte
102///
103/// ```text
104///          7                          3                          0
105///          +--------------------------+--------------------------+
106/// byte 1   | MQTT Control Packet Type | Flags for each type      |
107///          +--------------------------+--------------------------+
108///          |         Remaining Bytes Len  (1/2/3/4 bytes)        |
109///          +-----------------------------------------------------+
110///
111/// <https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349207>
112/// ```
113#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
114pub struct FixedHeader {
115    /// First byte of the stream. Used to identify packet types and
116    /// several flags
117    byte1: u8,
118    /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header
119    /// len can vary from 2 bytes to 5 bytes
120    /// 1..4 bytes are variable length encoded to represent remaining length
121    fixed_header_len: usize,
122    /// Remaining length of the packet. Doesn't include fixed header bytes
123    /// Represents variable header + payload size
124    remaining_len: usize,
125}
126
127impl FixedHeader {
128    pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
129        FixedHeader {
130            byte1,
131            fixed_header_len: remaining_len_len + 1,
132            remaining_len,
133        }
134    }
135
136    pub fn packet_type(&self) -> Result<PacketType, Error> {
137        let num = self.byte1 >> 4;
138        match num {
139            1 => Ok(PacketType::Connect),
140            2 => Ok(PacketType::ConnAck),
141            3 => Ok(PacketType::Publish),
142            4 => Ok(PacketType::PubAck),
143            5 => Ok(PacketType::PubRec),
144            6 => Ok(PacketType::PubRel),
145            7 => Ok(PacketType::PubComp),
146            8 => Ok(PacketType::Subscribe),
147            9 => Ok(PacketType::SubAck),
148            10 => Ok(PacketType::Unsubscribe),
149            11 => Ok(PacketType::UnsubAck),
150            12 => Ok(PacketType::PingReq),
151            13 => Ok(PacketType::PingResp),
152            14 => Ok(PacketType::Disconnect),
153            _ => Err(Error::InvalidPacketType(num)),
154        }
155    }
156
157    /// Returns the size of full packet (fixed header + variable header + payload)
158    /// Fixed header is enough to get the size of a frame in the stream
159    pub fn frame_length(&self) -> usize {
160        self.fixed_header_len + self.remaining_len
161    }
162}
163
164/// Checks if the stream has enough bytes to frame a packet and returns fixed header
165/// only if a packet can be framed with existing bytes in the `stream`.
166/// The passed stream doesn't modify parent stream's cursor. If this function
167/// returned an error, next `check` on the same parent stream is forced start
168/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally)
169pub fn check(stream: Iter<u8>, max_packet_size: usize) -> Result<FixedHeader, Error> {
170    // Create fixed header if there are enough bytes in the stream
171    // to frame full packet
172    let stream_len = stream.len();
173    let fixed_header = parse_fixed_header(stream)?;
174
175    // Don't let rogue connections attack with huge payloads.
176    // Disconnect them before reading all that data
177    if fixed_header.remaining_len > max_packet_size {
178        return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len));
179    }
180
181    // If the current call fails due to insufficient bytes in the stream,
182    // after calculating remaining length, we extend the stream
183    let frame_length = fixed_header.frame_length();
184    if stream_len < frame_length {
185        return Err(Error::InsufficientBytes(frame_length - stream_len));
186    }
187
188    Ok(fixed_header)
189}
190
191/// Parses fixed header
192fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
193    // At least 2 bytes are necessary to frame a packet
194    let stream_len = stream.len();
195    if stream_len < 2 {
196        return Err(Error::InsufficientBytes(2 - stream_len));
197    }
198
199    let byte1 = stream.next().unwrap();
200    let (len_len, len) = length(stream)?;
201
202    Ok(FixedHeader::new(*byte1, len_len, len))
203}
204
205/// Parses variable byte integer in the stream and returns the length
206/// and number of bytes that make it. Used for remaining length calculation
207/// as well as for calculating property lengths
208fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
209    let mut len: usize = 0;
210    let mut len_len = 0;
211    let mut done = false;
212    let mut shift = 0;
213
214    // Use continuation bit at position 7 to continue reading next
215    // byte to frame 'length'.
216    // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will
217    // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx
218    for byte in stream {
219        len_len += 1;
220        let byte = *byte as usize;
221        len += (byte & 0x7F) << shift;
222
223        // stop when continue bit is 0
224        done = (byte & 0x80) == 0;
225        if done {
226            break;
227        }
228
229        shift += 7;
230
231        // Only a max of 4 bytes allowed for remaining length
232        // more than 4 shifts (0, 7, 14, 21) implies bad length
233        if shift > 21 {
234            return Err(Error::MalformedRemainingLength);
235        }
236    }
237
238    // Not enough bytes to frame remaining length. wait for
239    // one more byte
240    if !done {
241        return Err(Error::InsufficientBytes(1));
242    }
243
244    Ok((len_len, len))
245}
246
247/// Reads a series of bytes with a length from a byte stream
248fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
249    let len = read_u16(stream)? as usize;
250
251    // Prevent attacks with wrong remaining length. This method is used in
252    // `packet.assembly()` with (enough) bytes to frame packet. Ensures that
253    // reading variable len string or bytes doesn't cross promised boundary
254    // with `read_fixed_header()`
255    if len > stream.len() {
256        return Err(Error::BoundaryCrossed(len));
257    }
258
259    Ok(stream.split_to(len))
260}
261
262/// Reads a string from bytes stream
263fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
264    let s = read_mqtt_bytes(stream)?;
265    match String::from_utf8(s.to_vec()) {
266        Ok(v) => Ok(v),
267        Err(_e) => Err(Error::TopicNotUtf8),
268    }
269}
270
271/// Serializes bytes to stream (including length)
272fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
273    stream.put_u16(bytes.len() as u16);
274    stream.extend_from_slice(bytes);
275}
276
277/// Serializes a string to stream
278fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
279    write_mqtt_bytes(stream, string.as_bytes());
280}
281
282/// Writes remaining length to stream and returns number of bytes for remaining length
283fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
284    if len > 268_435_455 {
285        return Err(Error::PayloadTooLong);
286    }
287
288    let mut done = false;
289    let mut x = len;
290    let mut count = 0;
291
292    while !done {
293        let mut byte = (x % 128) as u8;
294        x /= 128;
295        if x > 0 {
296            byte |= 128;
297        }
298
299        stream.put_u8(byte);
300        count += 1;
301        done = x == 0;
302    }
303
304    Ok(count)
305}
306
307/// Maps a number to QoS
308pub fn qos(num: u8) -> Result<QoS, Error> {
309    match num {
310        0 => Ok(QoS::AtMostOnce),
311        1 => Ok(QoS::AtLeastOnce),
312        2 => Ok(QoS::ExactlyOnce),
313        qos => Err(Error::InvalidQoS(qos)),
314    }
315}
316
317/// After collecting enough bytes to frame a packet (packet's frame())
318/// , It's possible that content itself in the stream is wrong. Like expected
319/// packet id or qos not being present. In cases where `read_mqtt_string` or
320/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to
321/// parse qos next, these pre checks will prevent `bytes` crashes
322fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
323    if stream.len() < 2 {
324        return Err(Error::MalformedPacket);
325    }
326
327    Ok(stream.get_u16())
328}
329
330fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
331    if stream.is_empty() {
332        return Err(Error::MalformedPacket);
333    }
334
335    Ok(stream.get_u8())
336}