rumqttc/
eventloop.rs

1use crate::{framed::Network, Transport};
2use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError};
3use crate::{MqttOptions, Outgoing};
4
5use crate::framed::N;
6use crate::mqttbytes::v4::*;
7use flume::{bounded, Receiver, Sender};
8use tokio::net::{lookup_host, TcpSocket, TcpStream};
9use tokio::select;
10use tokio::time::{self, Instant, Sleep};
11
12use std::collections::VecDeque;
13use std::io;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::time::Duration;
17
18#[cfg(unix)]
19use {std::path::Path, tokio::net::UnixStream};
20
21#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
22use crate::tls;
23
24#[cfg(feature = "websocket")]
25use {
26    crate::websockets::{split_url, validate_response_headers, UrlError},
27    async_tungstenite::tungstenite::client::IntoClientRequest,
28    ws_stream_tungstenite::WsStream,
29};
30
31#[cfg(feature = "proxy")]
32use crate::proxy::ProxyError;
33
34/// Critical errors during eventloop polling
35#[derive(Debug, thiserror::Error)]
36pub enum ConnectionError {
37    #[error("Mqtt state: {0}")]
38    MqttState(#[from] StateError),
39    #[error("Network timeout")]
40    NetworkTimeout,
41    #[error("Flush timeout")]
42    FlushTimeout,
43    #[cfg(feature = "websocket")]
44    #[error("Websocket: {0}")]
45    Websocket(#[from] async_tungstenite::tungstenite::error::Error),
46    #[cfg(feature = "websocket")]
47    #[error("Websocket Connect: {0}")]
48    WsConnect(#[from] http::Error),
49    #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
50    #[error("TLS: {0}")]
51    Tls(#[from] tls::Error),
52    #[error("I/O: {0}")]
53    Io(#[from] io::Error),
54    #[error("Connection refused, return code: `{0:?}`")]
55    ConnectionRefused(ConnectReturnCode),
56    #[error("Expected ConnAck packet, received: {0:?}")]
57    NotConnAck(Packet),
58    #[error("Requests done")]
59    RequestsDone,
60    #[cfg(feature = "websocket")]
61    #[error("Invalid Url: {0}")]
62    InvalidUrl(#[from] UrlError),
63    #[cfg(feature = "proxy")]
64    #[error("Proxy Connect: {0}")]
65    Proxy(#[from] ProxyError),
66    #[cfg(feature = "websocket")]
67    #[error("Websocket response validation error: ")]
68    ResponseValidation(#[from] crate::websockets::ValidationError),
69}
70
71/// Eventloop with all the state of a connection
72pub struct EventLoop {
73    /// Options of the current mqtt connection
74    pub mqtt_options: MqttOptions,
75    /// Current state of the connection
76    pub state: MqttState,
77    /// Request stream
78    requests_rx: Receiver<Request>,
79    /// Requests handle to send requests
80    pub(crate) requests_tx: Sender<Request>,
81    /// Pending packets from last session
82    pub pending: VecDeque<Request>,
83    /// Network connection to the broker
84    network: Option<Network>,
85    /// Keep alive time
86    keepalive_timeout: Option<Pin<Box<Sleep>>>,
87    pub network_options: NetworkOptions,
88}
89
90/// Events which can be yielded by the event loop
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum Event {
93    Incoming(Incoming),
94    Outgoing(Outgoing),
95}
96
97impl EventLoop {
98    /// New MQTT `EventLoop`
99    ///
100    /// When connection encounters critical errors (like auth failure), user has a choice to
101    /// access and update `options`, `state` and `requests`.
102    pub fn new(mqtt_options: MqttOptions, cap: usize) -> EventLoop {
103        let (requests_tx, requests_rx) = bounded(cap);
104        let pending = VecDeque::new();
105        let max_inflight = mqtt_options.inflight;
106        let manual_acks = mqtt_options.manual_acks;
107        let max_outgoing_packet_size = mqtt_options.max_outgoing_packet_size;
108
109        EventLoop {
110            mqtt_options,
111            state: MqttState::new(max_inflight, manual_acks, max_outgoing_packet_size),
112            requests_tx,
113            requests_rx,
114            pending,
115            network: None,
116            keepalive_timeout: None,
117            network_options: NetworkOptions::new(),
118        }
119    }
120
121    /// Last session might contain packets which aren't acked. MQTT says these packets should be
122    /// republished in the next session. Move pending messages from state to eventloop, drops the
123    /// underlying network connection and clears the keepalive timeout if any.
124    ///
125    /// > NOTE: Use only when EventLoop is blocked on network and unable to immediately handle disconnect.
126    /// > Also, while this helps prevent data loss, the pending list length should be managed properly.
127    /// > For this reason we recommend setting [`AsycClient`](crate::AsyncClient)'s channel capacity to `0`.
128    pub fn clean(&mut self) {
129        self.network = None;
130        self.keepalive_timeout = None;
131        self.pending.extend(self.state.clean());
132
133        // drain requests from channel which weren't yet received
134        let requests_in_channel = self.requests_rx.drain();
135        self.pending.extend(requests_in_channel);
136    }
137
138    /// Yields Next notification or outgoing request and periodically pings
139    /// the broker. Continuing to poll will reconnect to the broker if there is
140    /// a disconnection.
141    /// **NOTE** Don't block this while iterating
142    pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
143        if self.network.is_none() {
144            let (network, connack) = match time::timeout(
145                Duration::from_secs(self.network_options.connection_timeout()),
146                connect(&self.mqtt_options, self.network_options.clone()),
147            )
148            .await
149            {
150                Ok(inner) => inner?,
151                Err(_) => return Err(ConnectionError::NetworkTimeout),
152            };
153            self.network = Some(network);
154
155            if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() {
156                self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive)));
157            }
158
159            return Ok(Event::Incoming(connack));
160        }
161
162        match self.select().await {
163            Ok(v) => Ok(v),
164            Err(e) => {
165                self.clean();
166                Err(e)
167            }
168        }
169    }
170
171    /// Select on network and requests and generate keepalive pings when necessary
172    async fn select(&mut self) -> Result<Event, ConnectionError> {
173        let network = self.network.as_mut().unwrap();
174        // let await_acks = self.state.await_acks;
175        let inflight_full = self.state.inflight >= self.mqtt_options.inflight;
176        let collision = self.state.collision.is_some();
177        let network_timeout = Duration::from_secs(self.network_options.connection_timeout());
178
179        // Read buffered events from previous polls before calling a new poll
180        if let Some(event) = self.state.events.pop_front() {
181            return Ok(event);
182        }
183
184        let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
185        // this loop is necessary since self.incoming.pop_front() might return None. In that case,
186        // instead of returning a None event, we try again.
187        select! {
188            // Pull a bunch of packets from network, reply in bunch and yield the first item
189            o = network.readb(&mut self.state) => {
190                o?;
191                // flush all the acks and return first incoming packet
192                match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
193                    Ok(inner) => inner?,
194                    Err(_)=> return Err(ConnectionError::FlushTimeout),
195                };
196                Ok(self.state.events.pop_front().unwrap())
197            },
198             // Handles pending and new requests.
199            // If available, prioritises pending requests from previous session.
200            // Else, pulls next request from user requests channel.
201            // If conditions in the below branch are for flow control.
202            // The branch is disabled if there's no pending messages and new user requests
203            // cannot be serviced due flow control.
204            // We read next user user request only when inflight messages are < configured inflight
205            // and there are no collisions while handling previous outgoing requests.
206            //
207            // Flow control is based on ack count. If inflight packet count in the buffer is
208            // less than max_inflight setting, next outgoing request will progress. For this
209            // to work correctly, broker should ack in sequence (a lot of brokers won't)
210            //
211            // E.g If max inflight = 5, user requests will be blocked when inflight queue
212            // looks like this                 -> [1, 2, 3, 4, 5].
213            // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5].
214            // This pulls next user request. But because max packet id = max_inflight, next
215            // user request's packet id will roll to 1. This replaces existing packet id 1.
216            // Resulting in a collision
217            //
218            // Eventloop can stop receiving outgoing user requests when previous outgoing
219            // request collided. I.e collision state. Collision state will be cleared only
220            // when correct ack is received
221            // Full inflight queue will look like -> [1a, 2, 3, 4, 5].
222            // If 3 is acked instead of 1 first   -> [1a, 2, x, 4, 5].
223            // After collision with pkid 1        -> [1b ,2, x, 4, 5].
224            // 1a is saved to state and event loop is set to collision mode stopping new
225            // outgoing requests (along with 1b).
226            o = Self::next_request(
227                &mut self.pending,
228                &self.requests_rx,
229                self.mqtt_options.pending_throttle
230            ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
231                Ok(request) => {
232                    self.state.handle_outgoing_packet(request)?;
233                    match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
234                        Ok(inner) => inner?,
235                        Err(_)=> return Err(ConnectionError::FlushTimeout),
236                    };
237                    Ok(self.state.events.pop_front().unwrap())
238                }
239                Err(_) => Err(ConnectionError::RequestsDone),
240            },
241            // We generate pings irrespective of network activity. This keeps the ping logic
242            // simple. We can change this behavior in future if necessary (to prevent extra pings)
243            _ = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
244                if self.keepalive_timeout.is_some() && !self.mqtt_options.keep_alive.is_zero() => {
245                let timeout = self.keepalive_timeout.as_mut().unwrap();
246                timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);
247
248                self.state.handle_outgoing_packet(Request::PingReq(PingReq))?;
249                match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
250                    Ok(inner) => inner?,
251                    Err(_)=> return Err(ConnectionError::FlushTimeout),
252                };
253                Ok(self.state.events.pop_front().unwrap())
254            }
255        }
256    }
257
258    pub fn network_options(&self) -> NetworkOptions {
259        self.network_options.clone()
260    }
261
262    pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self {
263        self.network_options = network_options;
264        self
265    }
266
267    async fn next_request(
268        pending: &mut VecDeque<Request>,
269        rx: &Receiver<Request>,
270        pending_throttle: Duration,
271    ) -> Result<Request, ConnectionError> {
272        if !pending.is_empty() {
273            time::sleep(pending_throttle).await;
274            // We must call .pop_front() AFTER sleep() otherwise we would have
275            // advanced the iterator but the future might be canceled before return
276            Ok(pending.pop_front().unwrap())
277        } else {
278            match rx.recv_async().await {
279                Ok(r) => Ok(r),
280                Err(_) => Err(ConnectionError::RequestsDone),
281            }
282        }
283    }
284}
285
286/// This stream internally processes requests from the request stream provided to the eventloop
287/// while also consuming byte stream from the network and yielding mqtt packets as the output of
288/// the stream.
289/// This function (for convenience) includes internal delays for users to perform internal sleeps
290/// between re-connections so that cancel semantics can be used during this sleep
291async fn connect(
292    mqtt_options: &MqttOptions,
293    network_options: NetworkOptions,
294) -> Result<(Network, Incoming), ConnectionError> {
295    // connect to the broker
296    let mut network = network_connect(mqtt_options, network_options).await?;
297
298    // make MQTT connection request (which internally awaits for ack)
299    let packet = mqtt_connect(mqtt_options, &mut network).await?;
300
301    Ok((network, packet))
302}
303
304pub(crate) async fn socket_connect(
305    host: String,
306    network_options: NetworkOptions,
307) -> io::Result<TcpStream> {
308    let addrs = lookup_host(host).await?;
309    let mut last_err = None;
310
311    for addr in addrs {
312        let socket = match addr {
313            SocketAddr::V4(_) => TcpSocket::new_v4()?,
314            SocketAddr::V6(_) => TcpSocket::new_v6()?,
315        };
316
317        if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
318            socket.set_send_buffer_size(send_buff_size).unwrap();
319        }
320        if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
321            socket.set_recv_buffer_size(recv_buffer_size).unwrap();
322        }
323
324        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
325        {
326            if let Some(bind_device) = &network_options.bind_device {
327                // call the bind_device function only if the bind_device network option is defined
328                // If binding device is None or an empty string it removes the binding,
329                // which is causing PermissionDenied errors in AWS environment (lambda function).
330                socket.bind_device(Some(bind_device.as_bytes()))?;
331            }
332        }
333
334        match socket.connect(addr).await {
335            Ok(s) => return Ok(s),
336            Err(e) => {
337                last_err = Some(e);
338            }
339        };
340    }
341
342    Err(last_err.unwrap_or_else(|| {
343        io::Error::new(
344            io::ErrorKind::InvalidInput,
345            "could not resolve to any address",
346        )
347    }))
348}
349
350async fn network_connect(
351    options: &MqttOptions,
352    network_options: NetworkOptions,
353) -> Result<Network, ConnectionError> {
354    // Process Unix files early, as proxy is not supported for them.
355    #[cfg(unix)]
356    if matches!(options.transport(), Transport::Unix) {
357        let file = options.broker_addr.as_str();
358        let socket = UnixStream::connect(Path::new(file)).await?;
359        let network = Network::new(socket, options.max_incoming_packet_size);
360        return Ok(network);
361    }
362
363    // For websockets domain and port are taken directly from `broker_addr` (which is a url).
364    let (domain, port) = match options.transport() {
365        #[cfg(feature = "websocket")]
366        Transport::Ws => split_url(&options.broker_addr)?,
367        #[cfg(all(feature = "use-rustls", feature = "websocket"))]
368        Transport::Wss(_) => split_url(&options.broker_addr)?,
369        _ => options.broker_address(),
370    };
371
372    let tcp_stream: Box<dyn N> = {
373        #[cfg(feature = "proxy")]
374        match options.proxy() {
375            Some(proxy) => proxy.connect(&domain, port, network_options).await?,
376            None => {
377                let addr = format!("{domain}:{port}");
378                let tcp = socket_connect(addr, network_options).await?;
379                Box::new(tcp)
380            }
381        }
382        #[cfg(not(feature = "proxy"))]
383        {
384            let addr = format!("{domain}:{port}");
385            let tcp = socket_connect(addr, network_options).await?;
386            Box::new(tcp)
387        }
388    };
389
390    let network = match options.transport() {
391        Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size),
392        #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
393        Transport::Tls(tls_config) => {
394            let socket =
395                tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
396                    .await?;
397            Network::new(socket, options.max_incoming_packet_size)
398        }
399        #[cfg(unix)]
400        Transport::Unix => unreachable!(),
401        #[cfg(feature = "websocket")]
402        Transport::Ws => {
403            let mut request = options.broker_addr.as_str().into_client_request()?;
404            request
405                .headers_mut()
406                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
407
408            if let Some(request_modifier) = options.request_modifier() {
409                request = request_modifier(request).await;
410            }
411
412            let (socket, response) =
413                async_tungstenite::tokio::client_async(request, tcp_stream).await?;
414            validate_response_headers(response)?;
415
416            Network::new(WsStream::new(socket), options.max_incoming_packet_size)
417        }
418        #[cfg(all(feature = "use-rustls", feature = "websocket"))]
419        Transport::Wss(tls_config) => {
420            let mut request = options.broker_addr.as_str().into_client_request()?;
421            request
422                .headers_mut()
423                .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
424
425            if let Some(request_modifier) = options.request_modifier() {
426                request = request_modifier(request).await;
427            }
428
429            let connector = tls::rustls_connector(&tls_config).await?;
430
431            let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
432                request,
433                tcp_stream,
434                Some(connector),
435            )
436            .await?;
437            validate_response_headers(response)?;
438
439            Network::new(WsStream::new(socket), options.max_incoming_packet_size)
440        }
441    };
442
443    Ok(network)
444}
445
446async fn mqtt_connect(
447    options: &MqttOptions,
448    network: &mut Network,
449) -> Result<Incoming, ConnectionError> {
450    let keep_alive = options.keep_alive().as_secs() as u16;
451    let clean_session = options.clean_session();
452    let last_will = options.last_will();
453
454    let mut connect = Connect::new(options.client_id());
455    connect.keep_alive = keep_alive;
456    connect.clean_session = clean_session;
457    connect.last_will = last_will;
458
459    if let Some((username, password)) = options.credentials() {
460        let login = Login::new(username, password);
461        connect.login = Some(login);
462    }
463
464    // send mqtt connect packet
465    network.connect(connect).await?;
466
467    // validate connack
468    match network.read().await? {
469        Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
470            Ok(Packet::ConnAck(connack))
471        }
472        Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
473        packet => Err(ConnectionError::NotConnAck(packet)),
474    }
475}