1use anyhow::Result;
32use chrono::Utc;
33use rocket::futures::{SinkExt, StreamExt};
34use rocket::http::Status;
35use rocket::request::{FromRequest, Outcome};
36use rocket::{Request, State, get, routes};
37use rocket_ws::{Message, WebSocket};
38use rumqttc::{AsyncClient, MqttOptions, QoS, Transport};
39use sqlx::ConnectOptions;
40use sqlx::PgPool;
41use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
42use std::collections::HashMap;
43use std::env;
44use std::sync::{Mutex, OnceLock};
45use tokio::sync::broadcast;
46use url::Url;
47
48mod device;
49mod invite;
50mod mqtt;
51mod user;
52
53pub struct SpeechbrainUrl(pub String);
55pub struct HomepageUrl(pub String);
57
58#[derive(Clone)]
60pub struct Token(pub String);
61
62#[rocket::async_trait]
64impl<'r> FromRequest<'r> for Token {
65 type Error = &'static str;
66
67 async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
69 let auth_header = req.headers().get_one("Authorization");
71 if let Some(token_str) = auth_header.and_then(|auth| auth.strip_prefix("Bearer ")) {
72 return Outcome::Success(Token(token_str.to_string()));
73 }
74
75 if let Some(token_str) = req.uri().query().and_then(|q| {
77 url::form_urlencoded::parse(q.as_bytes())
78 .find(|(k, _)| k == "token")
79 .map(|(_, v)| v.into_owned())
80 }) {
81 return Outcome::Success(Token(token_str));
82 }
83
84 Outcome::Error((
85 Status::Unauthorized,
86 "Missing or invalid Authorization header or token query parameter",
87 ))
88 }
89}
90
91type RecentCommands = Mutex<HashMap<String, (String, i64)>>;
93type PendingPings = Mutex<HashMap<String, (i64, tokio::sync::oneshot::Sender<()>)>>;
95type PendingConfigUpdates = Mutex<HashMap<String, tokio::sync::oneshot::Sender<()>>>;
97
98type DeviceUpdateSender = broadcast::Sender<String>;
100type UserBroadcast = broadcast::Sender<String>;
102type UserBroadcasts = Mutex<HashMap<String, UserBroadcast>>;
104
105pub static RECENT_COMMANDS: OnceLock<RecentCommands> = OnceLock::new();
107pub static PENDING_PINGS: OnceLock<PendingPings> = OnceLock::new();
109pub static PENDING_CONFIG_UPDATES: OnceLock<PendingConfigUpdates> = OnceLock::new();
111pub static DEVICE_UPDATE_TX: OnceLock<DeviceUpdateSender> = OnceLock::new();
113pub static USER_BROADCASTS: OnceLock<UserBroadcasts> = OnceLock::new();
115
116#[tokio::main]
120async fn main() -> Result<()> {
121 dotenv::dotenv().ok();
122
123 let db_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
125 let mqtt_host = env::var("MQTT_HOST").expect("MQTT_HOST must be set");
126 let mqtt_port: u16 = env::var("MQTT_PORT")
127 .map(|s| s.parse().unwrap())
128 .unwrap_or(1883);
129 let mqtt_tls: bool = env::var("MQTT_TLS")
130 .map(|s| s.parse().unwrap())
131 .unwrap_or(false);
132 let mqtt_username = env::var("MQTT_USERNAME").ok();
133 let mqtt_password = env::var("MQTT_PASSWORD").ok();
134 let port: u16 = env::var("PORT")
135 .unwrap_or("8000".to_string())
136 .parse()
137 .unwrap();
138 let speechbrain_url =
139 SpeechbrainUrl(env::var("SPEECHBRAIN_URL").unwrap_or("http://localhost:5008".to_string()));
140 let homepage_url =
141 HomepageUrl(env::var("HOMEPAGE_URL").unwrap_or("https://example.com".to_string()));
142 RECENT_COMMANDS.set(Mutex::new(HashMap::new())).unwrap();
143 PENDING_PINGS.set(Mutex::new(HashMap::new())).unwrap();
144 PENDING_CONFIG_UPDATES
145 .set(Mutex::new(HashMap::new()))
146 .unwrap();
147 let (tx, _rx) = broadcast::channel(100);
148 DEVICE_UPDATE_TX.set(tx).unwrap();
149 USER_BROADCASTS.set(Mutex::new(HashMap::new())).unwrap();
150
151 let url = Url::parse(&db_url)?;
153 let options = PgConnectOptions::from_url(&url)?.ssl_mode(PgSslMode::Require);
154 let db_pool = PgPoolOptions::new()
155 .max_connections(5)
156 .connect_with(options)
157 .await?;
158
159 sqlx::query("CREATE TABLE IF NOT EXISTS devices ( uuid uuid PRIMARY KEY, user_id VARCHAR(255), last_heard timestamptz NOT NULL, uptime_ms bigint NOT NULL, hashed_passphrase VARCHAR(255), locked_down_at timestamptz)")
161 .execute(&db_pool)
162 .await?;
163
164 sqlx::query("CREATE TABLE IF NOT EXISTS users ( firebase_uid VARCHAR(255) PRIMARY KEY, hashed_password VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, phone_number VARCHAR(255), name VARCHAR(255) NOT NULL, current_token VARCHAR(255), voice_embeddings BYTEA, created_at timestamptz NOT NULL DEFAULT NOW(), last_login timestamptz)")
166 .execute(&db_pool)
167 .await?;
168
169 sqlx::query("CREATE TABLE IF NOT EXISTS logs ( id SERIAL PRIMARY KEY, device_id VARCHAR(255) NOT NULL, timestamp timestamptz NOT NULL DEFAULT NOW(), event_type VARCHAR(10) NOT NULL, reason VARCHAR(20) NOT NULL, user_id VARCHAR(255), created_at timestamptz NOT NULL DEFAULT NOW())")
171 .execute(&db_pool)
172 .await?;
173
174 sqlx::query("CREATE TABLE IF NOT EXISTS invites ( id SERIAL PRIMARY KEY, device_id UUID NOT NULL, sender_id VARCHAR(255) NOT NULL, receiver_id VARCHAR(255) NOT NULL, status INTEGER NOT NULL DEFAULT 0, expiry_timestamp BIGINT NOT NULL, created_at timestamptz NOT NULL DEFAULT NOW(), FOREIGN KEY (device_id) REFERENCES devices(uuid) ON DELETE CASCADE)")
176 .execute(&db_pool)
177 .await?;
178
179 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS user_id VARCHAR(255)")
181 .execute(&db_pool)
182 .await?;
183 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS hashed_passphrase VARCHAR(255)")
184 .execute(&db_pool)
185 .await?;
186 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS wifi_ssid VARCHAR(255)")
187 .execute(&db_pool)
188 .await?;
189 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS backend_url VARCHAR(255)")
190 .execute(&db_pool)
191 .await?;
192 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS mqtt_broker_url VARCHAR(255)")
193 .execute(&db_pool)
194 .await?;
195 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS mqtt_heartbeat_enable BOOLEAN")
196 .execute(&db_pool)
197 .await?;
198 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS mqtt_heartbeat_interval_sec INTEGER")
199 .execute(&db_pool)
200 .await?;
201 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS audio_record_timeout_sec INTEGER")
202 .execute(&db_pool)
203 .await?;
204 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS lock_timeout_ms INTEGER")
205 .execute(&db_pool)
206 .await?;
207 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS lock_state VARCHAR(10)")
208 .execute(&db_pool)
209 .await?;
210 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS pairing_timeout_sec INTEGER")
211 .execute(&db_pool)
212 .await?;
213 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS locked_down_at timestamptz")
214 .execute(&db_pool)
215 .await?;
216 sqlx::query(
217 "ALTER TABLE devices ADD COLUMN IF NOT EXISTS voice_detection_enable BOOLEAN DEFAULT true",
218 )
219 .execute(&db_pool)
220 .await?;
221 sqlx::query(
222 "ALTER TABLE devices ADD COLUMN IF NOT EXISTS voice_invite_enable BOOLEAN DEFAULT true",
223 )
224 .execute(&db_pool)
225 .await?;
226 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS voice_threshold FLOAT8 DEFAULT 0.60")
227 .execute(&db_pool)
228 .await?;
229 sqlx::query(
230 "ALTER TABLE devices ADD COLUMN IF NOT EXISTS vad_rms_threshold INTEGER DEFAULT 1000",
231 )
232 .execute(&db_pool)
233 .await?;
234 sqlx::query("ALTER TABLE users ADD COLUMN IF NOT EXISTS voice_embeddings BYTEA")
235 .execute(&db_pool)
236 .await?;
237
238 let mut mqtt_options = MqttOptions::new("backend", mqtt_host, mqtt_port);
240 if let Some(user) = mqtt_username {
241 mqtt_options.set_credentials(user, mqtt_password.unwrap_or_default());
242 }
243 if mqtt_tls {
244 mqtt_options.set_transport(Transport::tls_with_default_config());
245 }
246
247 let (mqtt_client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
248
249 mqtt_client
251 .subscribe("lockwise/+/status", QoS::AtMostOnce)
252 .await?;
253
254 let db_pool_clone = db_pool.clone();
256 tokio::spawn(async move {
257 mqtt::handle_mqtt_events(&db_pool_clone, &mut eventloop).await;
258 });
259
260 let db_pool_cleanup = db_pool.clone();
262 tokio::spawn(async move {
263 loop {
264 tokio::time::sleep(std::time::Duration::from_secs(3600)).await; let one_month_ago = Utc::now() - chrono::Duration::days(30);
266 let _ = sqlx::query("DELETE FROM logs WHERE timestamp < $1")
267 .bind(one_month_ago)
268 .execute(&db_pool_cleanup)
269 .await;
270 }
271 });
272
273 tokio::spawn(async move {
275 rocket::build()
276 .configure(
277 rocket::Config::figment()
278 .merge(("port", port))
279 .merge(("address", "0.0.0.0"))
280 .merge(("workers", num_cpus::get())),
281 )
282 .manage(db_pool)
283 .manage(mqtt_client)
284 .manage(speechbrain_url)
285 .manage(homepage_url)
286 .mount(
287 "/",
288 routes![
289 index,
290 health,
291 websocket_updates,
292 device::control_device,
293 device::control_temp_device,
294 device::get_accessible_devices,
295 device::get_device,
296 device::get_devices,
297 device::get_logs,
298 device::get_notifications,
299 device::get_temp_device,
300 device::get_temp_devices_status,
301 device::lockdown_device,
302 device::ping_device,
303 device::ping_temp_device,
304 device::reboot_device,
305 device::register_device,
306 device::unpair_device,
307 device::update_config,
308 device::verify_voice,
309 invite::accept_invite,
310 invite::cancel_invite,
311 invite::create_invite,
312 invite::get_invites,
313 invite::reject_invite,
314 invite::update_invite,
315 user::delete_account,
316 user::delete_voice,
317 user::login_user,
318 user::logout_user,
319 user::register_user,
320 user::register_voice,
321 user::update_password,
322 user::update_phone,
323 user::verify_password,
324 user::voice_status
325 ],
326 )
327 .launch()
328 .await
329 .unwrap();
330 });
331
332 tokio::signal::ctrl_c().await?;
334 Ok(())
335}
336
337#[get("/")]
339fn index(homepage_url: &State<HomepageUrl>) -> rocket::response::Redirect {
340 rocket::response::Redirect::found(homepage_url.0.as_str().to_string())
341}
342
343#[get("/health")]
345fn health() -> &'static str {
346 "OK"
347}
348
349#[get("/ws/updates")]
351fn websocket_updates(
352 ws: WebSocket,
353 token: Token,
354 db_pool: &State<PgPool>,
355) -> rocket_ws::Channel<'static> {
356 let pool = (**db_pool).clone();
357 ws.channel(move |mut stream| {
358 Box::pin(async move {
359 let user_row: Option<(String,)> =
361 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
362 .bind(&token.0)
363 .fetch_optional(&pool)
364 .await
365 .unwrap_or(None);
366 let user_id = match user_row {
367 Some((uid,)) => uid,
368 None => {
369 let _ = stream.close(None).await;
371 return Ok(());
372 }
373 };
374
375 let user_broadcasts = USER_BROADCASTS.get().unwrap();
377 let tx = {
378 let mut broadcasts = user_broadcasts.lock().unwrap();
379 broadcasts
380 .entry(user_id.clone())
381 .or_insert_with(|| broadcast::channel(100).0)
382 .clone()
383 };
384 let mut rx = tx.subscribe();
385
386 if stream
388 .send(Message::Text("Connected".to_string()))
389 .await
390 .is_err()
391 {
392 return Ok(());
393 }
394
395 loop {
396 tokio::select! {
397 msg = rx.recv() => {
398 match msg {
399 Ok(update) => {
400 if stream.send(Message::Text(update)).await.is_err() {
401 break;
402 }
403 }
404 Err(_) => break,
405 }
406 }
407 msg = stream.next() => {
408 match msg {
409 Some(Ok(Message::Close(_))) | None => break,
410 _ => {} }
412 }
413 }
414 }
415
416 Ok(())
417 })
418 })
419}