1use crate::{framed::Network, Transport};
2use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError};
3use crate::{MqttOptions, Outgoing};
4
5use crate::framed::N;
6use crate::mqttbytes::v4::*;
7use flume::{bounded, Receiver, Sender};
8use tokio::net::{lookup_host, TcpSocket, TcpStream};
9use tokio::select;
10use tokio::time::{self, Instant, Sleep};
11
12use std::collections::VecDeque;
13use std::io;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::time::Duration;
17
18#[cfg(unix)]
19use {std::path::Path, tokio::net::UnixStream};
20
21#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
22use crate::tls;
23
24#[cfg(feature = "websocket")]
25use {
26 crate::websockets::{split_url, validate_response_headers, UrlError},
27 async_tungstenite::tungstenite::client::IntoClientRequest,
28 ws_stream_tungstenite::WsStream,
29};
30
31#[cfg(feature = "proxy")]
32use crate::proxy::ProxyError;
33
34#[derive(Debug, thiserror::Error)]
36pub enum ConnectionError {
37 #[error("Mqtt state: {0}")]
38 MqttState(#[from] StateError),
39 #[error("Network timeout")]
40 NetworkTimeout,
41 #[error("Flush timeout")]
42 FlushTimeout,
43 #[cfg(feature = "websocket")]
44 #[error("Websocket: {0}")]
45 Websocket(#[from] async_tungstenite::tungstenite::error::Error),
46 #[cfg(feature = "websocket")]
47 #[error("Websocket Connect: {0}")]
48 WsConnect(#[from] http::Error),
49 #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
50 #[error("TLS: {0}")]
51 Tls(#[from] tls::Error),
52 #[error("I/O: {0}")]
53 Io(#[from] io::Error),
54 #[error("Connection refused, return code: `{0:?}`")]
55 ConnectionRefused(ConnectReturnCode),
56 #[error("Expected ConnAck packet, received: {0:?}")]
57 NotConnAck(Packet),
58 #[error("Requests done")]
59 RequestsDone,
60 #[cfg(feature = "websocket")]
61 #[error("Invalid Url: {0}")]
62 InvalidUrl(#[from] UrlError),
63 #[cfg(feature = "proxy")]
64 #[error("Proxy Connect: {0}")]
65 Proxy(#[from] ProxyError),
66 #[cfg(feature = "websocket")]
67 #[error("Websocket response validation error: ")]
68 ResponseValidation(#[from] crate::websockets::ValidationError),
69}
70
71pub struct EventLoop {
73 pub mqtt_options: MqttOptions,
75 pub state: MqttState,
77 requests_rx: Receiver<Request>,
79 pub(crate) requests_tx: Sender<Request>,
81 pub pending: VecDeque<Request>,
83 network: Option<Network>,
85 keepalive_timeout: Option<Pin<Box<Sleep>>>,
87 pub network_options: NetworkOptions,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum Event {
93 Incoming(Incoming),
94 Outgoing(Outgoing),
95}
96
97impl EventLoop {
98 pub fn new(mqtt_options: MqttOptions, cap: usize) -> EventLoop {
103 let (requests_tx, requests_rx) = bounded(cap);
104 let pending = VecDeque::new();
105 let max_inflight = mqtt_options.inflight;
106 let manual_acks = mqtt_options.manual_acks;
107 let max_outgoing_packet_size = mqtt_options.max_outgoing_packet_size;
108
109 EventLoop {
110 mqtt_options,
111 state: MqttState::new(max_inflight, manual_acks, max_outgoing_packet_size),
112 requests_tx,
113 requests_rx,
114 pending,
115 network: None,
116 keepalive_timeout: None,
117 network_options: NetworkOptions::new(),
118 }
119 }
120
121 pub fn clean(&mut self) {
129 self.network = None;
130 self.keepalive_timeout = None;
131 self.pending.extend(self.state.clean());
132
133 let requests_in_channel = self.requests_rx.drain();
135 self.pending.extend(requests_in_channel);
136 }
137
138 pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
143 if self.network.is_none() {
144 let (network, connack) = match time::timeout(
145 Duration::from_secs(self.network_options.connection_timeout()),
146 connect(&self.mqtt_options, self.network_options.clone()),
147 )
148 .await
149 {
150 Ok(inner) => inner?,
151 Err(_) => return Err(ConnectionError::NetworkTimeout),
152 };
153 self.network = Some(network);
154
155 if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() {
156 self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive)));
157 }
158
159 return Ok(Event::Incoming(connack));
160 }
161
162 match self.select().await {
163 Ok(v) => Ok(v),
164 Err(e) => {
165 self.clean();
166 Err(e)
167 }
168 }
169 }
170
171 async fn select(&mut self) -> Result<Event, ConnectionError> {
173 let network = self.network.as_mut().unwrap();
174 let inflight_full = self.state.inflight >= self.mqtt_options.inflight;
176 let collision = self.state.collision.is_some();
177 let network_timeout = Duration::from_secs(self.network_options.connection_timeout());
178
179 if let Some(event) = self.state.events.pop_front() {
181 return Ok(event);
182 }
183
184 let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
185 select! {
188 o = network.readb(&mut self.state) => {
190 o?;
191 match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
193 Ok(inner) => inner?,
194 Err(_)=> return Err(ConnectionError::FlushTimeout),
195 };
196 Ok(self.state.events.pop_front().unwrap())
197 },
198 o = Self::next_request(
227 &mut self.pending,
228 &self.requests_rx,
229 self.mqtt_options.pending_throttle
230 ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
231 Ok(request) => {
232 self.state.handle_outgoing_packet(request)?;
233 match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
234 Ok(inner) => inner?,
235 Err(_)=> return Err(ConnectionError::FlushTimeout),
236 };
237 Ok(self.state.events.pop_front().unwrap())
238 }
239 Err(_) => Err(ConnectionError::RequestsDone),
240 },
241 _ = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
244 if self.keepalive_timeout.is_some() && !self.mqtt_options.keep_alive.is_zero() => {
245 let timeout = self.keepalive_timeout.as_mut().unwrap();
246 timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);
247
248 self.state.handle_outgoing_packet(Request::PingReq(PingReq))?;
249 match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
250 Ok(inner) => inner?,
251 Err(_)=> return Err(ConnectionError::FlushTimeout),
252 };
253 Ok(self.state.events.pop_front().unwrap())
254 }
255 }
256 }
257
258 pub fn network_options(&self) -> NetworkOptions {
259 self.network_options.clone()
260 }
261
262 pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self {
263 self.network_options = network_options;
264 self
265 }
266
267 async fn next_request(
268 pending: &mut VecDeque<Request>,
269 rx: &Receiver<Request>,
270 pending_throttle: Duration,
271 ) -> Result<Request, ConnectionError> {
272 if !pending.is_empty() {
273 time::sleep(pending_throttle).await;
274 Ok(pending.pop_front().unwrap())
277 } else {
278 match rx.recv_async().await {
279 Ok(r) => Ok(r),
280 Err(_) => Err(ConnectionError::RequestsDone),
281 }
282 }
283 }
284}
285
286async fn connect(
292 mqtt_options: &MqttOptions,
293 network_options: NetworkOptions,
294) -> Result<(Network, Incoming), ConnectionError> {
295 let mut network = network_connect(mqtt_options, network_options).await?;
297
298 let packet = mqtt_connect(mqtt_options, &mut network).await?;
300
301 Ok((network, packet))
302}
303
304pub(crate) async fn socket_connect(
305 host: String,
306 network_options: NetworkOptions,
307) -> io::Result<TcpStream> {
308 let addrs = lookup_host(host).await?;
309 let mut last_err = None;
310
311 for addr in addrs {
312 let socket = match addr {
313 SocketAddr::V4(_) => TcpSocket::new_v4()?,
314 SocketAddr::V6(_) => TcpSocket::new_v6()?,
315 };
316
317 if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
318 socket.set_send_buffer_size(send_buff_size).unwrap();
319 }
320 if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
321 socket.set_recv_buffer_size(recv_buffer_size).unwrap();
322 }
323
324 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
325 {
326 if let Some(bind_device) = &network_options.bind_device {
327 socket.bind_device(Some(bind_device.as_bytes()))?;
331 }
332 }
333
334 match socket.connect(addr).await {
335 Ok(s) => return Ok(s),
336 Err(e) => {
337 last_err = Some(e);
338 }
339 };
340 }
341
342 Err(last_err.unwrap_or_else(|| {
343 io::Error::new(
344 io::ErrorKind::InvalidInput,
345 "could not resolve to any address",
346 )
347 }))
348}
349
350async fn network_connect(
351 options: &MqttOptions,
352 network_options: NetworkOptions,
353) -> Result<Network, ConnectionError> {
354 #[cfg(unix)]
356 if matches!(options.transport(), Transport::Unix) {
357 let file = options.broker_addr.as_str();
358 let socket = UnixStream::connect(Path::new(file)).await?;
359 let network = Network::new(socket, options.max_incoming_packet_size);
360 return Ok(network);
361 }
362
363 let (domain, port) = match options.transport() {
365 #[cfg(feature = "websocket")]
366 Transport::Ws => split_url(&options.broker_addr)?,
367 #[cfg(all(feature = "use-rustls", feature = "websocket"))]
368 Transport::Wss(_) => split_url(&options.broker_addr)?,
369 _ => options.broker_address(),
370 };
371
372 let tcp_stream: Box<dyn N> = {
373 #[cfg(feature = "proxy")]
374 match options.proxy() {
375 Some(proxy) => proxy.connect(&domain, port, network_options).await?,
376 None => {
377 let addr = format!("{domain}:{port}");
378 let tcp = socket_connect(addr, network_options).await?;
379 Box::new(tcp)
380 }
381 }
382 #[cfg(not(feature = "proxy"))]
383 {
384 let addr = format!("{domain}:{port}");
385 let tcp = socket_connect(addr, network_options).await?;
386 Box::new(tcp)
387 }
388 };
389
390 let network = match options.transport() {
391 Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size),
392 #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
393 Transport::Tls(tls_config) => {
394 let socket =
395 tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
396 .await?;
397 Network::new(socket, options.max_incoming_packet_size)
398 }
399 #[cfg(unix)]
400 Transport::Unix => unreachable!(),
401 #[cfg(feature = "websocket")]
402 Transport::Ws => {
403 let mut request = options.broker_addr.as_str().into_client_request()?;
404 request
405 .headers_mut()
406 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
407
408 if let Some(request_modifier) = options.request_modifier() {
409 request = request_modifier(request).await;
410 }
411
412 let (socket, response) =
413 async_tungstenite::tokio::client_async(request, tcp_stream).await?;
414 validate_response_headers(response)?;
415
416 Network::new(WsStream::new(socket), options.max_incoming_packet_size)
417 }
418 #[cfg(all(feature = "use-rustls", feature = "websocket"))]
419 Transport::Wss(tls_config) => {
420 let mut request = options.broker_addr.as_str().into_client_request()?;
421 request
422 .headers_mut()
423 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
424
425 if let Some(request_modifier) = options.request_modifier() {
426 request = request_modifier(request).await;
427 }
428
429 let connector = tls::rustls_connector(&tls_config).await?;
430
431 let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
432 request,
433 tcp_stream,
434 Some(connector),
435 )
436 .await?;
437 validate_response_headers(response)?;
438
439 Network::new(WsStream::new(socket), options.max_incoming_packet_size)
440 }
441 };
442
443 Ok(network)
444}
445
446async fn mqtt_connect(
447 options: &MqttOptions,
448 network: &mut Network,
449) -> Result<Incoming, ConnectionError> {
450 let keep_alive = options.keep_alive().as_secs() as u16;
451 let clean_session = options.clean_session();
452 let last_will = options.last_will();
453
454 let mut connect = Connect::new(options.client_id());
455 connect.keep_alive = keep_alive;
456 connect.clean_session = clean_session;
457 connect.last_will = last_will;
458
459 if let Some((username, password)) = options.credentials() {
460 let login = Login::new(username, password);
461 connect.login = Some(login);
462 }
463
464 network.connect(connect).await?;
466
467 match network.read().await? {
469 Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
470 Ok(Packet::ConnAck(connack))
471 }
472 Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
473 packet => Err(ConnectionError::NotConnAck(packet)),
474 }
475}