rumqttc/v5/mqttbytes/v5/
mod.rs

1use std::slice::Iter;
2
3pub use self::{
4    connack::{ConnAck, ConnAckProperties, ConnectReturnCode},
5    connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login},
6    disconnect::{Disconnect, DisconnectReasonCode},
7    ping::{PingReq, PingResp},
8    puback::{PubAck, PubAckProperties, PubAckReason},
9    pubcomp::{PubComp, PubCompProperties, PubCompReason},
10    publish::{Publish, PublishProperties},
11    pubrec::{PubRec, PubRecProperties, PubRecReason},
12    pubrel::{PubRel, PubRelProperties, PubRelReason},
13    suback::{SubAck, SubAckProperties, SubscribeReasonCode},
14    subscribe::{Filter, RetainForwardRule, Subscribe, SubscribeProperties},
15    unsuback::{UnsubAck, UnsubAckProperties, UnsubAckReason},
16    unsubscribe::{Unsubscribe, UnsubscribeProperties},
17};
18
19use super::*;
20use bytes::{Buf, BufMut, Bytes, BytesMut};
21
22mod connack;
23mod connect;
24mod disconnect;
25mod ping;
26mod puback;
27mod pubcomp;
28mod publish;
29mod pubrec;
30mod pubrel;
31mod suback;
32mod subscribe;
33mod unsuback;
34mod unsubscribe;
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub enum Packet {
38    Connect(Connect, Option<LastWill>, Option<Login>),
39    ConnAck(ConnAck),
40    Publish(Publish),
41    PubAck(PubAck),
42    PingReq(PingReq),
43    PingResp(PingResp),
44    Subscribe(Subscribe),
45    SubAck(SubAck),
46    PubRec(PubRec),
47    PubRel(PubRel),
48    PubComp(PubComp),
49    Unsubscribe(Unsubscribe),
50    UnsubAck(UnsubAck),
51    Disconnect(Disconnect),
52}
53
54impl Packet {
55    /// Reads a stream of bytes and extracts next MQTT packet out of it
56    pub fn read(stream: &mut BytesMut, max_size: Option<usize>) -> Result<Packet, Error> {
57        let fixed_header = check(stream.iter(), max_size)?;
58
59        // Test with a stream with exactly the size to check border panics
60        let packet = stream.split_to(fixed_header.frame_length());
61        let packet_type = fixed_header.packet_type()?;
62
63        if fixed_header.remaining_len == 0 {
64            // no payload packets, Disconnect still has a bit more info
65            return match packet_type {
66                PacketType::PingReq => Ok(Packet::PingReq(PingReq)),
67                PacketType::PingResp => Ok(Packet::PingResp(PingResp)),
68                _ => Err(Error::PayloadRequired),
69            };
70        }
71
72        let packet = packet.freeze();
73        let packet = match packet_type {
74            PacketType::Connect => {
75                let (connect, will, login) = Connect::read(fixed_header, packet)?;
76                Packet::Connect(connect, will, login)
77            }
78            PacketType::Publish => {
79                let publish = Publish::read(fixed_header, packet)?;
80                Packet::Publish(publish)
81            }
82            PacketType::Subscribe => {
83                let subscribe = Subscribe::read(fixed_header, packet)?;
84                Packet::Subscribe(subscribe)
85            }
86            PacketType::Unsubscribe => {
87                let unsubscribe = Unsubscribe::read(fixed_header, packet)?;
88                Packet::Unsubscribe(unsubscribe)
89            }
90            PacketType::ConnAck => {
91                let connack = ConnAck::read(fixed_header, packet)?;
92                Packet::ConnAck(connack)
93            }
94            PacketType::PubAck => {
95                let puback = PubAck::read(fixed_header, packet)?;
96                Packet::PubAck(puback)
97            }
98            PacketType::PubRec => {
99                let pubrec = PubRec::read(fixed_header, packet)?;
100                Packet::PubRec(pubrec)
101            }
102            PacketType::PubRel => {
103                let pubrel = PubRel::read(fixed_header, packet)?;
104                Packet::PubRel(pubrel)
105            }
106            PacketType::PubComp => {
107                let pubcomp = PubComp::read(fixed_header, packet)?;
108                Packet::PubComp(pubcomp)
109            }
110            PacketType::SubAck => {
111                let suback = SubAck::read(fixed_header, packet)?;
112                Packet::SubAck(suback)
113            }
114            PacketType::UnsubAck => {
115                let unsuback = UnsubAck::read(fixed_header, packet)?;
116                Packet::UnsubAck(unsuback)
117            }
118            PacketType::PingReq => Packet::PingReq(PingReq),
119            PacketType::PingResp => Packet::PingResp(PingResp),
120            PacketType::Disconnect => {
121                let disconnect = Disconnect::read(fixed_header, packet)?;
122                Packet::Disconnect(disconnect)
123            }
124        };
125
126        Ok(packet)
127    }
128
129    pub fn write(&self, write: &mut BytesMut) -> Result<usize, Error> {
130        match self {
131            Self::Publish(publish) => publish.write(write),
132            Self::Subscribe(subscription) => subscription.write(write),
133            Self::Unsubscribe(unsubscribe) => unsubscribe.write(write),
134            Self::ConnAck(ack) => ack.write(write),
135            Self::PubAck(ack) => ack.write(write),
136            Self::SubAck(ack) => ack.write(write),
137            Self::UnsubAck(unsuback) => unsuback.write(write),
138            Self::PubRec(pubrec) => pubrec.write(write),
139            Self::PubRel(pubrel) => pubrel.write(write),
140            Self::PubComp(pubcomp) => pubcomp.write(write),
141            Self::Connect(connect, will, login) => connect.write(will, login, write),
142            Self::PingReq(_) => PingReq::write(write),
143            Self::PingResp(_) => PingResp::write(write),
144            Self::Disconnect(disconnect) => disconnect.write(write),
145        }
146    }
147}
148
149/// MQTT packet type
150#[repr(u8)]
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
152pub enum PacketType {
153    Connect = 1,
154    ConnAck,
155    Publish,
156    PubAck,
157    PubRec,
158    PubRel,
159    PubComp,
160    Subscribe,
161    SubAck,
162    Unsubscribe,
163    UnsubAck,
164    PingReq,
165    PingResp,
166    Disconnect,
167}
168
169#[repr(u8)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171enum PropertyType {
172    PayloadFormatIndicator = 1,
173    MessageExpiryInterval = 2,
174    ContentType = 3,
175    ResponseTopic = 8,
176    CorrelationData = 9,
177    SubscriptionIdentifier = 11,
178    SessionExpiryInterval = 17,
179    AssignedClientIdentifier = 18,
180    ServerKeepAlive = 19,
181    AuthenticationMethod = 21,
182    AuthenticationData = 22,
183    RequestProblemInformation = 23,
184    WillDelayInterval = 24,
185    RequestResponseInformation = 25,
186    ResponseInformation = 26,
187    ServerReference = 28,
188    ReasonString = 31,
189    ReceiveMaximum = 33,
190    TopicAliasMaximum = 34,
191    TopicAlias = 35,
192    MaximumQos = 36,
193    RetainAvailable = 37,
194    UserProperty = 38,
195    MaximumPacketSize = 39,
196    WildcardSubscriptionAvailable = 40,
197    SubscriptionIdentifierAvailable = 41,
198    SharedSubscriptionAvailable = 42,
199}
200
201/// Packet type from a byte
202///
203/// ```ignore
204///          7                          3                          0
205///          +--------------------------+--------------------------+
206/// byte 1   | MQTT Control Packet Type | Flags for each type      |
207///          +--------------------------+--------------------------+
208///          |         Remaining Bytes Len  (1/2/3/4 bytes)        |
209///          +-----------------------------------------------------+
210///
211/// <https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349207>
212/// ```
213#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
214pub struct FixedHeader {
215    /// First byte of the stream. Used to identify packet types and
216    /// several flags
217    byte1: u8,
218    /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header
219    /// len can vary from 2 bytes to 5 bytes
220    /// 1..4 bytes are variable length encoded to represent remaining length
221    fixed_header_len: usize,
222    /// Remaining length of the packet. Doesn't include fixed header bytes
223    /// Represents variable header + payload size
224    remaining_len: usize,
225}
226
227impl FixedHeader {
228    pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader {
229        FixedHeader {
230            byte1,
231            fixed_header_len: remaining_len_len + 1,
232            remaining_len,
233        }
234    }
235
236    pub fn packet_type(&self) -> Result<PacketType, Error> {
237        let num = self.byte1 >> 4;
238        match num {
239            1 => Ok(PacketType::Connect),
240            2 => Ok(PacketType::ConnAck),
241            3 => Ok(PacketType::Publish),
242            4 => Ok(PacketType::PubAck),
243            5 => Ok(PacketType::PubRec),
244            6 => Ok(PacketType::PubRel),
245            7 => Ok(PacketType::PubComp),
246            8 => Ok(PacketType::Subscribe),
247            9 => Ok(PacketType::SubAck),
248            10 => Ok(PacketType::Unsubscribe),
249            11 => Ok(PacketType::UnsubAck),
250            12 => Ok(PacketType::PingReq),
251            13 => Ok(PacketType::PingResp),
252            14 => Ok(PacketType::Disconnect),
253            _ => Err(Error::InvalidPacketType(num)),
254        }
255    }
256
257    /// Returns the size of full packet (fixed header + variable header + payload)
258    /// Fixed header is enough to get the size of a frame in the stream
259    pub fn frame_length(&self) -> usize {
260        self.fixed_header_len + self.remaining_len
261    }
262}
263
264fn property(num: u8) -> Result<PropertyType, Error> {
265    let property = match num {
266        1 => PropertyType::PayloadFormatIndicator,
267        2 => PropertyType::MessageExpiryInterval,
268        3 => PropertyType::ContentType,
269        8 => PropertyType::ResponseTopic,
270        9 => PropertyType::CorrelationData,
271        11 => PropertyType::SubscriptionIdentifier,
272        17 => PropertyType::SessionExpiryInterval,
273        18 => PropertyType::AssignedClientIdentifier,
274        19 => PropertyType::ServerKeepAlive,
275        21 => PropertyType::AuthenticationMethod,
276        22 => PropertyType::AuthenticationData,
277        23 => PropertyType::RequestProblemInformation,
278        24 => PropertyType::WillDelayInterval,
279        25 => PropertyType::RequestResponseInformation,
280        26 => PropertyType::ResponseInformation,
281        28 => PropertyType::ServerReference,
282        31 => PropertyType::ReasonString,
283        33 => PropertyType::ReceiveMaximum,
284        34 => PropertyType::TopicAliasMaximum,
285        35 => PropertyType::TopicAlias,
286        36 => PropertyType::MaximumQos,
287        37 => PropertyType::RetainAvailable,
288        38 => PropertyType::UserProperty,
289        39 => PropertyType::MaximumPacketSize,
290        40 => PropertyType::WildcardSubscriptionAvailable,
291        41 => PropertyType::SubscriptionIdentifierAvailable,
292        42 => PropertyType::SharedSubscriptionAvailable,
293        num => return Err(Error::InvalidPropertyType(num)),
294    };
295
296    Ok(property)
297}
298
299/// Checks if the stream has enough bytes to frame a packet and returns fixed header
300/// only if a packet can be framed with existing bytes in the `stream`.
301/// The passed stream doesn't modify parent stream's cursor. If this function
302/// returned an error, next `check` on the same parent stream is forced start
303/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally)
304pub fn check(stream: Iter<u8>, max_packet_size: Option<usize>) -> Result<FixedHeader, Error> {
305    // Create fixed header if there are enough bytes in the stream
306    // to frame full packet
307    let stream_len = stream.len();
308    let fixed_header = parse_fixed_header(stream)?;
309
310    // Don't let rogue connections attack with huge payloads.
311    // Disconnect them before reading all that data
312    if let Some(max_size) = max_packet_size {
313        if fixed_header.remaining_len > max_size {
314            return Err(Error::PayloadSizeLimitExceeded {
315                pkt_size: fixed_header.remaining_len,
316                max: max_size,
317            });
318        }
319    }
320
321    // If the current call fails due to insufficient bytes in the stream,
322    // after calculating remaining length, we extend the stream
323    let frame_length = fixed_header.frame_length();
324    if stream_len < frame_length {
325        return Err(Error::InsufficientBytes(frame_length - stream_len));
326    }
327
328    Ok(fixed_header)
329}
330
331/// Parses fixed header
332fn parse_fixed_header(mut stream: Iter<u8>) -> Result<FixedHeader, Error> {
333    // At least 2 bytes are necessary to frame a packet
334    let stream_len = stream.len();
335    if stream_len < 2 {
336        return Err(Error::InsufficientBytes(2 - stream_len));
337    }
338
339    let byte1 = stream.next().unwrap();
340    let (len_len, len) = length(stream)?;
341
342    Ok(FixedHeader::new(*byte1, len_len, len))
343}
344
345/// Parses variable byte integer in the stream and returns the length
346/// and number of bytes that make it. Used for remaining length calculation
347/// as well as for calculating property lengths
348fn length(stream: Iter<u8>) -> Result<(usize, usize), Error> {
349    let mut len: usize = 0;
350    let mut len_len = 0;
351    let mut done = false;
352    let mut shift = 0;
353
354    // Use continuation bit at position 7 to continue reading next
355    // byte to frame 'length'.
356    // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will
357    // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx
358    for byte in stream {
359        len_len += 1;
360        let byte = *byte as usize;
361        len += (byte & 0x7F) << shift;
362
363        // stop when continue bit is 0
364        done = (byte & 0x80) == 0;
365        if done {
366            break;
367        }
368
369        shift += 7;
370
371        // Only a max of 4 bytes allowed for remaining length
372        // more than 4 shifts (0, 7, 14, 21) implies bad length
373        if shift > 21 {
374            return Err(Error::MalformedRemainingLength);
375        }
376    }
377
378    // Not enough bytes to frame remaining length. wait for
379    // one more byte
380    if !done {
381        return Err(Error::InsufficientBytes(1));
382    }
383
384    Ok((len_len, len))
385}
386
387/// Reads a series of bytes with a length from a byte stream
388fn read_mqtt_bytes(stream: &mut Bytes) -> Result<Bytes, Error> {
389    let len = read_u16(stream)? as usize;
390
391    // Prevent attacks with wrong remaining length. This method is used in
392    // `packet.assembly()` with (enough) bytes to frame packet. Ensures that
393    // reading variable len string or bytes doesn't cross promised boundary
394    // with `read_fixed_header()`
395    if len > stream.len() {
396        return Err(Error::BoundaryCrossed(len));
397    }
398
399    Ok(stream.split_to(len))
400}
401
402/// Reads a string from bytes stream
403fn read_mqtt_string(stream: &mut Bytes) -> Result<String, Error> {
404    let s = read_mqtt_bytes(stream)?;
405    match String::from_utf8(s.to_vec()) {
406        Ok(v) => Ok(v),
407        Err(_e) => Err(Error::TopicNotUtf8),
408    }
409}
410
411/// Serializes bytes to stream (including length)
412fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) {
413    stream.put_u16(bytes.len() as u16);
414    stream.extend_from_slice(bytes);
415}
416
417/// Serializes a string to stream
418fn write_mqtt_string(stream: &mut BytesMut, string: &str) {
419    write_mqtt_bytes(stream, string.as_bytes());
420}
421
422/// Writes remaining length to stream and returns number of bytes for remaining length
423fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result<usize, Error> {
424    if len > 268_435_455 {
425        return Err(Error::PayloadTooLong);
426    }
427
428    let mut done = false;
429    let mut x = len;
430    let mut count = 0;
431
432    while !done {
433        let mut byte = (x % 128) as u8;
434        x /= 128;
435        if x > 0 {
436            byte |= 128;
437        }
438
439        stream.put_u8(byte);
440        count += 1;
441        done = x == 0;
442    }
443
444    Ok(count)
445}
446
447/// Return number of remaining length bytes required for encoding length
448fn len_len(len: usize) -> usize {
449    if len >= 2_097_152 {
450        4
451    } else if len >= 16_384 {
452        3
453    } else if len >= 128 {
454        2
455    } else {
456        1
457    }
458}
459
460/// After collecting enough bytes to frame a packet (packet's frame())
461/// , It's possible that content itself in the stream is wrong. Like expected
462/// packet id or qos not being present. In cases where `read_mqtt_string` or
463/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to
464/// parse qos next, these pre checks will prevent `bytes` crashes
465fn read_u16(stream: &mut Bytes) -> Result<u16, Error> {
466    if stream.len() < 2 {
467        return Err(Error::MalformedPacket);
468    }
469
470    Ok(stream.get_u16())
471}
472
473fn read_u8(stream: &mut Bytes) -> Result<u8, Error> {
474    if stream.is_empty() {
475        return Err(Error::MalformedPacket);
476    }
477
478    Ok(stream.get_u8())
479}
480
481fn read_u32(stream: &mut Bytes) -> Result<u32, Error> {
482    if stream.len() < 4 {
483        return Err(Error::MalformedPacket);
484    }
485
486    Ok(stream.get_u32())
487}
488
489mod test {
490    // These are used in tests by packets
491    #[allow(dead_code)]
492    pub const USER_PROP_KEY: &str = "property";
493    #[allow(dead_code)]
494    pub const USER_PROP_VAL: &str = "a value thats really long............................................................................................................";
495}