rumqttc/
tls.rs

1#[cfg(feature = "use-rustls")]
2use rustls_pemfile::Item;
3#[cfg(feature = "use-rustls")]
4use tokio_rustls::rustls::{
5    self,
6    pki_types::{InvalidDnsNameError, ServerName},
7    ClientConfig, RootCertStore,
8};
9#[cfg(feature = "use-rustls")]
10use tokio_rustls::TlsConnector as RustlsConnector;
11
12#[cfg(feature = "use-rustls")]
13use std::convert::TryFrom;
14#[cfg(feature = "use-rustls")]
15use std::io::{BufReader, Cursor};
16#[cfg(feature = "use-rustls")]
17use std::sync::Arc;
18
19use crate::framed::N;
20use crate::TlsConfiguration;
21
22#[cfg(feature = "use-native-tls")]
23use tokio_native_tls::TlsConnector as NativeTlsConnector;
24
25#[cfg(feature = "use-native-tls")]
26use tokio_native_tls::native_tls::{Error as NativeTlsError, Identity};
27
28use std::io;
29use std::net::AddrParseError;
30
31#[derive(Debug, thiserror::Error)]
32pub enum Error {
33    /// Error parsing IP address
34    #[error("Addr")]
35    Addr(#[from] AddrParseError),
36    /// I/O related error
37    #[error("I/O: {0}")]
38    Io(#[from] io::Error),
39    #[cfg(feature = "use-rustls")]
40    /// Certificate/Name validation error
41    #[error("Web Pki: {0}")]
42    WebPki(#[from] webpki::Error),
43    /// Invalid DNS name
44    #[cfg(feature = "use-rustls")]
45    #[error("DNS name")]
46    DNSName(#[from] InvalidDnsNameError),
47    #[cfg(feature = "use-rustls")]
48    /// Error from rustls module
49    #[error("TLS error: {0}")]
50    TLS(#[from] rustls::Error),
51    #[cfg(feature = "use-rustls")]
52    /// No valid CA cert found
53    #[error("No valid CA certificate provided")]
54    NoValidCertInChain,
55    #[cfg(feature = "use-rustls")]
56    /// No valid client cert found
57    #[error("No valid certificate for client authentication in chain")]
58    NoValidClientCertInChain,
59    #[cfg(feature = "use-rustls")]
60    /// No valid key found
61    #[error("No valid key in chain")]
62    NoValidKeyInChain,
63    #[cfg(feature = "use-native-tls")]
64    #[error("Native TLS error {0}")]
65    NativeTls(#[from] NativeTlsError),
66}
67
68#[cfg(feature = "use-rustls")]
69pub async fn rustls_connector(tls_config: &TlsConfiguration) -> Result<RustlsConnector, Error> {
70    let config = match tls_config {
71        TlsConfiguration::Simple {
72            ca,
73            alpn,
74            client_auth,
75        } => {
76            // Add ca to root store if the connection is TLS
77            let mut root_cert_store = RootCertStore::empty();
78            let certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))
79                .collect::<Result<Vec<_>, _>>()?;
80
81            root_cert_store.add_parsable_certificates(certs);
82
83            if root_cert_store.is_empty() {
84                return Err(Error::NoValidCertInChain);
85            }
86
87            let config = ClientConfig::builder().with_root_certificates(root_cert_store);
88
89            // Add der encoded client cert and key
90            let mut config = if let Some(client) = client_auth.as_ref() {
91                let certs =
92                    rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client.0.clone())))
93                        .collect::<Result<Vec<_>, _>>()?;
94                if certs.is_empty() {
95                    return Err(Error::NoValidClientCertInChain);
96                }
97
98                // Create buffer for key file
99                let mut key_buffer = BufReader::new(Cursor::new(client.1.clone()));
100
101                // Read PEM items until we find a valid key.
102                let key = loop {
103                    let item = rustls_pemfile::read_one(&mut key_buffer)?;
104                    match item {
105                        Some(Item::Sec1Key(key)) => {
106                            break key.into();
107                        }
108                        Some(Item::Pkcs1Key(key)) => {
109                            break key.into();
110                        }
111                        Some(Item::Pkcs8Key(key)) => {
112                            break key.into();
113                        }
114                        None => return Err(Error::NoValidKeyInChain),
115                        _ => {}
116                    }
117                };
118
119                config.with_client_auth_cert(certs, key)?
120            } else {
121                config.with_no_client_auth()
122            };
123
124            // Set ALPN
125            if let Some(alpn) = alpn.as_ref() {
126                config.alpn_protocols.extend_from_slice(alpn);
127            }
128
129            Arc::new(config)
130        }
131        TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(),
132        #[allow(unreachable_patterns)]
133        _ => unreachable!("This cannot be called for other TLS backends than Rustls"),
134    };
135
136    Ok(RustlsConnector::from(config))
137}
138
139#[cfg(feature = "use-native-tls")]
140pub async fn native_tls_connector(
141    tls_config: &TlsConfiguration,
142) -> Result<NativeTlsConnector, Error> {
143    let connector = match tls_config {
144        TlsConfiguration::SimpleNative { ca, client_auth } => {
145            let cert = native_tls::Certificate::from_pem(ca)?;
146
147            let mut connector_builder = native_tls::TlsConnector::builder();
148            connector_builder.add_root_certificate(cert);
149
150            if let Some((der, password)) = client_auth {
151                let identity = Identity::from_pkcs12(der, password)?;
152                connector_builder.identity(identity);
153            }
154
155            connector_builder.build()?
156        }
157        TlsConfiguration::Native => native_tls::TlsConnector::new()?,
158        #[allow(unreachable_patterns)]
159        _ => unreachable!("This cannot be called for other TLS backends than Native TLS"),
160    };
161
162    Ok(connector.into())
163}
164
165pub async fn tls_connect(
166    addr: &str,
167    _port: u16,
168    tls_config: &TlsConfiguration,
169    tcp: Box<dyn N>,
170) -> Result<Box<dyn N>, Error> {
171    let tls: Box<dyn N> = match tls_config {
172        #[cfg(feature = "use-rustls")]
173        TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
174            let connector = rustls_connector(tls_config).await?;
175            let domain = ServerName::try_from(addr)?.to_owned();
176            Box::new(connector.connect(domain, tcp).await?)
177        }
178        #[cfg(feature = "use-native-tls")]
179        TlsConfiguration::Native | TlsConfiguration::SimpleNative { .. } => {
180            let connector = native_tls_connector(tls_config).await?;
181            Box::new(connector.connect(addr, tcp).await?)
182        }
183        #[allow(unreachable_patterns)]
184        _ => panic!("Unknown or not enabled TLS backend configuration"),
185    };
186    Ok(tls)
187}