rumqttc/mqttbytes/v4/
connect.rs

1use super::*;
2use bytes::{Buf, Bytes};
3
4/// Connection packet initiated by the client
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Connect {
7    /// Mqtt protocol version
8    pub protocol: Protocol,
9    /// Mqtt keep alive time
10    pub keep_alive: u16,
11    /// Client Id
12    pub client_id: String,
13    /// Clean session. Asks the broker to clear previous state
14    pub clean_session: bool,
15    /// Will that broker needs to publish when the client disconnects
16    pub last_will: Option<LastWill>,
17    /// Login credentials
18    pub login: Option<Login>,
19}
20
21impl Connect {
22    pub fn new<S: Into<String>>(id: S) -> Connect {
23        Connect {
24            protocol: Protocol::V4,
25            keep_alive: 10,
26            client_id: id.into(),
27            clean_session: true,
28            last_will: None,
29            login: None,
30        }
31    }
32
33    pub fn set_login<U: Into<String>, P: Into<String>>(&mut self, u: U, p: P) -> &mut Connect {
34        let login = Login {
35            username: u.into(),
36            password: p.into(),
37        };
38
39        self.login = Some(login);
40        self
41    }
42
43    fn len(&self) -> usize {
44        let mut len = 2 + "MQTT".len() // protocol name
45                              + 1            // protocol version
46                              + 1            // connect flags
47                              + 2; // keep alive
48
49        len += 2 + self.client_id.len();
50
51        // last will len
52        if let Some(last_will) = &self.last_will {
53            len += last_will.len();
54        }
55
56        // username and password len
57        if let Some(login) = &self.login {
58            len += login.len();
59        }
60
61        len
62    }
63
64    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Connect, Error> {
65        let variable_header_index = fixed_header.fixed_header_len;
66        bytes.advance(variable_header_index);
67
68        // Variable header
69        let protocol_name = read_mqtt_string(&mut bytes)?;
70        let protocol_level = read_u8(&mut bytes)?;
71        if protocol_name != "MQTT" {
72            return Err(Error::InvalidProtocol);
73        }
74
75        let protocol = match protocol_level {
76            4 => Protocol::V4,
77            5 => Protocol::V5,
78            num => return Err(Error::InvalidProtocolLevel(num)),
79        };
80
81        let connect_flags = read_u8(&mut bytes)?;
82        let clean_session = (connect_flags & 0b10) != 0;
83        let keep_alive = read_u16(&mut bytes)?;
84
85        let client_id = read_mqtt_string(&mut bytes)?;
86        let last_will = LastWill::read(connect_flags, &mut bytes)?;
87        let login = Login::read(connect_flags, &mut bytes)?;
88
89        let connect = Connect {
90            protocol,
91            keep_alive,
92            client_id,
93            clean_session,
94            last_will,
95            login,
96        };
97
98        Ok(connect)
99    }
100
101    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
102        let len = self.len();
103        buffer.put_u8(0b0001_0000);
104        let count = write_remaining_length(buffer, len)?;
105        write_mqtt_string(buffer, "MQTT");
106
107        match self.protocol {
108            Protocol::V4 => buffer.put_u8(0x04),
109            Protocol::V5 => buffer.put_u8(0x05),
110        }
111
112        let flags_index = 1 + count + 2 + 4 + 1;
113
114        let mut connect_flags = 0;
115        if self.clean_session {
116            connect_flags |= 0x02;
117        }
118
119        buffer.put_u8(connect_flags);
120        buffer.put_u16(self.keep_alive);
121        write_mqtt_string(buffer, &self.client_id);
122
123        if let Some(last_will) = &self.last_will {
124            connect_flags |= last_will.write(buffer)?;
125        }
126
127        if let Some(login) = &self.login {
128            connect_flags |= login.write(buffer);
129        }
130
131        // update connect flags
132        buffer[flags_index] = connect_flags;
133        Ok(len)
134    }
135}
136
137/// LastWill that broker forwards on behalf of the client
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct LastWill {
140    pub topic: String,
141    pub message: Bytes,
142    pub qos: QoS,
143    pub retain: bool,
144}
145
146impl LastWill {
147    pub fn new(
148        topic: impl Into<String>,
149        payload: impl Into<Vec<u8>>,
150        qos: QoS,
151        retain: bool,
152    ) -> LastWill {
153        LastWill {
154            topic: topic.into(),
155            message: Bytes::from(payload.into()),
156            qos,
157            retain,
158        }
159    }
160
161    fn len(&self) -> usize {
162        let mut len = 0;
163        len += 2 + self.topic.len() + 2 + self.message.len();
164        len
165    }
166
167    fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
168        let last_will = match connect_flags & 0b100 {
169            0 if (connect_flags & 0b0011_1000) != 0 => {
170                return Err(Error::IncorrectPacketFormat);
171            }
172            0 => None,
173            _ => {
174                let will_topic = read_mqtt_string(bytes)?;
175                let will_message = read_mqtt_bytes(bytes)?;
176                let will_qos = qos((connect_flags & 0b11000) >> 3)?;
177                Some(LastWill {
178                    topic: will_topic,
179                    message: will_message,
180                    qos: will_qos,
181                    retain: (connect_flags & 0b0010_0000) != 0,
182                })
183            }
184        };
185
186        Ok(last_will)
187    }
188
189    fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
190        let mut connect_flags = 0;
191
192        connect_flags |= 0x04 | (self.qos as u8) << 3;
193        if self.retain {
194            connect_flags |= 0x20;
195        }
196
197        write_mqtt_string(buffer, &self.topic);
198        write_mqtt_bytes(buffer, &self.message);
199        Ok(connect_flags)
200    }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct Login {
205    pub username: String,
206    pub password: String,
207}
208
209impl Login {
210    pub fn new<U: Into<String>, P: Into<String>>(u: U, p: P) -> Login {
211        Login {
212            username: u.into(),
213            password: p.into(),
214        }
215    }
216
217    fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Login>, Error> {
218        let username = match connect_flags & 0b1000_0000 {
219            0 => String::new(),
220            _ => read_mqtt_string(bytes)?,
221        };
222
223        let password = match connect_flags & 0b0100_0000 {
224            0 => String::new(),
225            _ => read_mqtt_string(bytes)?,
226        };
227
228        if username.is_empty() && password.is_empty() {
229            Ok(None)
230        } else {
231            Ok(Some(Login { username, password }))
232        }
233    }
234
235    fn len(&self) -> usize {
236        let mut len = 0;
237
238        if !self.username.is_empty() {
239            len += 2 + self.username.len();
240        }
241
242        if !self.password.is_empty() {
243            len += 2 + self.password.len();
244        }
245
246        len
247    }
248
249    fn write(&self, buffer: &mut BytesMut) -> u8 {
250        let mut connect_flags = 0;
251        if !self.username.is_empty() {
252            connect_flags |= 0x80;
253            write_mqtt_string(buffer, &self.username);
254        }
255
256        if !self.password.is_empty() {
257            connect_flags |= 0x40;
258            write_mqtt_string(buffer, &self.password);
259        }
260
261        connect_flags
262    }
263
264    pub fn validate(&self, username: &str, password: &str) -> bool {
265        (self.username == *username) && (self.password == *password)
266    }
267}
268
269#[cfg(test)]
270mod test {
271    use super::*;
272    use bytes::BytesMut;
273    use pretty_assertions::assert_eq;
274
275    #[test]
276    fn connect_parsing_works() {
277        let mut stream = bytes::BytesMut::new();
278        let packetstream = &[
279            0x10,
280            39, // packet type, flags and remaining len
281            0x00,
282            0x04,
283            b'M',
284            b'Q',
285            b'T',
286            b'T',
287            0x04,        // variable header
288            0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session
289            0x00,
290            0x0a, // variable header. keep alive = 10 sec
291            0x00,
292            0x04,
293            b't',
294            b'e',
295            b's',
296            b't', // payload. client_id
297            0x00,
298            0x02,
299            b'/',
300            b'a', // payload. will topic = '/a'
301            0x00,
302            0x07,
303            b'o',
304            b'f',
305            b'f',
306            b'l',
307            b'i',
308            b'n',
309            b'e', // payload. variable header. will msg = 'offline'
310            0x00,
311            0x04,
312            b'r',
313            b'u',
314            b'm',
315            b'q', // payload. username = 'rumq'
316            0x00,
317            0x02,
318            b'm',
319            b'q', // payload. password = 'mq'
320            0xDE,
321            0xAD,
322            0xBE,
323            0xEF, // extra packets in the stream
324        ];
325
326        stream.extend_from_slice(&packetstream[..]);
327        let fixed_header = parse_fixed_header(stream.iter()).unwrap();
328        let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze();
329        let packet = Connect::read(fixed_header, connect_bytes).unwrap();
330
331        assert_eq!(
332            packet,
333            Connect {
334                protocol: Protocol::V4,
335                keep_alive: 10,
336                client_id: "test".to_owned(),
337                clean_session: true,
338                last_will: Some(LastWill::new("/a", "offline", QoS::AtLeastOnce, false)),
339                login: Some(Login::new("rumq", "mq")),
340            }
341        );
342    }
343
344    fn sample_bytes() -> Vec<u8> {
345        vec![
346            0x10,
347            39,
348            0x00,
349            0x04,
350            b'M',
351            b'Q',
352            b'T',
353            b'T',
354            0x04,
355            0b1100_1110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session
356            0x00,
357            0x0a, // 10 sec
358            0x00,
359            0x04,
360            b't',
361            b'e',
362            b's',
363            b't', // client_id
364            0x00,
365            0x02,
366            b'/',
367            b'a', // will topic = '/a'
368            0x00,
369            0x07,
370            b'o',
371            b'f',
372            b'f',
373            b'l',
374            b'i',
375            b'n',
376            b'e', // will msg = 'offline'
377            0x00,
378            0x04,
379            b'r',
380            b'u',
381            b's',
382            b't', // username = 'rust'
383            0x00,
384            0x02,
385            b'm',
386            b'q', // password = 'mq'
387        ]
388    }
389
390    #[test]
391    fn connect_encoding_works() {
392        let connect = Connect {
393            protocol: Protocol::V4,
394            keep_alive: 10,
395            client_id: "test".to_owned(),
396            clean_session: true,
397            last_will: Some(LastWill::new("/a", "offline", QoS::AtLeastOnce, false)),
398            login: Some(Login::new("rust", "mq")),
399        };
400
401        let mut buf = BytesMut::new();
402        connect.write(&mut buf).unwrap();
403
404        // println!("{:?}", &buf[..]);
405        // println!("{:?}", sample_bytes());
406
407        assert_eq!(buf, sample_bytes());
408    }
409}