rumqttc/v5/
eventloop.rs

1use super::framed::Network;
2use super::mqttbytes::v5::*;
3use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
4use crate::eventloop::socket_connect;
5use crate::framed::N;
6
7use flume::{bounded, Receiver, Sender};
8use tokio::select;
9use tokio::time::{self, error::Elapsed, Instant, Sleep};
10
11use std::collections::VecDeque;
12use std::convert::TryInto;
13use std::io;
14use std::pin::Pin;
15use std::time::Duration;
16
17use super::mqttbytes::v5::ConnectReturnCode;
18
19#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
20use crate::tls;
21
22#[cfg(unix)]
23use {std::path::Path, tokio::net::UnixStream};
24
25#[cfg(feature = "websocket")]
26use {
27    crate::websockets::{split_url, validate_response_headers, UrlError},
28    async_tungstenite::tungstenite::client::IntoClientRequest,
29    ws_stream_tungstenite::WsStream,
30};
31
32#[cfg(feature = "proxy")]
33use crate::proxy::ProxyError;
34
35/// Critical errors during eventloop polling
36#[derive(Debug, thiserror::Error)]
37pub enum ConnectionError {
38    #[error("Mqtt state: {0}")]
39    MqttState(#[from] StateError),
40    #[error("Timeout")]
41    Timeout(#[from] Elapsed),
42    #[cfg(feature = "websocket")]
43    #[error("Websocket: {0}")]
44    Websocket(#[from] async_tungstenite::tungstenite::error::Error),
45    #[cfg(feature = "websocket")]
46    #[error("Websocket Connect: {0}")]
47    WsConnect(#[from] http::Error),
48    #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
49    #[error("TLS: {0}")]
50    Tls(#[from] tls::Error),
51    #[error("I/O: {0}")]
52    Io(#[from] io::Error),
53    #[error("Connection refused, return code: `{0:?}`")]
54    ConnectionRefused(ConnectReturnCode),
55    #[error("Expected ConnAck packet, received: {0:?}")]
56    NotConnAck(Box<Packet>),
57    #[error("Requests done")]
58    RequestsDone,
59    #[cfg(feature = "websocket")]
60    #[error("Invalid Url: {0}")]
61    InvalidUrl(#[from] UrlError),
62    #[cfg(feature = "proxy")]
63    #[error("Proxy Connect: {0}")]
64    Proxy(#[from] ProxyError),
65    #[cfg(feature = "websocket")]
66    #[error("Websocket response validation error: ")]
67    ResponseValidation(#[from] crate::websockets::ValidationError),
68}
69
70/// Eventloop with all the state of a connection
71pub struct EventLoop {
72    /// Options of the current mqtt connection
73    pub options: MqttOptions,
74    /// Current state of the connection
75    pub state: MqttState,
76    /// Request stream
77    requests_rx: Receiver<Request>,
78    /// Requests handle to send requests
79    pub(crate) requests_tx: Sender<Request>,
80    /// Pending packets from last session
81    pub pending: VecDeque<Request>,
82    /// Network connection to the broker
83    network: Option<Network>,
84    /// Keep alive time
85    keepalive_timeout: Option<Pin<Box<Sleep>>>,
86}
87
88/// Events which can be yielded by the event loop
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum Event {
91    Incoming(Incoming),
92    Outgoing(Outgoing),
93}
94
95impl EventLoop {
96    /// New MQTT `EventLoop`
97    ///
98    /// When connection encounters critical errors (like auth failure), user has a choice to
99    /// access and update `options`, `state` and `requests`.
100    pub fn new(options: MqttOptions, cap: usize) -> EventLoop {
101        let (requests_tx, requests_rx) = bounded(cap);
102        let pending = VecDeque::new();
103        let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX);
104        let manual_acks = options.manual_acks;
105
106        EventLoop {
107            options,
108            state: MqttState::new(inflight_limit, manual_acks),
109            requests_tx,
110            requests_rx,
111            pending,
112            network: None,
113            keepalive_timeout: None,
114        }
115    }
116
117    /// Last session might contain packets which aren't acked. MQTT says these packets should be
118    /// republished in the next session. Move pending messages from state to eventloop, drops the
119    /// underlying network connection and clears the keepalive timeout if any.
120    ///
121    /// > NOTE: Use only when EventLoop is blocked on network and unable to immediately handle disconnect.
122    /// > Also, while this helps prevent data loss, the pending list length should be managed properly.
123    /// > For this reason we recommend setting [`AsycClient`](super::AsyncClient)'s channel capacity to `0`.
124    pub fn clean(&mut self) {
125        self.network = None;
126        self.keepalive_timeout = None;
127        self.pending.extend(self.state.clean());
128
129        // drain requests from channel which weren't yet received
130        let requests_in_channel = self.requests_rx.drain();
131        self.pending.extend(requests_in_channel);
132    }
133
134    /// Yields Next notification or outgoing request and periodically pings
135    /// the broker. Continuing to poll will reconnect to the broker if there is
136    /// a disconnection.
137    /// **NOTE** Don't block this while iterating
138    pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
139        if self.network.is_none() {
140            let (network, connack) = time::timeout(
141                Duration::from_secs(self.options.connection_timeout()),
142                connect(&mut self.options),
143            )
144            .await??;
145            self.network = Some(network);
146
147            if self.keepalive_timeout.is_none() {
148                self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
149            }
150
151            self.state.handle_incoming_packet(connack)?;
152        }
153
154        match self.select().await {
155            Ok(v) => Ok(v),
156            Err(e) => {
157                self.clean();
158                Err(e)
159            }
160        }
161    }
162
163    /// Select on network and requests and generate keepalive pings when necessary
164    async fn select(&mut self) -> Result<Event, ConnectionError> {
165        let network = self.network.as_mut().unwrap();
166        // let await_acks = self.state.await_acks;
167
168        let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
169        let collision = self.state.collision.is_some();
170
171        // Read buffered events from previous polls before calling a new poll
172        if let Some(event) = self.state.events.pop_front() {
173            return Ok(event);
174        }
175
176        // this loop is necessary since self.incoming.pop_front() might return None. In that case,
177        // instead of returning a None event, we try again.
178        select! {
179            // Handles pending and new requests.
180            // If available, prioritises pending requests from previous session.
181            // Else, pulls next request from user requests channel.
182            // If conditions in the below branch are for flow control.
183            // The branch is disabled if there's no pending messages and new user requests
184            // cannot be serviced due flow control.
185            // We read next user user request only when inflight messages are < configured inflight
186            // and there are no collisions while handling previous outgoing requests.
187            //
188            // Flow control is based on ack count. If inflight packet count in the buffer is
189            // less than max_inflight setting, next outgoing request will progress. For this
190            // to work correctly, broker should ack in sequence (a lot of brokers won't)
191            //
192            // E.g If max inflight = 5, user requests will be blocked when inflight queue
193            // looks like this                 -> [1, 2, 3, 4, 5].
194            // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5].
195            // This pulls next user request. But because max packet id = max_inflight, next
196            // user request's packet id will roll to 1. This replaces existing packet id 1.
197            // Resulting in a collision
198            //
199            // Eventloop can stop receiving outgoing user requests when previous outgoing
200            // request collided. I.e collision state. Collision state will be cleared only
201            // when correct ack is received
202            // Full inflight queue will look like -> [1a, 2, 3, 4, 5].
203            // If 3 is acked instead of 1 first   -> [1a, 2, x, 4, 5].
204            // After collision with pkid 1        -> [1b ,2, x, 4, 5].
205            // 1a is saved to state and event loop is set to collision mode stopping new
206            // outgoing requests (along with 1b).
207            o = Self::next_request(
208                &mut self.pending,
209                &self.requests_rx,
210                self.options.pending_throttle
211            ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
212                Ok(request) => {
213                    self.state.handle_outgoing_packet(request)?;
214                    network.flush(&mut self.state.write).await?;
215                    Ok(self.state.events.pop_front().unwrap())
216                }
217                Err(_) => Err(ConnectionError::RequestsDone),
218            },
219            // Pull a bunch of packets from network, reply in bunch and yield the first item
220            o = network.readb(&mut self.state) => {
221                o?;
222                // flush all the acks and return first incoming packet
223                network.flush(&mut self.state.write).await?;
224                Ok(self.state.events.pop_front().unwrap())
225            },
226            // We generate pings irrespective of network activity. This keeps the ping logic
227            // simple. We can change this behavior in future if necessary (to prevent extra pings)
228            _ = self.keepalive_timeout.as_mut().unwrap() => {
229                let timeout = self.keepalive_timeout.as_mut().unwrap();
230                timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
231
232                self.state.handle_outgoing_packet(Request::PingReq)?;
233                network.flush(&mut self.state.write).await?;
234                Ok(self.state.events.pop_front().unwrap())
235            }
236        }
237    }
238
239    async fn next_request(
240        pending: &mut VecDeque<Request>,
241        rx: &Receiver<Request>,
242        pending_throttle: Duration,
243    ) -> Result<Request, ConnectionError> {
244        if !pending.is_empty() {
245            time::sleep(pending_throttle).await;
246            // We must call .next() AFTER sleep() otherwise .next() would
247            // advance the iterator but the future might be canceled before return
248            Ok(pending.pop_front().unwrap())
249        } else {
250            match rx.recv_async().await {
251                Ok(r) => Ok(r),
252                Err(_) => Err(ConnectionError::RequestsDone),
253            }
254        }
255    }
256}
257
258/// This stream internally processes requests from the request stream provided to the eventloop
259/// while also consuming byte stream from the network and yielding mqtt packets as the output of
260/// the stream.
261/// This function (for convenience) includes internal delays for users to perform internal sleeps
262/// between re-connections so that cancel semantics can be used during this sleep
263async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), ConnectionError> {
264    // connect to the broker
265    let mut network = network_connect(options).await?;
266
267    // make MQTT connection request (which internally awaits for ack)
268    let packet = mqtt_connect(options, &mut network).await?;
269
270    // Last session might contain packets which aren't acked. MQTT says these packets should be
271    // republished in the next session
272    // move pending messages from state to eventloop
273    // let pending = self.state.clean();
274    // self.pending = pending.into_iter();
275    Ok((network, packet))
276}
277
278async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
279    let mut max_incoming_pkt_size = Some(options.default_max_incoming_size);
280
281    // Override default value if max_packet_size is set on `connect_properties`
282    if let Some(connect_props) = &options.connect_properties {
283        if let Some(max_size) = connect_props.max_packet_size {
284            let max_size = max_size.try_into().map_err(StateError::Coversion)?;
285            max_incoming_pkt_size = Some(max_size);
286        }
287    }
288
289    // Process Unix files early, as proxy is not supported for them.
290    #[cfg(unix)]
291    if matches!(options.transport(), Transport::Unix) {
292        let file = options.broker_addr.as_str();
293        let socket = UnixStream::connect(Path::new(file)).await?;
294        let network = Network::new(socket, max_incoming_pkt_size);
295        return Ok(network);
296    }
297
298    // For websockets domain and port are taken directly from `broker_addr` (which is a url).
299    let (domain, port) = match options.transport() {
300        #[cfg(feature = "websocket")]
301        Transport::Ws => split_url(&options.broker_addr)?,
302        #[cfg(all(feature = "use-rustls", feature = "websocket"))]
303        Transport::Wss(_) => split_url(&options.broker_addr)?,
304        _ => options.broker_address(),
305    };
306
307    let tcp_stream: Box<dyn N> = {
308        #[cfg(feature = "proxy")]
309        match options.proxy() {
310            Some(proxy) => {
311                proxy
312                    .connect(&domain, port, options.network_options())
313                    .await?
314            }
315            None => {
316                let addr = format!("{domain}:{port}");
317                let tcp = socket_connect(addr, options.network_options()).await?;
318                Box::new(tcp)
319            }
320        }
321        #[cfg(not(feature = "proxy"))]
322        {
323            let addr = format!("{domain}:{port}");
324            let tcp = socket_connect(addr, options.network_options()).await?;
325            Box::new(tcp)
326        }
327    };
328
329    let network = match options.transport() {
330        Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
331        #[cfg(any(feature = "use-native-tls", feature = "use-rustls"))]
332        Transport::Tls(tls_config) => {
333            let socket =
334                tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
335                    .await?;
336            Network::new(socket, max_incoming_pkt_size)
337        }
338        #[cfg(unix)]
339        Transport::Unix => unreachable!(),
340        #[cfg(feature = "websocket")]
341        Transport::Ws => {
342            let mut request = options.broker_addr.as_str().into_client_request()?;
343            request
344                .headers_mut()
345                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
346
347            if let Some(request_modifier) = options.request_modifier() {
348                request = request_modifier(request).await;
349            }
350
351            let (socket, response) =
352                async_tungstenite::tokio::client_async(request, tcp_stream).await?;
353            validate_response_headers(response)?;
354
355            Network::new(WsStream::new(socket), max_incoming_pkt_size)
356        }
357        #[cfg(all(feature = "use-rustls", feature = "websocket"))]
358        Transport::Wss(tls_config) => {
359            let mut request = options.broker_addr.as_str().into_client_request()?;
360            request
361                .headers_mut()
362                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
363
364            if let Some(request_modifier) = options.request_modifier() {
365                request = request_modifier(request).await;
366            }
367
368            let connector = tls::rustls_connector(&tls_config).await?;
369
370            let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
371                request,
372                tcp_stream,
373                Some(connector),
374            )
375            .await?;
376            validate_response_headers(response)?;
377
378            Network::new(WsStream::new(socket), max_incoming_pkt_size)
379        }
380    };
381
382    Ok(network)
383}
384
385async fn mqtt_connect(
386    options: &mut MqttOptions,
387    network: &mut Network,
388) -> Result<Incoming, ConnectionError> {
389    let keep_alive = options.keep_alive().as_secs() as u16;
390    let clean_start = options.clean_start();
391    let client_id = options.client_id();
392    let properties = options.connect_properties();
393
394    let connect = Connect {
395        keep_alive,
396        client_id,
397        clean_start,
398        properties,
399    };
400
401    // send mqtt connect packet
402    network.connect(connect, options).await?;
403
404    // validate connack
405    match network.read().await? {
406        Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
407            // Override local keep_alive value if set by server.
408            if let Some(props) = &connack.properties {
409                if let Some(keep_alive) = props.server_keep_alive {
410                    options.keep_alive = Duration::from_secs(keep_alive as u64);
411                }
412            }
413            Ok(Packet::ConnAck(connack))
414        }
415        Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
416        packet => Err(ConnectionError::NotConnAck(Box::new(packet))),
417    }
418}