1use super::*;
2use bytes::{Buf, Bytes};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Connect {
7 pub protocol: Protocol,
9 pub keep_alive: u16,
11 pub client_id: String,
13 pub clean_session: bool,
15 pub last_will: Option<LastWill>,
17 pub login: Option<Login>,
19}
20
21impl Connect {
22 pub fn new<S: Into<String>>(id: S) -> Connect {
23 Connect {
24 protocol: Protocol::V4,
25 keep_alive: 10,
26 client_id: id.into(),
27 clean_session: true,
28 last_will: None,
29 login: None,
30 }
31 }
32
33 pub fn set_login<U: Into<String>, P: Into<String>>(&mut self, u: U, p: P) -> &mut Connect {
34 let login = Login {
35 username: u.into(),
36 password: p.into(),
37 };
38
39 self.login = Some(login);
40 self
41 }
42
43 fn len(&self) -> usize {
44 let mut len = 2 + "MQTT".len() + 1 + 1 + 2; len += 2 + self.client_id.len();
50
51 if let Some(last_will) = &self.last_will {
53 len += last_will.len();
54 }
55
56 if let Some(login) = &self.login {
58 len += login.len();
59 }
60
61 len
62 }
63
64 pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Connect, Error> {
65 let variable_header_index = fixed_header.fixed_header_len;
66 bytes.advance(variable_header_index);
67
68 let protocol_name = read_mqtt_string(&mut bytes)?;
70 let protocol_level = read_u8(&mut bytes)?;
71 if protocol_name != "MQTT" {
72 return Err(Error::InvalidProtocol);
73 }
74
75 let protocol = match protocol_level {
76 4 => Protocol::V4,
77 5 => Protocol::V5,
78 num => return Err(Error::InvalidProtocolLevel(num)),
79 };
80
81 let connect_flags = read_u8(&mut bytes)?;
82 let clean_session = (connect_flags & 0b10) != 0;
83 let keep_alive = read_u16(&mut bytes)?;
84
85 let client_id = read_mqtt_string(&mut bytes)?;
86 let last_will = LastWill::read(connect_flags, &mut bytes)?;
87 let login = Login::read(connect_flags, &mut bytes)?;
88
89 let connect = Connect {
90 protocol,
91 keep_alive,
92 client_id,
93 clean_session,
94 last_will,
95 login,
96 };
97
98 Ok(connect)
99 }
100
101 pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
102 let len = self.len();
103 buffer.put_u8(0b0001_0000);
104 let count = write_remaining_length(buffer, len)?;
105 write_mqtt_string(buffer, "MQTT");
106
107 match self.protocol {
108 Protocol::V4 => buffer.put_u8(0x04),
109 Protocol::V5 => buffer.put_u8(0x05),
110 }
111
112 let flags_index = 1 + count + 2 + 4 + 1;
113
114 let mut connect_flags = 0;
115 if self.clean_session {
116 connect_flags |= 0x02;
117 }
118
119 buffer.put_u8(connect_flags);
120 buffer.put_u16(self.keep_alive);
121 write_mqtt_string(buffer, &self.client_id);
122
123 if let Some(last_will) = &self.last_will {
124 connect_flags |= last_will.write(buffer)?;
125 }
126
127 if let Some(login) = &self.login {
128 connect_flags |= login.write(buffer);
129 }
130
131 buffer[flags_index] = connect_flags;
133 Ok(len)
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct LastWill {
140 pub topic: String,
141 pub message: Bytes,
142 pub qos: QoS,
143 pub retain: bool,
144}
145
146impl LastWill {
147 pub fn new(
148 topic: impl Into<String>,
149 payload: impl Into<Vec<u8>>,
150 qos: QoS,
151 retain: bool,
152 ) -> LastWill {
153 LastWill {
154 topic: topic.into(),
155 message: Bytes::from(payload.into()),
156 qos,
157 retain,
158 }
159 }
160
161 fn len(&self) -> usize {
162 let mut len = 0;
163 len += 2 + self.topic.len() + 2 + self.message.len();
164 len
165 }
166
167 fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<LastWill>, Error> {
168 let last_will = match connect_flags & 0b100 {
169 0 if (connect_flags & 0b0011_1000) != 0 => {
170 return Err(Error::IncorrectPacketFormat);
171 }
172 0 => None,
173 _ => {
174 let will_topic = read_mqtt_string(bytes)?;
175 let will_message = read_mqtt_bytes(bytes)?;
176 let will_qos = qos((connect_flags & 0b11000) >> 3)?;
177 Some(LastWill {
178 topic: will_topic,
179 message: will_message,
180 qos: will_qos,
181 retain: (connect_flags & 0b0010_0000) != 0,
182 })
183 }
184 };
185
186 Ok(last_will)
187 }
188
189 fn write(&self, buffer: &mut BytesMut) -> Result<u8, Error> {
190 let mut connect_flags = 0;
191
192 connect_flags |= 0x04 | (self.qos as u8) << 3;
193 if self.retain {
194 connect_flags |= 0x20;
195 }
196
197 write_mqtt_string(buffer, &self.topic);
198 write_mqtt_bytes(buffer, &self.message);
199 Ok(connect_flags)
200 }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct Login {
205 pub username: String,
206 pub password: String,
207}
208
209impl Login {
210 pub fn new<U: Into<String>, P: Into<String>>(u: U, p: P) -> Login {
211 Login {
212 username: u.into(),
213 password: p.into(),
214 }
215 }
216
217 fn read(connect_flags: u8, bytes: &mut Bytes) -> Result<Option<Login>, Error> {
218 let username = match connect_flags & 0b1000_0000 {
219 0 => String::new(),
220 _ => read_mqtt_string(bytes)?,
221 };
222
223 let password = match connect_flags & 0b0100_0000 {
224 0 => String::new(),
225 _ => read_mqtt_string(bytes)?,
226 };
227
228 if username.is_empty() && password.is_empty() {
229 Ok(None)
230 } else {
231 Ok(Some(Login { username, password }))
232 }
233 }
234
235 fn len(&self) -> usize {
236 let mut len = 0;
237
238 if !self.username.is_empty() {
239 len += 2 + self.username.len();
240 }
241
242 if !self.password.is_empty() {
243 len += 2 + self.password.len();
244 }
245
246 len
247 }
248
249 fn write(&self, buffer: &mut BytesMut) -> u8 {
250 let mut connect_flags = 0;
251 if !self.username.is_empty() {
252 connect_flags |= 0x80;
253 write_mqtt_string(buffer, &self.username);
254 }
255
256 if !self.password.is_empty() {
257 connect_flags |= 0x40;
258 write_mqtt_string(buffer, &self.password);
259 }
260
261 connect_flags
262 }
263
264 pub fn validate(&self, username: &str, password: &str) -> bool {
265 (self.username == *username) && (self.password == *password)
266 }
267}
268
269#[cfg(test)]
270mod test {
271 use super::*;
272 use bytes::BytesMut;
273 use pretty_assertions::assert_eq;
274
275 #[test]
276 fn connect_parsing_works() {
277 let mut stream = bytes::BytesMut::new();
278 let packetstream = &[
279 0x10,
280 39, 0x00,
282 0x04,
283 b'M',
284 b'Q',
285 b'T',
286 b'T',
287 0x04, 0b1100_1110, 0x00,
290 0x0a, 0x00,
292 0x04,
293 b't',
294 b'e',
295 b's',
296 b't', 0x00,
298 0x02,
299 b'/',
300 b'a', 0x00,
302 0x07,
303 b'o',
304 b'f',
305 b'f',
306 b'l',
307 b'i',
308 b'n',
309 b'e', 0x00,
311 0x04,
312 b'r',
313 b'u',
314 b'm',
315 b'q', 0x00,
317 0x02,
318 b'm',
319 b'q', 0xDE,
321 0xAD,
322 0xBE,
323 0xEF, ];
325
326 stream.extend_from_slice(&packetstream[..]);
327 let fixed_header = parse_fixed_header(stream.iter()).unwrap();
328 let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze();
329 let packet = Connect::read(fixed_header, connect_bytes).unwrap();
330
331 assert_eq!(
332 packet,
333 Connect {
334 protocol: Protocol::V4,
335 keep_alive: 10,
336 client_id: "test".to_owned(),
337 clean_session: true,
338 last_will: Some(LastWill::new("/a", "offline", QoS::AtLeastOnce, false)),
339 login: Some(Login::new("rumq", "mq")),
340 }
341 );
342 }
343
344 fn sample_bytes() -> Vec<u8> {
345 vec![
346 0x10,
347 39,
348 0x00,
349 0x04,
350 b'M',
351 b'Q',
352 b'T',
353 b'T',
354 0x04,
355 0b1100_1110, 0x00,
357 0x0a, 0x00,
359 0x04,
360 b't',
361 b'e',
362 b's',
363 b't', 0x00,
365 0x02,
366 b'/',
367 b'a', 0x00,
369 0x07,
370 b'o',
371 b'f',
372 b'f',
373 b'l',
374 b'i',
375 b'n',
376 b'e', 0x00,
378 0x04,
379 b'r',
380 b'u',
381 b's',
382 b't', 0x00,
384 0x02,
385 b'm',
386 b'q', ]
388 }
389
390 #[test]
391 fn connect_encoding_works() {
392 let connect = Connect {
393 protocol: Protocol::V4,
394 keep_alive: 10,
395 client_id: "test".to_owned(),
396 clean_session: true,
397 last_will: Some(LastWill::new("/a", "offline", QoS::AtLeastOnce, false)),
398 login: Some(Login::new("rust", "mq")),
399 };
400
401 let mut buf = BytesMut::new();
402 connect.write(&mut buf).unwrap();
403
404 assert_eq!(buf, sample_bytes());
408 }
409}