1use anyhow::Result;
6use argon2::{Argon2, PasswordHasher, PasswordVerifier, password_hash::PasswordHash};
7use base64::Engine;
8use chrono::Utc;
9use reqwest::Client;
10use rocket::http::Status;
11use rocket::request::{FromRequest, Outcome};
12use rocket::{Request, State, get, post};
13use rumqttc::{AsyncClient, QoS};
14use serde::{Deserialize, Serialize};
15use sqlx::{PgPool, Row};
16use tokio::io::AsyncReadExt;
17use uuid::Uuid;
18
19use super::SpeechbrainUrl;
20use super::Token;
21use super::mqtt::publish_control_message;
22
23#[derive(Clone)]
25pub struct DeviceToken(pub String);
26
27#[rocket::async_trait]
30impl<'r> FromRequest<'r> for DeviceToken {
31 type Error = &'static str;
32
33 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
35 let auth_header = req.headers().get_one("Authorization");
36 match auth_header {
37 Some(auth) if auth.starts_with("Bearer ") => {
38 let token_str = &auth[7..];
39 Outcome::Success(DeviceToken(token_str.to_string()))
40 }
41 _ => Outcome::Error((
42 Status::Unauthorized,
43 "Missing or invalid Authorization header",
44 )),
45 }
46 }
47}
48
49#[derive(Deserialize)]
51pub struct ControlRequest {
52 command: String,
54 user_id: String,
56}
57
58#[derive(Deserialize)]
60pub struct LockStatusMessage {
61 pub lock: String,
63 pub reason: String,
65 #[allow(dead_code)]
67 pub uptime_ms: u64,
68 pub timestamp: u64,
70}
71
72#[derive(Serialize)]
74pub struct LogEntry {
75 id: i32,
77 device_id: String,
79 timestamp: chrono::DateTime<chrono::Utc>,
81 event_type: String,
83 reason: String,
85 user_id: Option<String>,
87 user_name: Option<String>,
89}
90
91#[derive(Deserialize)]
93pub struct UpdateConfigRequest {
94 configs: Vec<ConfigItem>,
96}
97
98#[derive(Deserialize, Debug)]
100pub struct ConfigItem {
101 key: String,
103 value: String,
105}
106
107#[derive(Deserialize)]
109pub struct RegisterDeviceRequest {
110 device_id: String,
112 user_key: String,
114 user_id: String,
116}
117
118#[derive(sqlx::FromRow)]
120struct DeviceVoiceRow {
121 user_id: Option<String>,
123 voice_invite_enable: Option<bool>,
125 voice_threshold: Option<f64>,
127 hashed_passphrase: Option<String>,
129}
130
131#[post("/update_config/<uuid>", data = "<request>")]
134pub async fn update_config(
135 token: Token,
136 uuid: &str,
137 request: rocket::serde::json::Json<UpdateConfigRequest>,
138 db_pool: &State<PgPool>,
139 mqtt_client: &State<AsyncClient>,
140) -> Result<(), Status> {
141 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
142
143 let user_row: Option<(String,)> =
145 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
146 .bind(&token.0)
147 .fetch_optional(&**db_pool)
148 .await
149 .map_err(|_| Status::InternalServerError)?;
150 let firebase_uid = match user_row {
151 Some((uid,)) => uid,
152 None => return Err(Status::Unauthorized),
153 };
154
155 let row: Option<(Option<String>,)> =
157 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
158 .bind(uuid_parsed)
159 .fetch_optional(&**db_pool)
160 .await
161 .map_err(|_| Status::InternalServerError)?;
162 if let Some((Some(owner_id),)) = row {
163 if firebase_uid != owner_id {
164 return Err(Status::Unauthorized);
165 }
166 } else {
167 return Err(Status::NotFound);
168 }
169
170 for config in &request.configs {
172 match config.key.as_str() {
173 "wifi_ssid" => {
174 if config.value.is_empty() {
175 return Err(Status::BadRequest);
176 }
177 }
178 "wifi_pass" => {} "audio_timeout" => {
180 let val: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
181 if !(3..=10).contains(&val) {
182 return Err(Status::BadRequest);
183 }
184 }
185 "lock_timeout" => {
186 let val: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
187 if !(5000..=300000).contains(&val) {
188 return Err(Status::BadRequest);
190 }
191 }
192 "pairing_timeout" => {
193 let val: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
194 if !(60..=600).contains(&val) {
195 return Err(Status::BadRequest);
196 }
197 }
198 "voice_detection_enable" => {
199 let val: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
200 if !(0..=1).contains(&val) {
201 return Err(Status::BadRequest);
202 }
203 }
204 "voice_invite_enable" => {
205 let val: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
206 if !(0..=1).contains(&val) {
207 return Err(Status::BadRequest);
208 }
209 }
210 "voice_threshold" => {
211 let val: f64 = config.value.parse().map_err(|_| Status::BadRequest)?;
212 if !(0.20..=0.90).contains(&val) {
213 return Err(Status::BadRequest);
214 }
215 }
216 "vad_rms_threshold" => {
217 let val: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
218 if !(500..=25000).contains(&val) {
219 return Err(Status::BadRequest);
220 }
221 }
222 _ => {
223 return Err(Status::BadRequest);
224 }
225 }
226 }
227
228 let mut backend_configs = Vec::new();
230 let mut device_configs = Vec::new();
231
232 for config in &request.configs {
233 match config.key.as_str() {
234 "voice_threshold" | "voice_invite_enable" => {
235 backend_configs.push(config);
236 }
237 _ => {
238 device_configs.push(config);
239 }
240 }
241 }
242
243 for config in backend_configs {
245 match config.key.as_str() {
246 "voice_threshold" => {
247 let threshold: f64 = config.value.parse().map_err(|_| Status::BadRequest)?;
248 sqlx::query("UPDATE devices SET voice_threshold = $1 WHERE uuid = $2")
249 .bind(threshold)
250 .bind(uuid_parsed)
251 .execute(&**db_pool)
252 .await
253 .map_err(|_| Status::InternalServerError)?;
254 }
255 "voice_invite_enable" => {
256 let enable: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
257 let enable_bool = enable == 1;
258 sqlx::query("UPDATE devices SET voice_invite_enable = $1 WHERE uuid = $2")
259 .bind(enable_bool)
260 .bind(uuid_parsed)
261 .execute(&**db_pool)
262 .await
263 .map_err(|_| Status::InternalServerError)?;
264 }
265 "vad_rms_threshold" => {
266 let threshold: i32 = config.value.parse().map_err(|_| Status::BadRequest)?;
267 sqlx::query("UPDATE devices SET vad_rms_threshold = $1 WHERE uuid = $2")
268 .bind(threshold)
269 .bind(uuid_parsed)
270 .execute(&**db_pool)
271 .await
272 .map_err(|_| Status::InternalServerError)?;
273 }
274 _ => {} }
276 }
277
278 for config in device_configs {
280 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
281 {
282 let updates_mutex = super::PENDING_CONFIG_UPDATES.get().unwrap();
283 let mut updates = updates_mutex.lock().unwrap();
284 updates.insert(uuid.to_string(), tx);
285 }
286
287 let topic = format!("lockwise/{}/control", uuid);
289 let msg = serde_cbor::to_vec(&serde_json::json!({
290 "command": "update_config",
291 "key": config.key,
292 "value": config.value
293 }))
294 .map_err(|_| Status::InternalServerError)?;
295 mqtt_client
296 .publish(topic, QoS::AtMostOnce, false, msg)
297 .await
298 .map_err(|_| Status::InternalServerError)?;
299
300 match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
302 Ok(Ok(())) => {}
303 Ok(Err(_)) => {
304 return Err(Status::InternalServerError);
305 }
306 Err(_) => {
307 return Err(Status::RequestTimeout);
308 }
309 }
310 }
311
312 Ok(())
313}
314
315#[post("/reboot/<uuid>")]
317pub async fn reboot_device(
318 token: Token,
319 uuid: &str,
320 db_pool: &State<PgPool>,
321 mqtt_client: &State<AsyncClient>,
322) -> Result<(), Status> {
323 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
324
325 let user_row: Option<(String,)> =
327 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
328 .bind(&token.0)
329 .fetch_optional(&**db_pool)
330 .await
331 .map_err(|_| Status::InternalServerError)?;
332 let firebase_uid = match user_row {
333 Some((uid,)) => uid,
334 None => return Err(Status::Unauthorized),
335 };
336
337 let row: Option<(Option<String>,)> =
339 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
340 .bind(uuid_parsed)
341 .fetch_optional(&**db_pool)
342 .await
343 .map_err(|_| Status::InternalServerError)?;
344 if let Some((Some(owner_id),)) = row {
345 if firebase_uid != owner_id {
346 return Err(Status::Unauthorized);
347 }
348 } else {
349 return Err(Status::NotFound);
350 }
351
352 publish_control_message(mqtt_client, uuid_parsed, "REBOOT".to_string())
354 .await
355 .map_err(|_| Status::InternalServerError)?;
356
357 Ok(())
358}
359
360#[post("/lockdown/<uuid>")]
362pub async fn lockdown_device(
363 token: Token,
364 uuid: &str,
365 db_pool: &State<PgPool>,
366 mqtt_client: &State<AsyncClient>,
367) -> Result<(), Status> {
368 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
369
370 let user_row: Option<(String,)> =
372 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
373 .bind(&token.0)
374 .fetch_optional(&**db_pool)
375 .await
376 .map_err(|_| Status::InternalServerError)?;
377 let firebase_uid = match user_row {
378 Some((uid,)) => uid,
379 None => return Err(Status::Unauthorized),
380 };
381
382 let row: Option<(Option<String>,)> =
384 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
385 .bind(uuid_parsed)
386 .fetch_optional(&**db_pool)
387 .await
388 .map_err(|_| Status::InternalServerError)?;
389 if let Some((Some(owner_id),)) = row {
390 if firebase_uid != owner_id {
391 return Err(Status::Unauthorized);
392 }
393 } else {
394 return Err(Status::NotFound);
395 }
396
397 publish_control_message(mqtt_client, uuid_parsed, "LOCKDOWN".to_string())
399 .await
400 .map_err(|_| Status::InternalServerError)?;
401
402 Ok(())
403}
404
405#[post("/ping/<uuid>")]
407pub async fn ping_device(
408 token: Token,
409 uuid: &str,
410 db_pool: &State<PgPool>,
411 mqtt_client: &State<AsyncClient>,
412) -> Result<(), Status> {
413 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
414
415 let user_row: Option<(String,)> =
417 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
418 .bind(&token.0)
419 .fetch_optional(&**db_pool)
420 .await
421 .map_err(|_| Status::InternalServerError)?;
422 let firebase_uid = match user_row {
423 Some((uid,)) => uid,
424 None => return Err(Status::Unauthorized),
425 };
426
427 let row: Option<(Option<String>,)> =
429 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
430 .bind(uuid_parsed)
431 .fetch_optional(&**db_pool)
432 .await
433 .map_err(|_| Status::InternalServerError)?;
434 if let Some((db_user_id_opt,)) = row {
435 let has_access = if let Some(db_user_id) = db_user_id_opt {
436 if firebase_uid == db_user_id {
438 true
439 } else {
440 let now = Utc::now().timestamp_millis();
442 let invite_row: Option<(i32,)> = sqlx::query_as(
443 "SELECT id FROM invites WHERE device_id = $1 AND receiver_id = $2 AND status = 1 AND expiry_timestamp > $3"
444 )
445 .bind(uuid_parsed)
446 .bind(&firebase_uid)
447 .bind(now)
448 .fetch_optional(&**db_pool)
449 .await
450 .map_err(|_| Status::InternalServerError)?;
451 invite_row.is_some()
452 }
453 } else {
454 false };
456
457 if !has_access {
458 return Err(Status::Unauthorized);
459 }
460 } else {
461 return Err(Status::NotFound);
462 }
463
464 publish_control_message(mqtt_client, uuid_parsed, "PING".to_string())
466 .await
467 .map_err(|_| Status::InternalServerError)?;
468
469 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
471 let start = chrono::Utc::now().timestamp_millis();
472 {
473 let pings_mutex = super::PENDING_PINGS.get().unwrap();
474 let mut pings = pings_mutex.lock().unwrap();
475 pings.insert(uuid.to_string(), (start, tx));
476 }
477
478 tokio::time::timeout(std::time::Duration::from_secs(10), rx)
480 .await
481 .map_err(|_| Status::RequestTimeout)?
482 .map_err(|_| Status::InternalServerError)?;
483
484 Ok(())
485}
486
487#[get("/devices")]
489pub async fn get_devices(token: Token, db_pool: &State<PgPool>) -> Result<String, Status> {
490 let user_row: Option<(String,)> =
492 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
493 .bind(&token.0)
494 .fetch_optional(&**db_pool)
495 .await
496 .map_err(|_| Status::InternalServerError)?;
497 let firebase_uid = match user_row {
498 Some((uid,)) => uid,
499 None => return Err(Status::Unauthorized),
500 };
501
502 let rows = sqlx::query(
503 "SELECT uuid, user_id, last_heard, uptime_ms, wifi_ssid, backend_url, mqtt_broker_url, mqtt_heartbeat_enable, mqtt_heartbeat_interval_sec, audio_record_timeout_sec, lock_timeout_ms, pairing_timeout_sec, lock_state, locked_down_at, voice_detection_enable, voice_invite_enable, voice_threshold, vad_rms_threshold FROM devices WHERE user_id = $1",
504 )
505 .bind(&firebase_uid)
506 .fetch_all(&**db_pool)
507 .await
508 .map_err(|_| Status::InternalServerError)?;
509
510 let rows_vec: Vec<_> = rows;
511 let devices: Vec<serde_json::Value> = rows_vec
512 .into_iter()
513 .map(|row| {
514 let db_uuid: Uuid = row.get(0);
515 let last_heard: chrono::DateTime<chrono::Utc> = row.get(2);
516 let uptime_ms: Option<i64> = row.get(3);
517 let wifi_ssid: Option<String> = row.get(4);
518 let backend_url: Option<String> = row.get(5);
519 let mqtt_broker_url: Option<String> = row.get(6);
520 let mqtt_heartbeat_enable: Option<bool> = row.get(7);
521 let mqtt_heartbeat_interval_sec: Option<i32> = row.get(8);
522 let audio_record_timeout_sec: Option<i32> = row.get(9);
523 let lock_timeout_ms: Option<i32> = row.get(10);
524 let pairing_timeout_sec: Option<i32> = row.get(11);
525 let lock_state: Option<String> = row.get(12);
526 let locked_down_at: Option<chrono::DateTime<chrono::Utc>> = row.get(13);
527 let voice_detection_enable: Option<bool> = row.get(14);
528 let voice_invite_enable: Option<bool> = row.get(15);
529 let voice_threshold: Option<f64> = row.get(16);
530 let vad_rms_threshold: Option<i32> = row.get(17);
531 serde_json::json!({
532 "uuid": db_uuid.to_string(),
533 "user_id": firebase_uid,
534 "last_heard": last_heard.timestamp_millis(),
535 "uptime_ms": uptime_ms,
536 "wifi_ssid": wifi_ssid,
537 "backend_url": backend_url,
538 "mqtt_broker_url": mqtt_broker_url,
539 "mqtt_heartbeat_enable": mqtt_heartbeat_enable,
540 "mqtt_heartbeat_interval_sec": mqtt_heartbeat_interval_sec,
541 "audio_record_timeout_sec": audio_record_timeout_sec,
542 "lock_timeout_ms": lock_timeout_ms,
543 "pairing_timeout_sec": pairing_timeout_sec,
544 "lock_state": lock_state,
545 "locked_down_at": locked_down_at.map(|dt| dt.timestamp_millis()),
546 "voice_detection_enable": voice_detection_enable,
547 "voice_invite_enable": voice_invite_enable,
548 "voice_threshold": voice_threshold,
549 "vad_rms_threshold": vad_rms_threshold
550 })
551 })
552 .collect();
553
554 Ok(serde_json::to_string(&devices).unwrap())
555}
556
557#[post("/register_device", data = "<request>")]
559pub async fn register_device(
560 request: rocket::serde::json::Json<RegisterDeviceRequest>,
561 db_pool: &State<PgPool>,
562) -> Result<(), Status> {
563 let device_uuid = Uuid::parse_str(&request.device_id).map_err(|_| Status::BadRequest)?;
564
565 let salt =
567 argon2::password_hash::SaltString::generate(&mut argon2::password_hash::rand_core::OsRng);
568 let argon2 = Argon2::default();
569 let hashed_key = argon2
570 .hash_password(request.user_key.as_bytes(), &salt)
571 .map_err(|_| Status::InternalServerError)?
572 .to_string();
573
574 sqlx::query(
576 "INSERT INTO devices (uuid, user_id, hashed_passphrase, last_heard, uptime_ms) VALUES ($1, $2, $3, NOW(), 0)
577 ON CONFLICT (uuid) DO UPDATE SET user_id = $2, hashed_passphrase = $3, last_heard = NOW()"
578 )
579 .bind(device_uuid)
580 .bind(&request.user_id)
581 .bind(&hashed_key)
582 .execute(&**db_pool)
583 .await
584 .map_err(|_| Status::InternalServerError)?;
585
586 sqlx::query("DELETE FROM logs WHERE device_id = $1")
588 .bind(device_uuid.to_string())
589 .execute(&**db_pool)
590 .await
591 .map_err(|_| Status::InternalServerError)?;
592
593 sqlx::query("DELETE FROM invites WHERE device_id = $1")
595 .bind(device_uuid)
596 .execute(&**db_pool)
597 .await
598 .map_err(|_| Status::InternalServerError)?;
599
600 Ok(())
601}
602
603#[post("/control/<uuid>", data = "<request>")]
605pub async fn control_device(
606 token: Token,
607 uuid: &str,
608 request: rocket::serde::json::Json<ControlRequest>,
609 db_pool: &State<PgPool>,
610 mqtt_client: &State<AsyncClient>,
611) -> Result<(), Status> {
612 let uuid = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
613
614 let user_row: Option<(String,)> =
616 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
617 .bind(&token.0)
618 .fetch_optional(&**db_pool)
619 .await
620 .map_err(|_| Status::InternalServerError)?;
621 let firebase_uid = match user_row {
622 Some((uid,)) => uid,
623 None => return Err(Status::Unauthorized),
624 };
625
626 if firebase_uid != request.user_id {
628 return Err(Status::Unauthorized);
629 }
630
631 let device_row: Option<(Option<String>,)> =
633 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
634 .bind(uuid)
635 .fetch_optional(&**db_pool)
636 .await
637 .map_err(|_| Status::InternalServerError)?;
638
639 let has_access = if let Some((Some(owner_id),)) = device_row {
640 if request.user_id == owner_id {
642 true
643 } else {
644 let now = Utc::now().timestamp_millis();
646 let invite_row: Option<(i32,)> = sqlx::query_as(
647 "SELECT id FROM invites WHERE device_id = $1 AND receiver_id = $2 AND status = 1 AND expiry_timestamp > $3"
648 )
649 .bind(uuid)
650 .bind(&request.user_id)
651 .bind(now)
652 .fetch_optional(&**db_pool)
653 .await
654 .map_err(|_| Status::InternalServerError)?;
655 invite_row.is_some()
656 }
657 } else {
658 false };
660
661 if !has_access {
662 return Err(Status::Unauthorized);
663 }
664
665 let now = chrono::Utc::now().timestamp();
667 {
668 let commands_mutex = super::RECENT_COMMANDS.get().unwrap();
669 let mut commands = commands_mutex.lock().unwrap();
670 commands.insert(uuid.to_string(), (firebase_uid.clone(), now));
671 }
672
673 publish_control_message(mqtt_client, uuid, request.command.clone())
674 .await
675 .map_err(|_| Status::InternalServerError)?;
676 Ok(())
677}
678
679#[post("/unpair/<uuid>")]
681pub async fn unpair_device(
682 token: Token,
683 uuid: &str,
684 db_pool: &State<PgPool>,
685) -> Result<(), Status> {
686 let uuid = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
687
688 let user_row: Option<(String,)> =
690 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
691 .bind(&token.0)
692 .fetch_optional(&**db_pool)
693 .await
694 .map_err(|_| Status::InternalServerError)?;
695 let firebase_uid = match user_row {
696 Some((uid,)) => uid,
697 None => return Err(Status::Unauthorized),
698 };
699
700 let row: Option<(Option<String>,)> =
702 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
703 .bind(uuid)
704 .fetch_optional(&**db_pool)
705 .await
706 .map_err(|_| Status::InternalServerError)?;
707 if let Some((Some(db_user_id),)) = row {
708 if firebase_uid != db_user_id {
709 return Err(Status::Unauthorized);
710 }
711 } else {
712 return Err(Status::Unauthorized); }
714
715 sqlx::query("UPDATE devices SET user_id = NULL WHERE uuid = $1")
717 .bind(uuid)
718 .execute(&**db_pool)
719 .await
720 .map_err(|_| Status::InternalServerError)?;
721
722 sqlx::query("DELETE FROM logs WHERE device_id = $1")
724 .bind(uuid.to_string())
725 .execute(&**db_pool)
726 .await
727 .map_err(|_| Status::InternalServerError)?;
728
729 sqlx::query("DELETE FROM invites WHERE device_id = $1")
731 .bind(uuid)
732 .execute(&**db_pool)
733 .await
734 .map_err(|_| Status::InternalServerError)?;
735
736 Ok(())
737}
738
739#[get("/device/<uuid>")]
741pub async fn get_device(
742 token: Token,
743 uuid: &str,
744 db_pool: &State<PgPool>,
745) -> Result<String, Status> {
746 let uuid = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
747
748 let user_row: Option<(String,)> =
750 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
751 .bind(&token.0)
752 .fetch_optional(&**db_pool)
753 .await
754 .map_err(|_| Status::InternalServerError)?;
755 let firebase_uid = match user_row {
756 Some((uid,)) => uid,
757 None => return Err(Status::Unauthorized),
758 };
759
760 let row = sqlx::query("SELECT uuid, user_id, last_heard, uptime_ms, wifi_ssid, backend_url, mqtt_broker_url, mqtt_heartbeat_enable, mqtt_heartbeat_interval_sec, audio_record_timeout_sec, lock_timeout_ms, pairing_timeout_sec, lock_state, locked_down_at, voice_detection_enable, voice_invite_enable, voice_threshold, vad_rms_threshold FROM devices WHERE uuid = $1")
762 .bind(uuid)
763 .fetch_optional(&**db_pool)
764 .await
765 .map_err(|_| Status::InternalServerError)?;
766 if let Some(row) = row {
767 let db_uuid: Uuid = row.get(0);
768 let db_user_id_opt: Option<String> = row.get(1);
769 let last_heard: chrono::DateTime<chrono::Utc> = row.get(2);
770 let uptime_ms: Option<i64> = row.get(3);
771 let wifi_ssid: Option<String> = row.get(4);
772 let backend_url: Option<String> = row.get(5);
773 let mqtt_broker_url: Option<String> = row.get(6);
774 let mqtt_heartbeat_enable: Option<bool> = row.get(7);
775 let mqtt_heartbeat_interval_sec: Option<i32> = row.get(8);
776 let audio_record_timeout_sec: Option<i32> = row.get(9);
777 let lock_timeout_ms: Option<i32> = row.get(10);
778 let pairing_timeout_sec: Option<i32> = row.get(11);
779 let lock_state: Option<String> = row.get(12);
780 let locked_down_at: Option<chrono::DateTime<chrono::Utc>> = row.get(13);
781 let voice_detection_enable: Option<bool> = row.get(14);
782 let voice_invite_enable: Option<bool> = row.get(15);
783 let voice_threshold: Option<f64> = row.get(16);
784 let vad_rms_threshold: Option<i32> = row.get(17);
785
786 let has_access = if let Some(db_user_id) = db_user_id_opt {
787 if firebase_uid == db_user_id {
789 true
790 } else {
791 let now = Utc::now().timestamp_millis();
793 let invite_row: Option<(i32,)> = sqlx::query_as(
794 "SELECT id FROM invites WHERE device_id = $1 AND receiver_id = $2 AND status = 1 AND expiry_timestamp > $3"
795 )
796 .bind(uuid)
797 .bind(&firebase_uid)
798 .bind(now)
799 .fetch_optional(&**db_pool)
800 .await
801 .map_err(|_| Status::InternalServerError)?;
802 invite_row.is_some()
803 }
804 } else {
805 false };
807
808 if !has_access {
809 return Err(Status::Unauthorized);
810 }
811 let device = serde_json::json!({
812 "uuid": db_uuid.to_string(),
813 "user_id": firebase_uid,
814 "last_heard": last_heard.timestamp_millis(),
815 "uptime_ms": uptime_ms,
816 "wifi_ssid": wifi_ssid,
817 "backend_url": backend_url,
818 "mqtt_broker_url": mqtt_broker_url,
819 "mqtt_heartbeat_enable": mqtt_heartbeat_enable,
820 "mqtt_heartbeat_interval_sec": mqtt_heartbeat_interval_sec,
821 "audio_record_timeout_sec": audio_record_timeout_sec,
822 "lock_timeout_ms": lock_timeout_ms,
823 "pairing_timeout_sec": pairing_timeout_sec,
824 "lock_state": lock_state,
825 "locked_down_at": locked_down_at.map(|dt| dt.timestamp_millis()),
826 "voice_detection_enable": voice_detection_enable,
827 "voice_invite_enable": voice_invite_enable,
828 "voice_threshold": voice_threshold,
829 "vad_rms_threshold": vad_rms_threshold
830 });
831 Ok(device.to_string())
832 } else {
833 Err(Status::NotFound)
834 }
835}
836
837#[get("/temp_device/<uuid>")]
839pub async fn get_temp_device(
840 token: Token,
841 uuid: &str,
842 db_pool: &State<PgPool>,
843) -> Result<String, Status> {
844 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
845
846 let user_row: Option<(String,)> =
848 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
849 .bind(&token.0)
850 .fetch_optional(&**db_pool)
851 .await
852 .map_err(|_| Status::InternalServerError)?;
853 let firebase_uid = match user_row {
854 Some((uid,)) => uid,
855 None => return Err(Status::Unauthorized),
856 };
857
858 let now = Utc::now().timestamp_millis();
860 let invite_row: Option<(i32,)> = sqlx::query_as(
861 "SELECT id FROM invites WHERE device_id = $1 AND receiver_id = $2 AND status = 1 AND expiry_timestamp > $3"
862 )
863 .bind(uuid_parsed)
864 .bind(&firebase_uid)
865 .bind(now)
866 .fetch_optional(&**db_pool)
867 .await
868 .map_err(|_| Status::InternalServerError)?;
869 if invite_row.is_none() {
870 return Err(Status::Unauthorized);
871 }
872
873 let row = sqlx::query("SELECT uuid, user_id, last_heard, uptime_ms, wifi_ssid, backend_url, mqtt_broker_url, mqtt_heartbeat_enable, mqtt_heartbeat_interval_sec, audio_record_timeout_sec, lock_timeout_ms, pairing_timeout_sec, lock_state, locked_down_at, voice_detection_enable, voice_invite_enable, voice_threshold, vad_rms_threshold FROM devices WHERE uuid = $1")
875 .bind(uuid_parsed)
876 .fetch_optional(&**db_pool)
877 .await
878 .map_err(|_| Status::InternalServerError)?;
879 if let Some(row) = row {
880 let db_uuid: Uuid = row.get(0);
881 let last_heard: chrono::DateTime<chrono::Utc> = row.get(2);
882 let uptime_ms: Option<i64> = row.get(3);
883 let wifi_ssid: Option<String> = row.get(4);
884 let backend_url: Option<String> = row.get(5);
885 let mqtt_broker_url: Option<String> = row.get(6);
886 let mqtt_heartbeat_enable: Option<bool> = row.get(7);
887 let mqtt_heartbeat_interval_sec: Option<i32> = row.get(8);
888 let audio_record_timeout_sec: Option<i32> = row.get(9);
889 let lock_timeout_ms: Option<i32> = row.get(10);
890 let pairing_timeout_sec: Option<i32> = row.get(11);
891 let lock_state: Option<String> = row.get(12);
892 let locked_down_at: Option<chrono::DateTime<chrono::Utc>> = row.get(13);
893 let voice_detection_enable: Option<bool> = row.get(14);
894 let voice_invite_enable: Option<bool> = row.get(15);
895 let voice_threshold: Option<f64> = row.get(16);
896 let vad_rms_threshold: Option<i32> = row.get(17);
897
898 let device = serde_json::json!({
899 "uuid": db_uuid.to_string(),
900 "user_id": firebase_uid,
901 "last_heard": last_heard.timestamp_millis(),
902 "uptime_ms": uptime_ms,
903 "wifi_ssid": wifi_ssid,
904 "backend_url": backend_url,
905 "mqtt_broker_url": mqtt_broker_url,
906 "mqtt_heartbeat_enable": mqtt_heartbeat_enable,
907 "mqtt_heartbeat_interval_sec": mqtt_heartbeat_interval_sec,
908 "audio_record_timeout_sec": audio_record_timeout_sec,
909 "lock_timeout_ms": lock_timeout_ms,
910 "pairing_timeout_sec": pairing_timeout_sec,
911 "lock_state": lock_state,
912 "locked_down_at": locked_down_at.map(|dt| dt.timestamp_millis()),
913 "voice_detection_enable": voice_detection_enable,
914 "voice_invite_enable": voice_invite_enable,
915 "voice_threshold": voice_threshold,
916 "vad_rms_threshold": vad_rms_threshold
917 });
918 Ok(device.to_string())
919 } else {
920 Err(Status::NotFound)
921 }
922}
923
924#[post("/temp_control/<uuid>", data = "<request>")]
926pub async fn control_temp_device(
927 token: Token,
928 uuid: &str,
929 request: rocket::serde::json::Json<ControlRequest>,
930 db_pool: &State<PgPool>,
931 mqtt_client: &State<AsyncClient>,
932) -> Result<(), Status> {
933 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
934
935 let user_row: Option<(String,)> =
937 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
938 .bind(&token.0)
939 .fetch_optional(&**db_pool)
940 .await
941 .map_err(|_| Status::InternalServerError)?;
942 let firebase_uid = match user_row {
943 Some((uid,)) => uid,
944 None => return Err(Status::Unauthorized),
945 };
946
947 if firebase_uid != request.user_id {
949 return Err(Status::Unauthorized);
950 }
951
952 let now = Utc::now().timestamp_millis();
954 let invite_row: Option<(i32,)> = sqlx::query_as(
955 "SELECT id FROM invites WHERE device_id = $1 AND receiver_id = $2 AND status = 1 AND expiry_timestamp > $3"
956 )
957 .bind(uuid_parsed)
958 .bind(&request.user_id)
959 .bind(now)
960 .fetch_optional(&**db_pool)
961 .await
962 .map_err(|_| Status::InternalServerError)?;
963 if invite_row.is_none() {
964 return Err(Status::Unauthorized);
965 }
966
967 let now_ts = chrono::Utc::now().timestamp();
969 {
970 let commands_mutex = super::RECENT_COMMANDS.get().unwrap();
971 let mut commands = commands_mutex.lock().unwrap();
972 commands.insert(uuid.to_string(), (firebase_uid.clone(), now_ts));
973 }
974
975 publish_control_message(mqtt_client, uuid_parsed, request.command.clone())
976 .await
977 .map_err(|_| Status::InternalServerError)?;
978 Ok(())
979}
980
981#[post("/temp_ping/<uuid>")]
983pub async fn ping_temp_device(
984 token: Token,
985 uuid: &str,
986 db_pool: &State<PgPool>,
987 mqtt_client: &State<AsyncClient>,
988) -> Result<(), Status> {
989 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
990
991 let user_row: Option<(String,)> =
993 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
994 .bind(&token.0)
995 .fetch_optional(&**db_pool)
996 .await
997 .map_err(|_| Status::InternalServerError)?;
998 let firebase_uid = match user_row {
999 Some((uid,)) => uid,
1000 None => return Err(Status::Unauthorized),
1001 };
1002
1003 let now = Utc::now().timestamp_millis();
1005 let invite_row: Option<(i32,)> = sqlx::query_as(
1006 "SELECT id FROM invites WHERE device_id = $1 AND receiver_id = $2 AND status = 1 AND expiry_timestamp > $3"
1007 )
1008 .bind(uuid_parsed)
1009 .bind(&firebase_uid)
1010 .bind(now)
1011 .fetch_optional(&**db_pool)
1012 .await
1013 .map_err(|_| Status::InternalServerError)?;
1014 if invite_row.is_none() {
1015 return Err(Status::Unauthorized);
1016 }
1017
1018 publish_control_message(mqtt_client, uuid_parsed, "PING".to_string())
1020 .await
1021 .map_err(|_| Status::InternalServerError)?;
1022
1023 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
1025 let start = chrono::Utc::now().timestamp_millis();
1026 {
1027 let pings_mutex = super::PENDING_PINGS.get().unwrap();
1028 let mut pings = pings_mutex.lock().unwrap();
1029 pings.insert(uuid.to_string(), (start, tx));
1030 }
1031
1032 tokio::time::timeout(std::time::Duration::from_secs(10), rx)
1034 .await
1035 .map_err(|_| Status::RequestTimeout)?
1036 .map_err(|_| Status::InternalServerError)?;
1037 Ok(())
1038}
1039
1040#[get("/temp_devices_status")]
1042pub async fn get_temp_devices_status(
1043 token: Token,
1044 db_pool: &State<PgPool>,
1045) -> Result<String, Status> {
1046 let user_row: Option<(String,)> =
1048 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
1049 .bind(&token.0)
1050 .fetch_optional(&**db_pool)
1051 .await
1052 .map_err(|_| Status::InternalServerError)?;
1053 let firebase_uid = match user_row {
1054 Some((uid,)) => uid,
1055 None => return Err(Status::Unauthorized),
1056 };
1057
1058 let rows = sqlx::query(
1059 "SELECT d.uuid, d.user_id, d.last_heard, d.uptime_ms, d.wifi_ssid, d.backend_url, d.mqtt_broker_url, d.mqtt_heartbeat_enable, d.mqtt_heartbeat_interval_sec, d.audio_record_timeout_sec, d.lock_timeout_ms, d.pairing_timeout_sec, d.lock_state, d.locked_down_at, d.voice_detection_enable, d.voice_invite_enable, d.voice_threshold, d.vad_rms_threshold FROM devices d JOIN invites i ON d.uuid = i.device_id WHERE i.receiver_id = $1 AND i.status = 1 AND i.expiry_timestamp > $2"
1060 )
1061 .bind(&firebase_uid)
1062 .bind(Utc::now().timestamp_millis())
1063 .fetch_all(&**db_pool)
1064 .await
1065 .map_err(|_| Status::InternalServerError)?;
1066
1067 let devices: Vec<serde_json::Value> = rows
1068 .into_iter()
1069 .map(|row| {
1070 let db_uuid: Uuid = row.get(0);
1071 let last_heard: chrono::DateTime<chrono::Utc> = row.get(2);
1072 let uptime_ms: Option<i64> = row.get(3);
1073 let wifi_ssid: Option<String> = row.get(4);
1074 let backend_url: Option<String> = row.get(5);
1075 let mqtt_broker_url: Option<String> = row.get(6);
1076 let mqtt_heartbeat_enable: Option<bool> = row.get(7);
1077 let mqtt_heartbeat_interval_sec: Option<i32> = row.get(8);
1078 let audio_record_timeout_sec: Option<i32> = row.get(9);
1079 let lock_timeout_ms: Option<i32> = row.get(10);
1080 let pairing_timeout_sec: Option<i32> = row.get(11);
1081 let lock_state: Option<String> = row.get(12);
1082 let locked_down_at: Option<chrono::DateTime<chrono::Utc>> = row.get(13);
1083 let voice_detection_enable: Option<bool> = row.get(14);
1084 let voice_invite_enable: Option<bool> = row.get(15);
1085 let voice_threshold: Option<f64> = row.get(16);
1086 let vad_rms_threshold: Option<i32> = row.get(17);
1087 serde_json::json!({
1088 "uuid": db_uuid.to_string(),
1089 "user_id": firebase_uid,
1090 "last_heard": last_heard.timestamp_millis(),
1091 "uptime_ms": uptime_ms,
1092 "wifi_ssid": wifi_ssid,
1093 "backend_url": backend_url,
1094 "mqtt_broker_url": mqtt_broker_url,
1095 "mqtt_heartbeat_enable": mqtt_heartbeat_enable,
1096 "mqtt_heartbeat_interval_sec": mqtt_heartbeat_interval_sec,
1097 "audio_record_timeout_sec": audio_record_timeout_sec,
1098 "lock_timeout_ms": lock_timeout_ms,
1099 "pairing_timeout_sec": pairing_timeout_sec,
1100 "lock_state": lock_state,
1101 "locked_down_at": locked_down_at.map(|dt| dt.timestamp_millis()),
1102 "voice_detection_enable": voice_detection_enable,
1103 "voice_invite_enable": voice_invite_enable,
1104 "voice_threshold": voice_threshold,
1105 "vad_rms_threshold": vad_rms_threshold
1106 })
1107 })
1108 .collect();
1109
1110 Ok(serde_json::to_string(&devices).unwrap())
1111}
1112
1113#[get("/logs/<uuid>")]
1115pub async fn get_logs(token: Token, uuid: &str, db_pool: &State<PgPool>) -> Result<String, Status> {
1116 let uuid_parsed = Uuid::parse_str(uuid).map_err(|_| Status::BadRequest)?;
1117
1118 let user_row: Option<(String,)> =
1120 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
1121 .bind(&token.0)
1122 .fetch_optional(&**db_pool)
1123 .await
1124 .map_err(|_| Status::InternalServerError)?;
1125 let firebase_uid = match user_row {
1126 Some((uid,)) => uid,
1127 None => {
1128 return Err(Status::Unauthorized);
1129 }
1130 };
1131
1132 let row: Option<(Option<String>,)> =
1134 sqlx::query_as("SELECT user_id FROM devices WHERE uuid = $1")
1135 .bind(uuid_parsed)
1136 .fetch_optional(&**db_pool)
1137 .await
1138 .map_err(|_| Status::InternalServerError)?;
1139 if let Some((Some(db_user_id),)) = row {
1140 if firebase_uid != db_user_id {
1141 return Err(Status::Unauthorized);
1142 }
1143 } else {
1144 return Err(Status::Unauthorized); }
1146
1147 let rows = sqlx::query("SELECT l.id, l.device_id, l.timestamp, l.event_type, l.reason, l.user_id, u.name as user_name FROM logs l LEFT JOIN users u ON l.user_id = u.firebase_uid WHERE l.device_id = $1 ORDER BY l.timestamp DESC LIMIT 1000")
1149 .bind(uuid_parsed.to_string())
1150 .fetch_all(&**db_pool)
1151 .await
1152 .map_err(|_| {
1153 Status::InternalServerError
1154 })?;
1155 let logs: Vec<LogEntry> = rows
1156 .into_iter()
1157 .map(|row| LogEntry {
1158 id: row.get(0),
1159 device_id: row.get(1),
1160 timestamp: row.get(2),
1161 event_type: row.get(3),
1162 reason: row.get(4),
1163 user_id: row.get(5),
1164 user_name: row.get(6),
1165 })
1166 .collect();
1167
1168 Ok(serde_json::to_string(&logs).unwrap())
1169}
1170
1171#[get("/notifications?<devices>")]
1173pub async fn get_notifications(
1174 token: Token,
1175 devices: Option<String>,
1176 db_pool: &State<PgPool>,
1177) -> Result<String, Status> {
1178 let user_row: Option<(String,)> =
1180 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
1181 .bind(&token.0)
1182 .fetch_optional(&**db_pool)
1183 .await
1184 .map_err(|_| Status::InternalServerError)?;
1185 let firebase_uid = match user_row {
1186 Some((uid,)) => uid,
1187 None => {
1188 return Err(Status::Unauthorized);
1189 }
1190 };
1191
1192 let logs: Vec<LogEntry> = if let Some(devices_str) = devices {
1194 let device_uuids: Vec<Uuid> = devices_str
1195 .split(',')
1196 .filter_map(|s| Uuid::parse_str(s.trim()).ok())
1197 .collect();
1198 if device_uuids.is_empty() {
1199 let rows = sqlx::query("SELECT l.id, l.device_id, l.timestamp, l.event_type, l.reason, l.user_id, u.name as user_name FROM logs l LEFT JOIN users u ON l.user_id = u.firebase_uid WHERE l.device_id IN (SELECT uuid::text FROM devices WHERE user_id = $1) ORDER BY l.timestamp DESC LIMIT 1000")
1200 .bind(firebase_uid)
1201 .fetch_all(&**db_pool)
1202 .await
1203 .map_err(|_| {
1204 Status::InternalServerError
1205 })?;
1206 rows.into_iter()
1207 .map(|row| LogEntry {
1208 id: row.get(0),
1209 device_id: row.get(1),
1210 timestamp: row.get(2),
1211 event_type: row.get(3),
1212 reason: row.get(4),
1213 user_id: row.get(5),
1214 user_name: row.get(6),
1215 })
1216 .collect()
1217 } else {
1218 let device_strings: Vec<String> = device_uuids.iter().map(|u| u.to_string()).collect();
1219 let rows = sqlx::query("SELECT l.id, l.device_id, l.timestamp, l.event_type, l.reason, l.user_id, u.name as user_name FROM logs l LEFT JOIN users u ON l.user_id = u.firebase_uid WHERE l.device_id IN (SELECT uuid::text FROM devices WHERE user_id = $1) AND l.device_id = ANY($2) ORDER BY l.timestamp DESC LIMIT 1000")
1220 .bind(firebase_uid)
1221 .bind(&device_strings)
1222 .fetch_all(&**db_pool)
1223 .await
1224 .map_err(|_| {
1225 Status::InternalServerError
1226 })?;
1227 rows.into_iter()
1228 .map(|row| LogEntry {
1229 id: row.get(0),
1230 device_id: row.get(1),
1231 timestamp: row.get(2),
1232 event_type: row.get(3),
1233 reason: row.get(4),
1234 user_id: row.get(5),
1235 user_name: row.get(6),
1236 })
1237 .collect()
1238 }
1239 } else {
1240 let rows = sqlx::query("SELECT l.id, l.device_id, l.timestamp, l.event_type, l.reason, l.user_id, u.name as user_name FROM logs l LEFT JOIN users u ON l.user_id = u.firebase_uid WHERE l.device_id IN (SELECT uuid::text FROM devices WHERE user_id = $1) ORDER BY l.timestamp DESC LIMIT 1000")
1241 .bind(firebase_uid)
1242 .fetch_all(&**db_pool)
1243 .await
1244 .map_err(|_| {
1245 Status::InternalServerError
1246 })?;
1247 rows.into_iter()
1248 .map(|row| LogEntry {
1249 id: row.get(0),
1250 device_id: row.get(1),
1251 timestamp: row.get(2),
1252 event_type: row.get(3),
1253 reason: row.get(4),
1254 user_id: row.get(5),
1255 user_name: row.get(6),
1256 })
1257 .collect()
1258 };
1259
1260 Ok(serde_json::to_string(&logs).unwrap())
1261}
1262
1263#[post("/verify_voice/<device_id>", data = "<audio_data>")]
1265pub async fn verify_voice(
1266 device_id: &str,
1267 device_token: DeviceToken,
1268 audio_data: rocket::data::Data<'_>,
1269 db_pool: &State<PgPool>,
1270 speechbrain_url: &State<SpeechbrainUrl>,
1271) -> Result<rocket::serde::json::Json<serde_json::Value>, Status> {
1272 let device_uuid = Uuid::parse_str(device_id).map_err(|_| Status::BadRequest)?;
1273
1274 let device_row: Option<DeviceVoiceRow> = sqlx::query_as(
1276 "SELECT user_id, voice_invite_enable, voice_threshold, hashed_passphrase FROM devices WHERE uuid = $1",
1277 )
1278 .bind(device_uuid)
1279 .fetch_optional(&**db_pool)
1280 .await
1281 .map_err(|_| Status::InternalServerError)?;
1282
1283 let (user_id, voice_invite_enable, voice_threshold, _hashed_passphrase) = match device_row {
1284 Some(row)
1285 if row.user_id.is_some()
1286 && row.voice_invite_enable.is_some()
1287 && row.voice_threshold.is_some() =>
1288 {
1289 let uid = row.user_id.unwrap();
1290 let vie = row.voice_invite_enable.unwrap();
1291 let vt = row.voice_threshold.unwrap();
1292 let hp = row.hashed_passphrase;
1293 println!(
1294 "DEBUG: Device found, user_id: {}, voice_invite_enable: {}, voice_threshold: {}",
1295 uid, vie, vt
1296 );
1297 if let Some(ref hash) = hp {
1299 let parsed_hash =
1300 PasswordHash::new(hash).map_err(|_| Status::InternalServerError)?;
1301 let argon2 = Argon2::default();
1302 if argon2
1303 .verify_password(device_token.0.as_bytes(), &parsed_hash)
1304 .is_err()
1305 {
1306 return Err(Status::Unauthorized);
1307 }
1308 } else {
1309 return Err(Status::Unauthorized);
1311 }
1312 (uid, vie, vt, hp)
1313 }
1314 _ => {
1315 return Err(Status::BadRequest);
1316 }
1317 };
1318
1319 let mut user_embeddings = Vec::new();
1321 let mut user_ids = Vec::new();
1322
1323 let owner_row: Option<(Option<Vec<u8>>,)> =
1325 sqlx::query_as("SELECT voice_embeddings FROM users WHERE firebase_uid = $1")
1326 .bind(&user_id)
1327 .fetch_optional(&**db_pool)
1328 .await
1329 .map_err(|_| Status::InternalServerError)?;
1330
1331 if let Some((Some(emb),)) = owner_row {
1332 println!(
1333 "DEBUG: Found voice embeddings for owner {} ({} bytes)",
1334 user_id,
1335 emb.len()
1336 );
1337 user_embeddings.push(base64::engine::general_purpose::STANDARD.encode(&emb));
1338 user_ids.push(user_id.clone());
1339 } else {
1340 return Err(Status::BadRequest); }
1342
1343 if voice_invite_enable {
1345 let now = Utc::now().timestamp_millis();
1346 let invite_rows: Vec<(String, Vec<u8>)> = sqlx::query_as(
1347 "SELECT u.firebase_uid, u.voice_embeddings FROM users u JOIN invites i ON u.firebase_uid = i.receiver_id WHERE i.device_id = $1 AND i.status = 1 AND i.expiry_timestamp > $2 AND u.voice_embeddings IS NOT NULL"
1348 )
1349 .bind(device_uuid)
1350 .bind(now)
1351 .fetch_all(&**db_pool)
1352 .await
1353 .map_err(|_| {
1354 Status::InternalServerError
1355 })?;
1356
1357 for (invite_user_id, emb) in invite_rows {
1358 println!(
1359 "DEBUG: Found voice embeddings for invited user {} ({} bytes)",
1360 invite_user_id,
1361 emb.len()
1362 );
1363 user_embeddings.push(base64::engine::general_purpose::STANDARD.encode(&emb));
1364 user_ids.push(invite_user_id);
1365 }
1366 }
1367
1368 if user_embeddings.is_empty() {
1369 return Err(Status::BadRequest);
1370 }
1371
1372 println!(
1373 "DEBUG: Collected {} embeddings from users: {:?}",
1374 user_embeddings.len(),
1375 user_ids
1376 );
1377
1378 let mut data = Vec::new();
1380 audio_data
1381 .open(rocket::data::ByteUnit::max_value())
1382 .read_to_end(&mut data)
1383 .await
1384 .map_err(|_| Status::BadRequest)?;
1385
1386 if data.is_empty() {
1387 return Err(Status::BadRequest);
1388 }
1389
1390 println!(
1392 "DEBUG: Calling speechbrain verify service at {}/verify",
1393 speechbrain_url.0.as_str()
1394 );
1395 let client = Client::new();
1396 let base64_data = base64::engine::general_purpose::STANDARD.encode(&data);
1397
1398 let response = client
1399 .post(format!("{}/verify", speechbrain_url.0.as_str()))
1400 .header("Content-Type", "application/json")
1401 .json(&serde_json::json!({
1402 "pcm_base64": base64_data,
1403 "candidates": user_embeddings
1404 }))
1405 .send()
1406 .await
1407 .map_err(|_| Status::InternalServerError)?;
1408
1409 println!(
1410 "DEBUG: Speechbrain verify response status: {}",
1411 response.status()
1412 );
1413
1414 if !response.status().is_success() {
1415 return Err(Status::InternalServerError);
1416 }
1417
1418 let verify_response: serde_json::Value = response.json().await.map_err(|e| {
1419 println!(
1420 "DEBUG: Failed to parse speechbrain verify response: {:?}",
1421 e
1422 );
1423 Status::InternalServerError
1424 })?;
1425
1426 let best_index = verify_response["best_index"]
1427 .as_u64()
1428 .ok_or(Status::InternalServerError)? as usize;
1429
1430 let score = verify_response["score"]
1431 .as_f64()
1432 .ok_or(Status::InternalServerError)?;
1433
1434 println!(
1435 "DEBUG: Verification best_index: {}, score: {}",
1436 best_index, score
1437 );
1438
1439 if score > voice_threshold && best_index < user_ids.len() {
1440 println!(
1441 "DEBUG: Score {} > {}, allowing unlock for user at index {}",
1442 score, voice_threshold, best_index
1443 );
1444
1445 let matched_user_id = &user_ids[best_index];
1446
1447 let now = chrono::Utc::now().timestamp();
1449 {
1450 let commands_mutex = super::RECENT_COMMANDS.get().unwrap();
1451 let mut commands = commands_mutex.lock().unwrap();
1452 commands.insert(device_id.to_string(), (matched_user_id.clone(), now));
1453 }
1454
1455 println!(
1456 "DEBUG: Stored recent voice verification for user {}",
1457 matched_user_id
1458 );
1459 Ok(rocket::serde::json::Json(
1460 serde_json::json!({"index": best_index}),
1461 ))
1462 } else {
1463 println!(
1464 "DEBUG: Score {} <= {} or invalid index {}, denying unlock",
1465 score, voice_threshold, best_index
1466 );
1467 Err(Status::Forbidden)
1468 }
1469}
1470
1471#[get("/accessible_devices")]
1473pub async fn get_accessible_devices(
1474 token: Token,
1475 db_pool: &State<PgPool>,
1476) -> Result<String, Status> {
1477 let user_row: Option<(String,)> =
1479 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
1480 .bind(&token.0)
1481 .fetch_optional(&**db_pool)
1482 .await
1483 .map_err(|_| Status::InternalServerError)?;
1484 let user_id = match user_row {
1485 Some((uid,)) => uid,
1486 None => return Err(Status::Unauthorized),
1487 };
1488
1489 let invites: Vec<(i32, uuid::Uuid, String, i64)> =
1491 sqlx::query_as("SELECT id, device_id, sender_id, expiry_timestamp FROM invites WHERE receiver_id = $1 AND status = 1 AND expiry_timestamp > $2")
1492 .bind(&user_id)
1493 .bind(Utc::now().timestamp_millis())
1494 .fetch_all(&**db_pool)
1495 .await
1496 .map_err(|_| Status::InternalServerError)?;
1497
1498 let devices: Vec<serde_json::Value> = invites
1499 .into_iter()
1500 .map(|(id, device_id, sender_id, expiry)| {
1501 serde_json::json!({
1502 "id": id,
1503 "device_id": device_id.to_string(),
1504 "sender_id": sender_id,
1505 "expiry_timestamp": expiry
1506 })
1507 })
1508 .collect();
1509
1510 Ok(serde_json::to_string(&devices).unwrap())
1511}