1use anyhow::Result;
31use chrono::Utc;
32use rocket::http::Status;
33use rocket::request::{FromRequest, Outcome};
34use rocket::{Request, State, get, routes};
35use rumqttc::{AsyncClient, MqttOptions, QoS, Transport};
36use sqlx::ConnectOptions;
37use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
38use std::collections::HashMap;
39use std::env;
40use std::sync::{Mutex, OnceLock};
41use url::Url;
42
43mod device;
44mod invite;
45mod mqtt;
46mod user;
47
48pub struct SpeechbrainUrl(pub String);
50pub struct HomepageUrl(pub String);
52
53#[derive(Clone)]
55pub struct Token(pub String);
56
57#[rocket::async_trait]
59impl<'r> FromRequest<'r> for Token {
60 type Error = &'static str;
61
62 async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
64 let auth_header = req.headers().get_one("Authorization");
65 match auth_header {
66 Some(auth) if auth.starts_with("Bearer ") => {
67 let token_str = &auth[7..];
68 Outcome::Success(Token(token_str.to_string()))
69 }
70 _ => Outcome::Error((
71 Status::Unauthorized,
72 "Missing or invalid Authorization header",
73 )),
74 }
75 }
76}
77
78type RecentCommands = Mutex<HashMap<String, (String, i64)>>;
80type PendingPings = Mutex<HashMap<String, (i64, tokio::sync::oneshot::Sender<()>)>>;
82type PendingConfigUpdates = Mutex<HashMap<String, tokio::sync::oneshot::Sender<()>>>;
84
85pub static RECENT_COMMANDS: OnceLock<RecentCommands> = OnceLock::new();
87pub static PENDING_PINGS: OnceLock<PendingPings> = OnceLock::new();
89pub static PENDING_CONFIG_UPDATES: OnceLock<PendingConfigUpdates> = OnceLock::new();
91
92#[tokio::main]
96async fn main() -> Result<()> {
97 dotenv::dotenv().ok();
98
99 let db_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
101 let mqtt_host = env::var("MQTT_HOST").expect("MQTT_HOST must be set");
102 let mqtt_port: u16 = env::var("MQTT_PORT")
103 .map(|s| s.parse().unwrap())
104 .unwrap_or(1883);
105 let mqtt_tls: bool = env::var("MQTT_TLS")
106 .map(|s| s.parse().unwrap())
107 .unwrap_or(false);
108 let mqtt_username = env::var("MQTT_USERNAME").ok();
109 let mqtt_password = env::var("MQTT_PASSWORD").ok();
110 let port: u16 = env::var("PORT")
111 .unwrap_or("8000".to_string())
112 .parse()
113 .unwrap();
114 let speechbrain_url =
115 SpeechbrainUrl(env::var("SPEECHBRAIN_URL").unwrap_or("http://localhost:5008".to_string()));
116 let homepage_url =
117 HomepageUrl(env::var("HOMEPAGE_URL").unwrap_or("https://example.com".to_string()));
118 RECENT_COMMANDS.set(Mutex::new(HashMap::new())).unwrap();
119 PENDING_PINGS.set(Mutex::new(HashMap::new())).unwrap();
120 PENDING_CONFIG_UPDATES
121 .set(Mutex::new(HashMap::new()))
122 .unwrap();
123
124 let url = Url::parse(&db_url)?;
126 let options = PgConnectOptions::from_url(&url)?.ssl_mode(PgSslMode::Require);
127 let db_pool = PgPoolOptions::new()
128 .max_connections(5)
129 .connect_with(options)
130 .await?;
131
132 sqlx::query("CREATE TABLE IF NOT EXISTS devices ( uuid uuid PRIMARY KEY, user_id VARCHAR(255), last_heard timestamptz NOT NULL, uptime_ms bigint NOT NULL, hashed_passphrase VARCHAR(255), locked_down_at timestamptz)")
134 .execute(&db_pool)
135 .await?;
136
137 sqlx::query("CREATE TABLE IF NOT EXISTS users ( firebase_uid VARCHAR(255) PRIMARY KEY, hashed_password VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, phone_number VARCHAR(255), name VARCHAR(255) NOT NULL, current_token VARCHAR(255), voice_embeddings BYTEA, created_at timestamptz NOT NULL DEFAULT NOW(), last_login timestamptz)")
139 .execute(&db_pool)
140 .await?;
141
142 sqlx::query("CREATE TABLE IF NOT EXISTS logs ( id SERIAL PRIMARY KEY, device_id VARCHAR(255) NOT NULL, timestamp timestamptz NOT NULL DEFAULT NOW(), event_type VARCHAR(10) NOT NULL, reason VARCHAR(20) NOT NULL, user_id VARCHAR(255), created_at timestamptz NOT NULL DEFAULT NOW())")
144 .execute(&db_pool)
145 .await?;
146
147 sqlx::query("CREATE TABLE IF NOT EXISTS invites ( id SERIAL PRIMARY KEY, device_id UUID NOT NULL, sender_id VARCHAR(255) NOT NULL, receiver_id VARCHAR(255) NOT NULL, status INTEGER NOT NULL DEFAULT 0, expiry_timestamp BIGINT NOT NULL, created_at timestamptz NOT NULL DEFAULT NOW(), FOREIGN KEY (device_id) REFERENCES devices(uuid) ON DELETE CASCADE)")
149 .execute(&db_pool)
150 .await?;
151
152 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS user_id VARCHAR(255)")
154 .execute(&db_pool)
155 .await?;
156 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS hashed_passphrase VARCHAR(255)")
157 .execute(&db_pool)
158 .await?;
159 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS wifi_ssid VARCHAR(255)")
160 .execute(&db_pool)
161 .await?;
162 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS backend_url VARCHAR(255)")
163 .execute(&db_pool)
164 .await?;
165 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS mqtt_broker_url VARCHAR(255)")
166 .execute(&db_pool)
167 .await?;
168 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS mqtt_heartbeat_enable BOOLEAN")
169 .execute(&db_pool)
170 .await?;
171 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS mqtt_heartbeat_interval_sec INTEGER")
172 .execute(&db_pool)
173 .await?;
174 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS audio_record_timeout_sec INTEGER")
175 .execute(&db_pool)
176 .await?;
177 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS lock_timeout_ms INTEGER")
178 .execute(&db_pool)
179 .await?;
180 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS lock_state VARCHAR(10)")
181 .execute(&db_pool)
182 .await?;
183 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS pairing_timeout_sec INTEGER")
184 .execute(&db_pool)
185 .await?;
186 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS locked_down_at timestamptz")
187 .execute(&db_pool)
188 .await?;
189 sqlx::query(
190 "ALTER TABLE devices ADD COLUMN IF NOT EXISTS voice_detection_enable BOOLEAN DEFAULT true",
191 )
192 .execute(&db_pool)
193 .await?;
194 sqlx::query(
195 "ALTER TABLE devices ADD COLUMN IF NOT EXISTS voice_invite_enable BOOLEAN DEFAULT true",
196 )
197 .execute(&db_pool)
198 .await?;
199 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS voice_threshold FLOAT8 DEFAULT 0.60")
200 .execute(&db_pool)
201 .await?;
202 sqlx::query("ALTER TABLE devices ADD COLUMN IF NOT EXISTS vad_rms_threshold INTEGER DEFAULT 1000")
203 .execute(&db_pool)
204 .await?;
205 sqlx::query("ALTER TABLE users ADD COLUMN IF NOT EXISTS voice_embeddings BYTEA")
206 .execute(&db_pool)
207 .await?;
208
209 let mut mqtt_options = MqttOptions::new("backend", mqtt_host, mqtt_port);
211 if let Some(user) = mqtt_username {
212 mqtt_options.set_credentials(user, mqtt_password.unwrap_or_default());
213 }
214 if mqtt_tls {
215 mqtt_options.set_transport(Transport::tls_with_default_config());
216 }
217
218 let (mqtt_client, mut eventloop) = AsyncClient::new(mqtt_options, 10);
219
220 mqtt_client
222 .subscribe("lockwise/+/status", QoS::AtMostOnce)
223 .await?;
224
225 let db_pool_clone = db_pool.clone();
227 tokio::spawn(async move {
228 mqtt::handle_mqtt_events(&db_pool_clone, &mut eventloop).await;
229 });
230
231 let db_pool_cleanup = db_pool.clone();
233 tokio::spawn(async move {
234 loop {
235 tokio::time::sleep(std::time::Duration::from_secs(3600)).await; let one_month_ago = Utc::now() - chrono::Duration::days(30);
237 let _ = sqlx::query("DELETE FROM logs WHERE timestamp < $1")
238 .bind(one_month_ago)
239 .execute(&db_pool_cleanup)
240 .await;
241 }
242 });
243
244 tokio::spawn(async move {
246 rocket::build()
247 .configure(
248 rocket::Config::figment()
249 .merge(("port", port))
250 .merge(("address", "0.0.0.0"))
251 .merge(("workers", num_cpus::get())),
252 )
253 .manage(db_pool)
254 .manage(mqtt_client)
255 .manage(speechbrain_url)
256 .manage(homepage_url)
257 .mount(
258 "/",
259 routes![
260 index,
261 health,
262 device::control_device,
263 device::control_temp_device,
264 device::get_accessible_devices,
265 device::get_device,
266 device::get_devices,
267 device::get_logs,
268 device::get_notifications,
269 device::get_temp_device,
270 device::get_temp_devices_status,
271 device::lockdown_device,
272 device::ping_device,
273 device::ping_temp_device,
274 device::reboot_device,
275 device::register_device,
276 device::unpair_device,
277 device::update_config,
278 device::verify_voice,
279 invite::accept_invite,
280 invite::cancel_invite,
281 invite::create_invite,
282 invite::get_invites,
283 invite::reject_invite,
284 invite::update_invite,
285 user::delete_account,
286 user::delete_voice,
287 user::login_user,
288 user::logout_user,
289 user::register_user,
290 user::register_voice,
291 user::update_password,
292 user::update_phone,
293 user::verify_password,
294 user::voice_status
295 ],
296 )
297 .launch()
298 .await
299 .unwrap();
300 });
301
302 tokio::signal::ctrl_c().await?;
304 Ok(())
305}
306
307#[get("/")]
309fn index(homepage_url: &State<HomepageUrl>) -> rocket::response::Redirect {
310 rocket::response::Redirect::found(homepage_url.0.as_str().to_string())
311}
312
313#[get("/health")]
315fn health() -> &'static str {
316 "OK"
317}