1use super::framed::Network;
2use super::mqttbytes::v5::*;
3use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
4use crate::eventloop::socket_connect;
5use crate::framed::N;
6
7use flume::{bounded, Receiver, Sender};
8use tokio::select;
9use tokio::time::{self, error::Elapsed, Instant, Sleep};
10
11use std::collections::VecDeque;
12use std::convert::TryInto;
13use std::io;
14use std::pin::Pin;
15use std::time::Duration;
16
17use super::mqttbytes::v5::ConnectReturnCode;
18
19#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
20use crate::tls;
21
22#[cfg(unix)]
23use {std::path::Path, tokio::net::UnixStream};
24
25#[cfg(feature = "websocket")]
26use {
27 crate::websockets::{split_url, validate_response_headers, UrlError},
28 async_tungstenite::tungstenite::client::IntoClientRequest,
29 ws_stream_tungstenite::WsStream,
30};
31
32#[cfg(feature = "proxy")]
33use crate::proxy::ProxyError;
34
35#[derive(Debug, thiserror::Error)]
37pub enum ConnectionError {
38 #[error("Mqtt state: {0}")]
39 MqttState(#[from] StateError),
40 #[error("Timeout")]
41 Timeout(#[from] Elapsed),
42 #[cfg(feature = "websocket")]
43 #[error("Websocket: {0}")]
44 Websocket(#[from] async_tungstenite::tungstenite::error::Error),
45 #[cfg(feature = "websocket")]
46 #[error("Websocket Connect: {0}")]
47 WsConnect(#[from] http::Error),
48 #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
49 #[error("TLS: {0}")]
50 Tls(#[from] tls::Error),
51 #[error("I/O: {0}")]
52 Io(#[from] io::Error),
53 #[error("Connection refused, return code: `{0:?}`")]
54 ConnectionRefused(ConnectReturnCode),
55 #[error("Expected ConnAck packet, received: {0:?}")]
56 NotConnAck(Box<Packet>),
57 #[error("Requests done")]
58 RequestsDone,
59 #[cfg(feature = "websocket")]
60 #[error("Invalid Url: {0}")]
61 InvalidUrl(#[from] UrlError),
62 #[cfg(feature = "proxy")]
63 #[error("Proxy Connect: {0}")]
64 Proxy(#[from] ProxyError),
65 #[cfg(feature = "websocket")]
66 #[error("Websocket response validation error: ")]
67 ResponseValidation(#[from] crate::websockets::ValidationError),
68}
69
70pub struct EventLoop {
72 pub options: MqttOptions,
74 pub state: MqttState,
76 requests_rx: Receiver<Request>,
78 pub(crate) requests_tx: Sender<Request>,
80 pub pending: VecDeque<Request>,
82 network: Option<Network>,
84 keepalive_timeout: Option<Pin<Box<Sleep>>>,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum Event {
91 Incoming(Incoming),
92 Outgoing(Outgoing),
93}
94
95impl EventLoop {
96 pub fn new(options: MqttOptions, cap: usize) -> EventLoop {
101 let (requests_tx, requests_rx) = bounded(cap);
102 let pending = VecDeque::new();
103 let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX);
104 let manual_acks = options.manual_acks;
105
106 EventLoop {
107 options,
108 state: MqttState::new(inflight_limit, manual_acks),
109 requests_tx,
110 requests_rx,
111 pending,
112 network: None,
113 keepalive_timeout: None,
114 }
115 }
116
117 pub fn clean(&mut self) {
125 self.network = None;
126 self.keepalive_timeout = None;
127 self.pending.extend(self.state.clean());
128
129 let requests_in_channel = self.requests_rx.drain();
131 self.pending.extend(requests_in_channel);
132 }
133
134 pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
139 if self.network.is_none() {
140 let (network, connack) = time::timeout(
141 Duration::from_secs(self.options.connection_timeout()),
142 connect(&mut self.options),
143 )
144 .await??;
145 self.network = Some(network);
146
147 if self.keepalive_timeout.is_none() {
148 self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
149 }
150
151 self.state.handle_incoming_packet(connack)?;
152 }
153
154 match self.select().await {
155 Ok(v) => Ok(v),
156 Err(e) => {
157 self.clean();
158 Err(e)
159 }
160 }
161 }
162
163 async fn select(&mut self) -> Result<Event, ConnectionError> {
165 let network = self.network.as_mut().unwrap();
166 let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
169 let collision = self.state.collision.is_some();
170
171 if let Some(event) = self.state.events.pop_front() {
173 return Ok(event);
174 }
175
176 select! {
179 o = Self::next_request(
208 &mut self.pending,
209 &self.requests_rx,
210 self.options.pending_throttle
211 ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
212 Ok(request) => {
213 self.state.handle_outgoing_packet(request)?;
214 network.flush(&mut self.state.write).await?;
215 Ok(self.state.events.pop_front().unwrap())
216 }
217 Err(_) => Err(ConnectionError::RequestsDone),
218 },
219 o = network.readb(&mut self.state) => {
221 o?;
222 network.flush(&mut self.state.write).await?;
224 Ok(self.state.events.pop_front().unwrap())
225 },
226 _ = self.keepalive_timeout.as_mut().unwrap() => {
229 let timeout = self.keepalive_timeout.as_mut().unwrap();
230 timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
231
232 self.state.handle_outgoing_packet(Request::PingReq)?;
233 network.flush(&mut self.state.write).await?;
234 Ok(self.state.events.pop_front().unwrap())
235 }
236 }
237 }
238
239 async fn next_request(
240 pending: &mut VecDeque<Request>,
241 rx: &Receiver<Request>,
242 pending_throttle: Duration,
243 ) -> Result<Request, ConnectionError> {
244 if !pending.is_empty() {
245 time::sleep(pending_throttle).await;
246 Ok(pending.pop_front().unwrap())
249 } else {
250 match rx.recv_async().await {
251 Ok(r) => Ok(r),
252 Err(_) => Err(ConnectionError::RequestsDone),
253 }
254 }
255 }
256}
257
258async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), ConnectionError> {
264 let mut network = network_connect(options).await?;
266
267 let packet = mqtt_connect(options, &mut network).await?;
269
270 Ok((network, packet))
276}
277
278async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
279 let mut max_incoming_pkt_size = Some(options.default_max_incoming_size);
280
281 if let Some(connect_props) = &options.connect_properties {
283 if let Some(max_size) = connect_props.max_packet_size {
284 let max_size = max_size.try_into().map_err(StateError::Coversion)?;
285 max_incoming_pkt_size = Some(max_size);
286 }
287 }
288
289 #[cfg(unix)]
291 if matches!(options.transport(), Transport::Unix) {
292 let file = options.broker_addr.as_str();
293 let socket = UnixStream::connect(Path::new(file)).await?;
294 let network = Network::new(socket, max_incoming_pkt_size);
295 return Ok(network);
296 }
297
298 let (domain, port) = match options.transport() {
300 #[cfg(feature = "websocket")]
301 Transport::Ws => split_url(&options.broker_addr)?,
302 #[cfg(all(feature = "use-rustls", feature = "websocket"))]
303 Transport::Wss(_) => split_url(&options.broker_addr)?,
304 _ => options.broker_address(),
305 };
306
307 let tcp_stream: Box<dyn N> = {
308 #[cfg(feature = "proxy")]
309 match options.proxy() {
310 Some(proxy) => {
311 proxy
312 .connect(&domain, port, options.network_options())
313 .await?
314 }
315 None => {
316 let addr = format!("{domain}:{port}");
317 let tcp = socket_connect(addr, options.network_options()).await?;
318 Box::new(tcp)
319 }
320 }
321 #[cfg(not(feature = "proxy"))]
322 {
323 let addr = format!("{domain}:{port}");
324 let tcp = socket_connect(addr, options.network_options()).await?;
325 Box::new(tcp)
326 }
327 };
328
329 let network = match options.transport() {
330 Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
331 #[cfg(any(feature = "use-native-tls", feature = "use-rustls"))]
332 Transport::Tls(tls_config) => {
333 let socket =
334 tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
335 .await?;
336 Network::new(socket, max_incoming_pkt_size)
337 }
338 #[cfg(unix)]
339 Transport::Unix => unreachable!(),
340 #[cfg(feature = "websocket")]
341 Transport::Ws => {
342 let mut request = options.broker_addr.as_str().into_client_request()?;
343 request
344 .headers_mut()
345 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
346
347 if let Some(request_modifier) = options.request_modifier() {
348 request = request_modifier(request).await;
349 }
350
351 let (socket, response) =
352 async_tungstenite::tokio::client_async(request, tcp_stream).await?;
353 validate_response_headers(response)?;
354
355 Network::new(WsStream::new(socket), max_incoming_pkt_size)
356 }
357 #[cfg(all(feature = "use-rustls", feature = "websocket"))]
358 Transport::Wss(tls_config) => {
359 let mut request = options.broker_addr.as_str().into_client_request()?;
360 request
361 .headers_mut()
362 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
363
364 if let Some(request_modifier) = options.request_modifier() {
365 request = request_modifier(request).await;
366 }
367
368 let connector = tls::rustls_connector(&tls_config).await?;
369
370 let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
371 request,
372 tcp_stream,
373 Some(connector),
374 )
375 .await?;
376 validate_response_headers(response)?;
377
378 Network::new(WsStream::new(socket), max_incoming_pkt_size)
379 }
380 };
381
382 Ok(network)
383}
384
385async fn mqtt_connect(
386 options: &mut MqttOptions,
387 network: &mut Network,
388) -> Result<Incoming, ConnectionError> {
389 let keep_alive = options.keep_alive().as_secs() as u16;
390 let clean_start = options.clean_start();
391 let client_id = options.client_id();
392 let properties = options.connect_properties();
393
394 let connect = Connect {
395 keep_alive,
396 client_id,
397 clean_start,
398 properties,
399 };
400
401 network.connect(connect, options).await?;
403
404 match network.read().await? {
406 Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
407 if let Some(props) = &connack.properties {
409 if let Some(keep_alive) = props.server_keep_alive {
410 options.keep_alive = Duration::from_secs(keep_alive as u64);
411 }
412 }
413 Ok(Packet::ConnAck(connack))
414 }
415 Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
416 packet => Err(ConnectionError::NotConnAck(Box::new(packet))),
417 }
418}