rumqttc/v5/
state.rs

1use super::mqttbytes::v5::{
2    ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck,
3    PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish,
4    SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe,
5};
6use super::mqttbytes::{self, QoS};
7
8use super::{Event, Incoming, Outgoing, Request};
9
10use bytes::{Bytes, BytesMut};
11use std::collections::{HashMap, VecDeque};
12use std::convert::TryInto;
13use std::{io, time::Instant};
14
15/// Errors during state handling
16#[derive(Debug, thiserror::Error)]
17pub enum StateError {
18    /// Io Error while state is passed to network
19    #[error("Io error: {0:?}")]
20    Io(#[from] io::Error),
21    #[error("Conversion error {0:?}")]
22    Coversion(#[from] core::num::TryFromIntError),
23    /// Invalid state for a given operation
24    #[error("Invalid state for a given operation")]
25    InvalidState,
26    /// Received a packet (ack) which isn't asked for
27    #[error("Received unsolicited ack pkid: {0}")]
28    Unsolicited(u16),
29    /// Last pingreq isn't acked
30    #[error("Last pingreq isn't acked")]
31    AwaitPingResp,
32    /// Received a wrong packet while waiting for another packet
33    #[error("Received a wrong packet while waiting for another packet")]
34    WrongPacket,
35    #[error("Timeout while waiting to resolve collision")]
36    CollisionTimeout,
37    #[error("A Subscribe packet must contain atleast one filter")]
38    EmptySubscription,
39    #[error("Mqtt serialization/deserialization error: {0}")]
40    Deserialization(#[from] mqttbytes::Error),
41    #[error(
42        "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'."
43    )]
44    InvalidAlias { alias: u16, max: u16 },
45    #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")]
46    OutgoingPacketTooLarge { pkt_size: u32, max: u32 },
47    #[error("Cannot receive packet of size '{pkt_size:?}'. It's greater than the client's maximum packet size of: '{max:?}'")]
48    IncomingPacketTooLarge { pkt_size: usize, max: usize },
49    #[error("Server sent disconnect with reason `{reason_string:?}` and code '{reason_code:?}' ")]
50    ServerDisconnect {
51        reason_code: DisconnectReasonCode,
52        reason_string: Option<String>,
53    },
54    #[error("Unsubscribe failed with reason '{reason:?}' ")]
55    UnsubFail { reason: UnsubAckReason },
56    #[error("Subscribe failed with reason '{reason:?}' ")]
57    SubFail { reason: SubscribeReasonCode },
58    #[error("Publish acknowledgement failed with reason '{reason:?}' ")]
59    PubAckFail { reason: PubAckReason },
60    #[error("Publish receive failed with reason '{reason:?}' ")]
61    PubRecFail { reason: PubRecReason },
62    #[error("Publish release failed with reason '{reason:?}' ")]
63    PubRelFail { reason: PubRelReason },
64    #[error("Publish completion failed with reason '{reason:?}' ")]
65    PubCompFail { reason: PubCompReason },
66    #[error("Connection failed with reason '{reason:?}' ")]
67    ConnFail { reason: ConnectReturnCode },
68}
69
70/// State of the mqtt connection.
71// Design: Methods will just modify the state of the object without doing any network operations
72// Design: All inflight queues are maintained in a pre initialized vec with index as packet id.
73// This is done for 2 reasons
74// Bad acks or out of order acks aren't O(n) causing cpu spikes
75// Any missing acks from the broker are detected during the next recycled use of packet ids
76#[derive(Debug, Clone)]
77pub struct MqttState {
78    /// Status of last ping
79    pub await_pingresp: bool,
80    /// Collision ping count. Collisions stop user requests
81    /// which inturn trigger pings. Multiple pings without
82    /// resolving collisions will result in error
83    pub collision_ping_count: usize,
84    /// Last incoming packet time
85    last_incoming: Instant,
86    /// Last outgoing packet time
87    last_outgoing: Instant,
88    /// Packet id of the last outgoing packet
89    pub(crate) last_pkid: u16,
90    /// Number of outgoing inflight publishes
91    pub(crate) inflight: u16,
92    /// Outgoing QoS 1, 2 publishes which aren't acked yet
93    pub(crate) outgoing_pub: Vec<Option<Publish>>,
94    /// Packet ids of released QoS 2 publishes
95    pub(crate) outgoing_rel: Vec<Option<u16>>,
96    /// Packet ids on incoming QoS 2 publishes
97    pub(crate) incoming_pub: Vec<Option<u16>>,
98    /// Last collision due to broker not acking in order
99    pub collision: Option<Publish>,
100    /// Buffered incoming packets
101    pub events: VecDeque<Event>,
102    /// Write buffer
103    pub write: BytesMut,
104    /// Indicates if acknowledgements should be send immediately
105    pub manual_acks: bool,
106    /// Map of alias_id->topic
107    topic_alises: HashMap<u16, Bytes>,
108    /// `topic_alias_maximum` RECEIVED via connack packet
109    pub broker_topic_alias_max: u16,
110    /// The broker's `max_packet_size` received via connack
111    pub max_outgoing_packet_size: Option<u32>,
112    /// Maximum number of allowed inflight QoS1 & QoS2 requests
113    pub(crate) max_outgoing_inflight: u16,
114    /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests
115    max_outgoing_inflight_upper_limit: u16,
116}
117
118impl MqttState {
119    /// Creates new mqtt state. Same state should be used during a
120    /// connection for persistent sessions while new state should
121    /// instantiated for clean sessions
122    pub fn new(max_inflight: u16, manual_acks: bool) -> Self {
123        MqttState {
124            await_pingresp: false,
125            collision_ping_count: 0,
126            last_incoming: Instant::now(),
127            last_outgoing: Instant::now(),
128            last_pkid: 0,
129            inflight: 0,
130            // index 0 is wasted as 0 is not a valid packet id
131            outgoing_pub: vec![None; max_inflight as usize + 1],
132            outgoing_rel: vec![None; max_inflight as usize + 1],
133            incoming_pub: vec![None; std::u16::MAX as usize + 1],
134            collision: None,
135            // TODO: Optimize these sizes later
136            events: VecDeque::with_capacity(100),
137            write: BytesMut::with_capacity(10 * 1024),
138            manual_acks,
139            topic_alises: HashMap::new(),
140            // Set via CONNACK
141            broker_topic_alias_max: 0,
142            max_outgoing_packet_size: None,
143            max_outgoing_inflight: max_inflight,
144            max_outgoing_inflight_upper_limit: max_inflight,
145        }
146    }
147
148    /// Returns inflight outgoing packets and clears internal queues
149    pub fn clean(&mut self) -> Vec<Request> {
150        let mut pending = Vec::with_capacity(100);
151        // remove and collect pending publishes
152        for publish in self.outgoing_pub.iter_mut() {
153            if let Some(publish) = publish.take() {
154                let request = Request::Publish(publish);
155                pending.push(request);
156            }
157        }
158
159        // remove and collect pending releases
160        for rel in self.outgoing_rel.iter_mut() {
161            if let Some(pkid) = rel.take() {
162                let request = Request::PubRel(PubRel::new(pkid, None));
163                pending.push(request);
164            }
165        }
166
167        // remove packed ids of incoming qos2 publishes
168        for id in self.incoming_pub.iter_mut() {
169            id.take();
170        }
171
172        self.await_pingresp = false;
173        self.collision_ping_count = 0;
174        self.inflight = 0;
175        pending
176    }
177
178    pub fn inflight(&self) -> u16 {
179        self.inflight
180    }
181
182    /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should
183    /// be put on to the network by the eventloop
184    pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> {
185        match request {
186            Request::Publish(publish) => {
187                self.check_size(publish.size())?;
188                self.outgoing_publish(publish)?
189            }
190            Request::PubRel(pubrel) => {
191                self.check_size(pubrel.size())?;
192                self.outgoing_pubrel(pubrel)?
193            }
194            Request::Subscribe(subscribe) => {
195                self.check_size(subscribe.size())?;
196                self.outgoing_subscribe(subscribe)?
197            }
198            Request::Unsubscribe(unsubscribe) => {
199                self.check_size(unsubscribe.size())?;
200                self.outgoing_unsubscribe(unsubscribe)?
201            }
202            Request::PingReq => self.outgoing_ping()?,
203            Request::Disconnect => {
204                self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)?
205            }
206            Request::PubAck(puback) => {
207                self.check_size(puback.size())?;
208                self.outgoing_puback(puback)?
209            }
210            Request::PubRec(pubrec) => {
211                self.check_size(pubrec.size())?;
212                self.outgoing_pubrec(pubrec)?
213            }
214            _ => unimplemented!(),
215        };
216
217        self.last_outgoing = Instant::now();
218        Ok(())
219    }
220
221    /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the
222    /// user to consume and `Packet` which for the eventloop to put on the network
223    /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will
224    /// be forwarded to user and Pubck packet will be written to network
225    pub fn handle_incoming_packet(&mut self, mut packet: Incoming) -> Result<(), StateError> {
226        let out = match &mut packet {
227            Incoming::PingResp(_) => self.handle_incoming_pingresp(),
228            Incoming::Publish(publish) => self.handle_incoming_publish(publish),
229            Incoming::SubAck(suback) => self.handle_incoming_suback(suback),
230            Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback),
231            Incoming::PubAck(puback) => self.handle_incoming_puback(puback),
232            Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec),
233            Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel),
234            Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp),
235            Incoming::ConnAck(connack) => self.handle_incoming_connack(connack),
236            Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn),
237            _ => {
238                error!("Invalid incoming packet = {:?}", packet);
239                return Err(StateError::WrongPacket);
240            }
241        };
242
243        out?;
244        self.events.push_back(Event::Incoming(packet));
245        self.last_incoming = Instant::now();
246        Ok(())
247    }
248
249    pub fn handle_protocol_error(&mut self) -> Result<(), StateError> {
250        // send DISCONNECT packet with REASON_CODE 0x82
251        self.outgoing_disconnect(DisconnectReasonCode::ProtocolError)
252    }
253
254    fn handle_incoming_suback(&mut self, suback: &mut SubAck) -> Result<(), StateError> {
255        for reason in suback.return_codes.iter() {
256            match reason {
257                SubscribeReasonCode::Success(qos) => {
258                    debug!("SubAck Pkid = {:?}, QoS = {:?}", suback.pkid, qos);
259                }
260                _ => return Err(StateError::SubFail { reason: *reason }),
261            }
262        }
263        Ok(())
264    }
265
266    fn handle_incoming_unsuback(&mut self, unsuback: &mut UnsubAck) -> Result<(), StateError> {
267        for reason in unsuback.reasons.iter() {
268            if reason != &UnsubAckReason::Success {
269                return Err(StateError::UnsubFail { reason: *reason });
270            }
271        }
272        Ok(())
273    }
274
275    fn handle_incoming_connack(&mut self, connack: &mut ConnAck) -> Result<(), StateError> {
276        if connack.code != ConnectReturnCode::Success {
277            return Err(StateError::ConnFail {
278                reason: connack.code,
279            });
280        }
281
282        if let Some(props) = &connack.properties {
283            if let Some(topic_alias_max) = props.topic_alias_max {
284                self.broker_topic_alias_max = topic_alias_max
285            }
286
287            if let Some(max_inflight) = props.receive_max {
288                self.max_outgoing_inflight =
289                    max_inflight.min(self.max_outgoing_inflight_upper_limit);
290                // FIXME: Maybe resize the pubrec and pubrel queues here
291                // to save some space.
292            }
293
294            self.max_outgoing_packet_size = props.max_packet_size;
295        }
296        Ok(())
297    }
298
299    fn handle_incoming_disconn(&mut self, disconn: &mut Disconnect) -> Result<(), StateError> {
300        let reason_code = disconn.reason_code;
301        let reason_string = if let Some(props) = &disconn.properties {
302            props.reason_string.clone()
303        } else {
304            None
305        };
306        Err(StateError::ServerDisconnect {
307            reason_code,
308            reason_string,
309        })
310    }
311
312    /// Results in a publish notification in all the QoS cases. Replys with an ack
313    /// in case of QoS1 and Replys rec in case of QoS while also storing the message
314    fn handle_incoming_publish(&mut self, publish: &mut Publish) -> Result<(), StateError> {
315        let qos = publish.qos;
316
317        let topic_alias = match &publish.properties {
318            Some(props) => props.topic_alias,
319            None => None,
320        };
321
322        if !publish.topic.is_empty() {
323            if let Some(alias) = topic_alias {
324                self.topic_alises.insert(alias, publish.topic.clone());
325            }
326        } else if let Some(alias) = topic_alias {
327            if let Some(topic) = self.topic_alises.get(&alias) {
328                publish.topic = topic.to_owned();
329            } else {
330                self.handle_protocol_error()?;
331            };
332        }
333
334        match qos {
335            QoS::AtMostOnce => Ok(()),
336            QoS::AtLeastOnce => {
337                if !self.manual_acks {
338                    let puback = PubAck::new(publish.pkid, None);
339                    self.outgoing_puback(puback)?;
340                }
341                Ok(())
342            }
343            QoS::ExactlyOnce => {
344                let pkid = publish.pkid;
345                self.incoming_pub[pkid as usize] = Some(pkid);
346
347                if !self.manual_acks {
348                    let pubrec = PubRec::new(pkid, None);
349                    self.outgoing_pubrec(pubrec)?;
350                }
351                Ok(())
352            }
353        }
354    }
355
356    fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> {
357        let publish = self
358            .outgoing_pub
359            .get_mut(puback.pkid as usize)
360            .ok_or(StateError::Unsolicited(puback.pkid))?;
361        let v = match publish.take() {
362            Some(_) => {
363                self.inflight -= 1;
364                Ok(())
365            }
366            None => {
367                error!("Unsolicited puback packet: {:?}", puback.pkid);
368                Err(StateError::Unsolicited(puback.pkid))
369            }
370        };
371
372        if puback.reason != PubAckReason::Success
373            && puback.reason != PubAckReason::NoMatchingSubscribers
374        {
375            return Err(StateError::PubAckFail {
376                reason: puback.reason,
377            });
378        }
379
380        if let Some(publish) = self.check_collision(puback.pkid) {
381            self.outgoing_pub[publish.pkid as usize] = Some(publish.clone());
382            self.inflight += 1;
383
384            let pkid = publish.pkid;
385            Packet::Publish(publish).write(&mut self.write)?;
386            let event = Event::Outgoing(Outgoing::Publish(pkid));
387            self.events.push_back(event);
388            self.collision_ping_count = 0;
389        }
390
391        v
392    }
393
394    fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> {
395        let publish = self
396            .outgoing_pub
397            .get_mut(pubrec.pkid as usize)
398            .ok_or(StateError::Unsolicited(pubrec.pkid))?;
399        match publish.take() {
400            Some(_) => {
401                if pubrec.reason != PubRecReason::Success
402                    && pubrec.reason != PubRecReason::NoMatchingSubscribers
403                {
404                    return Err(StateError::PubRecFail {
405                        reason: pubrec.reason,
406                    });
407                }
408
409                // NOTE: Inflight - 1 for qos2 in comp
410                self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid);
411                Packet::PubRel(PubRel::new(pubrec.pkid, None)).write(&mut self.write)?;
412
413                let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
414                self.events.push_back(event);
415                Ok(())
416            }
417            None => {
418                error!("Unsolicited pubrec packet: {:?}", pubrec.pkid);
419                Err(StateError::Unsolicited(pubrec.pkid))
420            }
421        }
422    }
423
424    fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> {
425        let publish = self
426            .incoming_pub
427            .get_mut(pubrel.pkid as usize)
428            .ok_or(StateError::Unsolicited(pubrel.pkid))?;
429        match publish.take() {
430            Some(_) => {
431                if pubrel.reason != PubRelReason::Success {
432                    return Err(StateError::PubRelFail {
433                        reason: pubrel.reason,
434                    });
435                }
436
437                Packet::PubComp(PubComp::new(pubrel.pkid, None)).write(&mut self.write)?;
438                let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
439                self.events.push_back(event);
440                Ok(())
441            }
442            None => {
443                error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
444                Err(StateError::Unsolicited(pubrel.pkid))
445            }
446        }
447    }
448
449    fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> {
450        if let Some(publish) = self.check_collision(pubcomp.pkid) {
451            let pkid = publish.pkid;
452            Packet::Publish(publish).write(&mut self.write)?;
453            let event = Event::Outgoing(Outgoing::Publish(pkid));
454            self.events.push_back(event);
455            self.collision_ping_count = 0;
456        }
457
458        let pubrel = self
459            .outgoing_rel
460            .get_mut(pubcomp.pkid as usize)
461            .ok_or(StateError::Unsolicited(pubcomp.pkid))?;
462        match pubrel.take() {
463            Some(_) => {
464                if pubcomp.reason != PubCompReason::Success {
465                    return Err(StateError::PubCompFail {
466                        reason: pubcomp.reason,
467                    });
468                }
469
470                self.inflight -= 1;
471                Ok(())
472            }
473            None => {
474                error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
475                Err(StateError::Unsolicited(pubcomp.pkid))
476            }
477        }
478    }
479
480    fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> {
481        self.await_pingresp = false;
482        Ok(())
483    }
484
485    /// Adds next packet identifier to QoS 1 and 2 publish packets and returns
486    /// it buy wrapping publish in packet
487    fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> {
488        if publish.qos != QoS::AtMostOnce {
489            if publish.pkid == 0 {
490                publish.pkid = self.next_pkid();
491            }
492
493            let pkid = publish.pkid;
494            if self
495                .outgoing_pub
496                .get(publish.pkid as usize)
497                .ok_or(StateError::Unsolicited(publish.pkid))?
498                .is_some()
499            {
500                info!("Collision on packet id = {:?}", publish.pkid);
501                self.collision = Some(publish);
502                let event = Event::Outgoing(Outgoing::AwaitAck(pkid));
503                self.events.push_back(event);
504                return Ok(());
505            }
506
507            // if there is an existing publish at this pkid, this implies that broker hasn't acked this
508            // packet yet. This error is possible only when broker isn't acking sequentially
509            self.outgoing_pub[pkid as usize] = Some(publish.clone());
510            self.inflight += 1;
511        };
512
513        debug!(
514            "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}",
515            String::from_utf8(publish.topic.to_vec()).unwrap(),
516            publish.pkid,
517            publish.payload.len()
518        );
519
520        let pkid = publish.pkid;
521
522        if let Some(props) = &publish.properties {
523            if let Some(alias) = props.topic_alias {
524                if alias > self.broker_topic_alias_max {
525                    // We MUST NOT send a Topic Alias that is greater than the
526                    // broker's Topic Alias Maximum.
527                    return Err(StateError::InvalidAlias {
528                        alias,
529                        max: self.broker_topic_alias_max,
530                    });
531                }
532            }
533        };
534
535        Packet::Publish(publish).write(&mut self.write)?;
536        let event = Event::Outgoing(Outgoing::Publish(pkid));
537        self.events.push_back(event);
538        Ok(())
539    }
540
541    fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> {
542        let pubrel = self.save_pubrel(pubrel)?;
543
544        debug!("Pubrel. Pkid = {}", pubrel.pkid);
545        Packet::PubRel(PubRel::new(pubrel.pkid, None)).write(&mut self.write)?;
546
547        let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid));
548        self.events.push_back(event);
549        Ok(())
550    }
551
552    fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> {
553        let pkid = puback.pkid;
554        Packet::PubAck(puback).write(&mut self.write)?;
555        let event = Event::Outgoing(Outgoing::PubAck(pkid));
556        self.events.push_back(event);
557        Ok(())
558    }
559
560    fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> {
561        let pkid = pubrec.pkid;
562        Packet::PubRec(pubrec).write(&mut self.write)?;
563        let event = Event::Outgoing(Outgoing::PubRec(pkid));
564        self.events.push_back(event);
565        Ok(())
566    }
567
568    /// check when the last control packet/pingreq packet is received and return
569    /// the status which tells if keep alive time has exceeded
570    /// NOTE: status will be checked for zero keepalive times also
571    fn outgoing_ping(&mut self) -> Result<(), StateError> {
572        let elapsed_in = self.last_incoming.elapsed();
573        let elapsed_out = self.last_outgoing.elapsed();
574
575        if self.collision.is_some() {
576            self.collision_ping_count += 1;
577            if self.collision_ping_count >= 2 {
578                return Err(StateError::CollisionTimeout);
579            }
580        }
581
582        // raise error if last ping didn't receive ack
583        if self.await_pingresp {
584            return Err(StateError::AwaitPingResp);
585        }
586
587        self.await_pingresp = true;
588
589        debug!(
590            "Pingreq, last incoming packet before {:?}, last outgoing request before {:?}",
591            elapsed_in, elapsed_out,
592        );
593
594        Packet::PingReq(PingReq).write(&mut self.write)?;
595        let event = Event::Outgoing(Outgoing::PingReq);
596        self.events.push_back(event);
597        Ok(())
598    }
599
600    fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> {
601        if subscription.filters.is_empty() {
602            return Err(StateError::EmptySubscription);
603        }
604
605        let pkid = self.next_pkid();
606        subscription.pkid = pkid;
607
608        debug!(
609            "Subscribe. Topics = {:?}, Pkid = {:?}",
610            subscription.filters, subscription.pkid
611        );
612
613        let pkid = subscription.pkid;
614        Packet::Subscribe(subscription).write(&mut self.write)?;
615        let event = Event::Outgoing(Outgoing::Subscribe(pkid));
616        self.events.push_back(event);
617        Ok(())
618    }
619
620    fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> {
621        let pkid = self.next_pkid();
622        unsub.pkid = pkid;
623
624        debug!(
625            "Unsubscribe. Topics = {:?}, Pkid = {:?}",
626            unsub.filters, unsub.pkid
627        );
628
629        let pkid = unsub.pkid;
630        Packet::Unsubscribe(unsub).write(&mut self.write)?;
631        let event = Event::Outgoing(Outgoing::Unsubscribe(pkid));
632        self.events.push_back(event);
633        Ok(())
634    }
635
636    fn outgoing_disconnect(&mut self, reason: DisconnectReasonCode) -> Result<(), StateError> {
637        debug!("Disconnect with {:?}", reason);
638
639        Packet::Disconnect(Disconnect::new(reason)).write(&mut self.write)?;
640        let event = Event::Outgoing(Outgoing::Disconnect);
641        self.events.push_back(event);
642        Ok(())
643    }
644
645    fn check_collision(&mut self, pkid: u16) -> Option<Publish> {
646        if let Some(publish) = &self.collision {
647            if publish.pkid == pkid {
648                return self.collision.take();
649            }
650        }
651
652        None
653    }
654
655    fn check_size(&self, pkt_size: usize) -> Result<(), StateError> {
656        let pkt_size = pkt_size.try_into()?;
657
658        match self.max_outgoing_packet_size {
659            Some(max_size) if pkt_size > max_size => Err(StateError::OutgoingPacketTooLarge {
660                pkt_size,
661                max: max_size,
662            }),
663            _ => Ok(()),
664        }
665    }
666
667    fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result<PubRel, StateError> {
668        let pubrel = match pubrel.pkid {
669            // consider PacketIdentifier(0) as uninitialized packets
670            0 => {
671                pubrel.pkid = self.next_pkid();
672                pubrel
673            }
674            _ => pubrel,
675        };
676
677        self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid);
678        self.inflight += 1;
679        Ok(pubrel)
680    }
681
682    /// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation
683    /// Packet ids are incremented till maximum set inflight messages and reset to 1 after that.
684    ///
685    fn next_pkid(&mut self) -> u16 {
686        let next_pkid = self.last_pkid + 1;
687
688        // When next packet id is at the edge of inflight queue,
689        // set await flag. This instructs eventloop to stop
690        // processing requests until all the inflight publishes
691        // are acked
692        if next_pkid == self.max_outgoing_inflight {
693            self.last_pkid = 0;
694            return next_pkid;
695        }
696
697        self.last_pkid = next_pkid;
698        next_pkid
699    }
700}
701
702#[cfg(test)]
703mod test {
704    use super::mqttbytes::v5::*;
705    use super::mqttbytes::*;
706    use super::{Event, Incoming, Outgoing, Request};
707    use super::{MqttState, StateError};
708
709    fn build_outgoing_publish(qos: QoS) -> Publish {
710        let topic = "hello/world".to_owned();
711        let payload = vec![1, 2, 3];
712
713        let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload, None);
714        publish.qos = qos;
715        publish
716    }
717
718    fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish {
719        let topic = "hello/world".to_owned();
720        let payload = vec![1, 2, 3];
721
722        let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload, None);
723        publish.pkid = pkid;
724        publish.qos = qos;
725        publish
726    }
727
728    fn build_mqttstate() -> MqttState {
729        MqttState::new(u16::MAX, false)
730    }
731
732    #[test]
733    fn next_pkid_increments_as_expected() {
734        let mut mqtt = build_mqttstate();
735
736        for i in 1..=100 {
737            let pkid = mqtt.next_pkid();
738
739            // loops between 0-99. % 100 == 0 implies border
740            let expected = i % 100;
741            if expected == 0 {
742                break;
743            }
744
745            assert_eq!(expected, pkid);
746        }
747    }
748
749    #[test]
750    fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() {
751        let mut mqtt = build_mqttstate();
752
753        // QoS0 Publish
754        let publish = build_outgoing_publish(QoS::AtMostOnce);
755
756        // QoS 0 publish shouldn't be saved in queue
757        mqtt.outgoing_publish(publish).unwrap();
758        assert_eq!(mqtt.last_pkid, 0);
759        assert_eq!(mqtt.inflight, 0);
760
761        // QoS1 Publish
762        let publish = build_outgoing_publish(QoS::AtLeastOnce);
763
764        // Packet id should be set and publish should be saved in queue
765        mqtt.outgoing_publish(publish.clone()).unwrap();
766        assert_eq!(mqtt.last_pkid, 1);
767        assert_eq!(mqtt.inflight, 1);
768
769        // Packet id should be incremented and publish should be saved in queue
770        mqtt.outgoing_publish(publish).unwrap();
771        assert_eq!(mqtt.last_pkid, 2);
772        assert_eq!(mqtt.inflight, 2);
773
774        // QoS1 Publish
775        let publish = build_outgoing_publish(QoS::ExactlyOnce);
776
777        // Packet id should be set and publish should be saved in queue
778        mqtt.outgoing_publish(publish.clone()).unwrap();
779        assert_eq!(mqtt.last_pkid, 3);
780        assert_eq!(mqtt.inflight, 3);
781
782        // Packet id should be incremented and publish should be saved in queue
783        mqtt.outgoing_publish(publish).unwrap();
784        assert_eq!(mqtt.last_pkid, 4);
785        assert_eq!(mqtt.inflight, 4);
786    }
787
788    #[test]
789    fn outgoing_publish_with_max_inflight_is_ok() {
790        let mut mqtt = MqttState::new(2, false);
791
792        // QoS2 publish
793        let publish = build_outgoing_publish(QoS::ExactlyOnce);
794
795        mqtt.outgoing_publish(publish.clone()).unwrap();
796        assert_eq!(mqtt.last_pkid, 1);
797        assert_eq!(mqtt.inflight, 1);
798
799        // Packet id should be set back down to 0, since we hit the limit
800        mqtt.outgoing_publish(publish.clone()).unwrap();
801        assert_eq!(mqtt.last_pkid, 0);
802        assert_eq!(mqtt.inflight, 2);
803
804        // This should cause a collition
805        mqtt.outgoing_publish(publish.clone()).unwrap();
806        assert_eq!(mqtt.last_pkid, 1);
807        assert_eq!(mqtt.inflight, 2);
808        assert!(mqtt.collision.is_some());
809
810        mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
811        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
812        assert_eq!(mqtt.inflight, 1);
813
814        // Now there should be space in the outgoing queue
815        mqtt.outgoing_publish(publish.clone()).unwrap();
816        assert_eq!(mqtt.last_pkid, 0);
817        assert_eq!(mqtt.inflight, 2);
818    }
819
820    #[test]
821    fn incoming_publish_should_be_added_to_queue_correctly() {
822        let mut mqtt = build_mqttstate();
823
824        // QoS0, 1, 2 Publishes
825        let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
826        let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
827        let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
828
829        mqtt.handle_incoming_publish(&mut publish1).unwrap();
830        mqtt.handle_incoming_publish(&mut publish2).unwrap();
831        mqtt.handle_incoming_publish(&mut publish3).unwrap();
832
833        let pkid = mqtt.incoming_pub[3].unwrap();
834
835        // only qos2 publish should be add to queue
836        assert_eq!(pkid, 3);
837    }
838
839    #[test]
840    fn incoming_publish_should_be_acked() {
841        let mut mqtt = build_mqttstate();
842
843        // QoS0, 1, 2 Publishes
844        let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
845        let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
846        let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
847
848        mqtt.handle_incoming_publish(&mut publish1).unwrap();
849        mqtt.handle_incoming_publish(&mut publish2).unwrap();
850        mqtt.handle_incoming_publish(&mut publish3).unwrap();
851
852        if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] {
853            assert_eq!(pkid, 2);
854        } else {
855            panic!("missing puback");
856        }
857
858        if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] {
859            assert_eq!(pkid, 3);
860        } else {
861            panic!("missing PubRec");
862        }
863    }
864
865    #[test]
866    fn incoming_publish_should_not_be_acked_with_manual_acks() {
867        let mut mqtt = build_mqttstate();
868        mqtt.manual_acks = true;
869
870        // QoS0, 1, 2 Publishes
871        let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
872        let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
873        let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);
874
875        mqtt.handle_incoming_publish(&mut publish1).unwrap();
876        mqtt.handle_incoming_publish(&mut publish2).unwrap();
877        mqtt.handle_incoming_publish(&mut publish3).unwrap();
878
879        let pkid = mqtt.incoming_pub[3].unwrap();
880        assert_eq!(pkid, 3);
881
882        assert!(mqtt.events.is_empty());
883    }
884
885    #[test]
886    fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
887        let mut mqtt = build_mqttstate();
888        let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1);
889
890        mqtt.handle_incoming_publish(&mut publish).unwrap();
891        let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap();
892        match packet {
893            Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
894            _ => panic!("Invalid network request: {:?}", packet),
895        }
896    }
897
898    #[test]
899    fn incoming_puback_should_remove_correct_publish_from_queue() {
900        let mut mqtt = build_mqttstate();
901
902        let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
903        let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
904
905        mqtt.outgoing_publish(publish1).unwrap();
906        mqtt.outgoing_publish(publish2).unwrap();
907        assert_eq!(mqtt.inflight, 2);
908
909        mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap();
910        assert_eq!(mqtt.inflight, 1);
911
912        mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap();
913        assert_eq!(mqtt.inflight, 0);
914
915        assert!(mqtt.outgoing_pub[1].is_none());
916        assert!(mqtt.outgoing_pub[2].is_none());
917    }
918
919    #[test]
920    fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() {
921        let mut mqtt = build_mqttstate();
922
923        let got = mqtt
924            .handle_incoming_puback(&PubAck::new(101, None))
925            .unwrap_err();
926
927        match got {
928            StateError::Unsolicited(pkid) => assert_eq!(pkid, 101),
929            e => panic!("Unexpected error: {}", e),
930        }
931    }
932
933    #[test]
934    fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() {
935        let mut mqtt = build_mqttstate();
936
937        let publish1 = build_outgoing_publish(QoS::AtLeastOnce);
938        let publish2 = build_outgoing_publish(QoS::ExactlyOnce);
939
940        let _publish_out = mqtt.outgoing_publish(publish1);
941        let _publish_out = mqtt.outgoing_publish(publish2);
942
943        mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap();
944        assert_eq!(mqtt.inflight, 2);
945
946        // check if the remaining element's pkid is 1
947        let backup = mqtt.outgoing_pub[1].clone();
948        assert_eq!(backup.unwrap().pkid, 1);
949
950        // check if the qos2 element's release pkid is 2
951        assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2);
952    }
953
954    #[test]
955    fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() {
956        let mut mqtt = build_mqttstate();
957
958        let publish = build_outgoing_publish(QoS::ExactlyOnce);
959        mqtt.outgoing_publish(publish).unwrap();
960        let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap();
961        match packet {
962            Packet::Publish(publish) => assert_eq!(publish.pkid, 1),
963            packet => panic!("Invalid network request: {:?}", packet),
964        }
965
966        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
967        let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap();
968        match packet {
969            Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1),
970            packet => panic!("Invalid network request: {:?}", packet),
971        }
972    }
973
974    #[test]
975    fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() {
976        let mut mqtt = build_mqttstate();
977        let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1);
978
979        mqtt.handle_incoming_publish(&mut publish).unwrap();
980        let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap();
981        match packet {
982            Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1),
983            packet => panic!("Invalid network request: {:?}", packet),
984        }
985
986        mqtt.handle_incoming_pubrel(&PubRel::new(1, None)).unwrap();
987        let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap();
988        match packet {
989            Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1),
990            packet => panic!("Invalid network request: {:?}", packet),
991        }
992    }
993
994    #[test]
995    fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() {
996        let mut mqtt = build_mqttstate();
997        let publish = build_outgoing_publish(QoS::ExactlyOnce);
998
999        mqtt.outgoing_publish(publish).unwrap();
1000        mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap();
1001
1002        mqtt.handle_incoming_pubcomp(&PubComp::new(1, None))
1003            .unwrap();
1004        assert_eq!(mqtt.inflight, 0);
1005    }
1006
1007    #[test]
1008    fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() {
1009        let mut mqtt = build_mqttstate();
1010        mqtt.outgoing_ping().unwrap();
1011
1012        // network activity other than pingresp
1013        let publish = build_outgoing_publish(QoS::AtLeastOnce);
1014        mqtt.handle_outgoing_packet(Request::Publish(publish))
1015            .unwrap();
1016        mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None)))
1017            .unwrap();
1018
1019        // should throw error because we didn't get pingresp for previous ping
1020        match mqtt.outgoing_ping() {
1021            Ok(_) => panic!("Should throw pingresp await error"),
1022            Err(StateError::AwaitPingResp) => (),
1023            Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e),
1024        }
1025    }
1026
1027    #[test]
1028    fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() {
1029        let mut mqtt = build_mqttstate();
1030
1031        // should ping
1032        mqtt.outgoing_ping().unwrap();
1033        mqtt.handle_incoming_packet(Incoming::PingResp(PingResp))
1034            .unwrap();
1035
1036        // should ping
1037        mqtt.outgoing_ping().unwrap();
1038    }
1039}