1#[cfg(feature = "use-rustls")]
2use rustls_pemfile::Item;
3#[cfg(feature = "use-rustls")]
4use tokio_rustls::rustls::{
5 self,
6 pki_types::{InvalidDnsNameError, ServerName},
7 ClientConfig, RootCertStore,
8};
9#[cfg(feature = "use-rustls")]
10use tokio_rustls::TlsConnector as RustlsConnector;
11
12#[cfg(feature = "use-rustls")]
13use std::convert::TryFrom;
14#[cfg(feature = "use-rustls")]
15use std::io::{BufReader, Cursor};
16#[cfg(feature = "use-rustls")]
17use std::sync::Arc;
18
19use crate::framed::N;
20use crate::TlsConfiguration;
21
22#[cfg(feature = "use-native-tls")]
23use tokio_native_tls::TlsConnector as NativeTlsConnector;
24
25#[cfg(feature = "use-native-tls")]
26use tokio_native_tls::native_tls::{Error as NativeTlsError, Identity};
27
28use std::io;
29use std::net::AddrParseError;
30
31#[derive(Debug, thiserror::Error)]
32pub enum Error {
33 #[error("Addr")]
35 Addr(#[from] AddrParseError),
36 #[error("I/O: {0}")]
38 Io(#[from] io::Error),
39 #[cfg(feature = "use-rustls")]
40 #[error("Web Pki: {0}")]
42 WebPki(#[from] webpki::Error),
43 #[cfg(feature = "use-rustls")]
45 #[error("DNS name")]
46 DNSName(#[from] InvalidDnsNameError),
47 #[cfg(feature = "use-rustls")]
48 #[error("TLS error: {0}")]
50 TLS(#[from] rustls::Error),
51 #[cfg(feature = "use-rustls")]
52 #[error("No valid CA certificate provided")]
54 NoValidCertInChain,
55 #[cfg(feature = "use-rustls")]
56 #[error("No valid certificate for client authentication in chain")]
58 NoValidClientCertInChain,
59 #[cfg(feature = "use-rustls")]
60 #[error("No valid key in chain")]
62 NoValidKeyInChain,
63 #[cfg(feature = "use-native-tls")]
64 #[error("Native TLS error {0}")]
65 NativeTls(#[from] NativeTlsError),
66}
67
68#[cfg(feature = "use-rustls")]
69pub async fn rustls_connector(tls_config: &TlsConfiguration) -> Result<RustlsConnector, Error> {
70 let config = match tls_config {
71 TlsConfiguration::Simple {
72 ca,
73 alpn,
74 client_auth,
75 } => {
76 let mut root_cert_store = RootCertStore::empty();
78 let certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))
79 .collect::<Result<Vec<_>, _>>()?;
80
81 root_cert_store.add_parsable_certificates(certs);
82
83 if root_cert_store.is_empty() {
84 return Err(Error::NoValidCertInChain);
85 }
86
87 let config = ClientConfig::builder().with_root_certificates(root_cert_store);
88
89 let mut config = if let Some(client) = client_auth.as_ref() {
91 let certs =
92 rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client.0.clone())))
93 .collect::<Result<Vec<_>, _>>()?;
94 if certs.is_empty() {
95 return Err(Error::NoValidClientCertInChain);
96 }
97
98 let mut key_buffer = BufReader::new(Cursor::new(client.1.clone()));
100
101 let key = loop {
103 let item = rustls_pemfile::read_one(&mut key_buffer)?;
104 match item {
105 Some(Item::Sec1Key(key)) => {
106 break key.into();
107 }
108 Some(Item::Pkcs1Key(key)) => {
109 break key.into();
110 }
111 Some(Item::Pkcs8Key(key)) => {
112 break key.into();
113 }
114 None => return Err(Error::NoValidKeyInChain),
115 _ => {}
116 }
117 };
118
119 config.with_client_auth_cert(certs, key)?
120 } else {
121 config.with_no_client_auth()
122 };
123
124 if let Some(alpn) = alpn.as_ref() {
126 config.alpn_protocols.extend_from_slice(alpn);
127 }
128
129 Arc::new(config)
130 }
131 TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(),
132 #[allow(unreachable_patterns)]
133 _ => unreachable!("This cannot be called for other TLS backends than Rustls"),
134 };
135
136 Ok(RustlsConnector::from(config))
137}
138
139#[cfg(feature = "use-native-tls")]
140pub async fn native_tls_connector(
141 tls_config: &TlsConfiguration,
142) -> Result<NativeTlsConnector, Error> {
143 let connector = match tls_config {
144 TlsConfiguration::SimpleNative { ca, client_auth } => {
145 let cert = native_tls::Certificate::from_pem(ca)?;
146
147 let mut connector_builder = native_tls::TlsConnector::builder();
148 connector_builder.add_root_certificate(cert);
149
150 if let Some((der, password)) = client_auth {
151 let identity = Identity::from_pkcs12(der, password)?;
152 connector_builder.identity(identity);
153 }
154
155 connector_builder.build()?
156 }
157 TlsConfiguration::Native => native_tls::TlsConnector::new()?,
158 #[allow(unreachable_patterns)]
159 _ => unreachable!("This cannot be called for other TLS backends than Native TLS"),
160 };
161
162 Ok(connector.into())
163}
164
165pub async fn tls_connect(
166 addr: &str,
167 _port: u16,
168 tls_config: &TlsConfiguration,
169 tcp: Box<dyn N>,
170) -> Result<Box<dyn N>, Error> {
171 let tls: Box<dyn N> = match tls_config {
172 #[cfg(feature = "use-rustls")]
173 TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => {
174 let connector = rustls_connector(tls_config).await?;
175 let domain = ServerName::try_from(addr)?.to_owned();
176 Box::new(connector.connect(domain, tcp).await?)
177 }
178 #[cfg(feature = "use-native-tls")]
179 TlsConfiguration::Native | TlsConfiguration::SimpleNative { .. } => {
180 let connector = native_tls_connector(tls_config).await?;
181 Box::new(connector.connect(addr, tcp).await?)
182 }
183 #[allow(unreachable_patterns)]
184 _ => panic!("Unknown or not enabled TLS backend configuration"),
185 };
186 Ok(tls)
187}