1use anyhow::Result;
6use argon2::password_hash::PasswordHash;
7use argon2::{Argon2, PasswordHasher, PasswordVerifier};
8use base64::Engine;
9use reqwest::Client;
10use rocket::http::Status;
11use rocket::{State, get, post};
12use serde::Deserialize;
13use sqlx::PgPool;
14use tokio::io::AsyncReadExt;
15use uuid::Uuid;
16
17use super::{SpeechbrainUrl, Token};
18
19#[derive(Deserialize)]
21pub struct RegisterRequest {
22 firebase_uid: String,
24 password: String,
26 email: String,
28 phone_number: String,
30 name: String,
32}
33
34#[derive(Deserialize)]
36pub struct LoginRequest {
37 firebase_uid: String,
39 password: String,
40}
41
42#[derive(Deserialize)]
44pub struct UpdatePhoneRequest {
45 phone_number: String,
47}
48#[derive(Deserialize)]
50pub struct UpdatePasswordRequest {
51 password: String,
53}
54
55#[derive(Deserialize)]
57pub struct VerifyPasswordRequest {
58 password: String,
60}
61
62#[post("/register", data = "<request>")]
64pub async fn register_user(
65 request: rocket::serde::json::Json<RegisterRequest>,
66 db_pool: &State<PgPool>,
67) -> Result<(), Status> {
68 let salt =
70 argon2::password_hash::SaltString::generate(&mut argon2::password_hash::rand_core::OsRng);
71 let argon2 = Argon2::default();
72 let hashed_password = argon2
73 .hash_password(request.password.as_bytes(), &salt)
74 .map_err(|_| Status::InternalServerError)?
75 .to_string();
76
77 sqlx::query(
79 "INSERT INTO users (firebase_uid, hashed_password, email, phone_number, name, created_at) VALUES ($1, $2, $3, $4, $5, NOW())"
80 )
81 .bind(&request.firebase_uid)
82 .bind(&hashed_password)
83 .bind(&request.email)
84 .bind(&request.phone_number)
85 .bind(&request.name)
86 .execute(&**db_pool)
87 .await
88 .map_err(|_| Status::InternalServerError)?;
89
90 Ok(())
91}
92
93#[post("/login", data = "<request>")]
95pub async fn login_user(
96 request: rocket::serde::json::Json<LoginRequest>,
97 db_pool: &State<PgPool>,
98) -> Result<String, Status> {
99 let row: Option<(String,)> =
101 sqlx::query_as("SELECT hashed_password FROM users WHERE firebase_uid = $1")
102 .bind(&request.firebase_uid)
103 .fetch_optional(&**db_pool)
104 .await
105 .map_err(|_| Status::InternalServerError)?;
106
107 if let Some((hashed_password,)) = row {
108 let parsed_hash = PasswordHash::new(&hashed_password);
109 if let Ok(hash) = parsed_hash {
110 let argon2 = Argon2::default();
111 if argon2
112 .verify_password(request.password.as_bytes(), &hash)
113 .is_ok()
114 {
115 let token = Uuid::new_v4().to_string();
117 sqlx::query("UPDATE users SET current_token = $1, last_login = NOW() WHERE firebase_uid = $2")
119 .bind(&token)
120 .bind(&request.firebase_uid)
121 .execute(&**db_pool)
122 .await
123 .map_err(|_| Status::InternalServerError)?;
124 return Ok(token);
125 }
126 }
127 }
128 Err(Status::Unauthorized)
129}
130
131#[post("/logout")]
133pub async fn logout_user(token: Token, db_pool: &State<PgPool>) -> Result<(), Status> {
134 sqlx::query("UPDATE users SET current_token = NULL WHERE firebase_uid = $1")
135 .bind(token.0)
136 .execute(&**db_pool)
137 .await
138 .map_err(|_| Status::InternalServerError)?;
139 Ok(())
140}
141
142#[post("/update_phone", data = "<request>")]
144pub async fn update_phone(
145 token: Token,
146 request: rocket::serde::json::Json<UpdatePhoneRequest>,
147 db_pool: &State<PgPool>,
148) -> Result<(), Status> {
149 let user_row: Option<(String,)> =
151 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
152 .bind(&token.0)
153 .fetch_optional(&**db_pool)
154 .await
155 .map_err(|_| Status::InternalServerError)?;
156 let firebase_uid = match user_row {
157 Some((uid,)) => uid,
158 None => return Err(Status::Unauthorized),
159 };
160
161 sqlx::query("UPDATE users SET phone_number = $1 WHERE firebase_uid = $2")
163 .bind(&request.phone_number)
164 .bind(&firebase_uid)
165 .execute(&**db_pool)
166 .await
167 .map_err(|_| Status::InternalServerError)?;
168
169 Ok(())
170}
171
172#[post("/update_password", data = "<request>")]
174pub async fn update_password(
175 token: Token,
176 request: rocket::serde::json::Json<UpdatePasswordRequest>,
177 db_pool: &State<PgPool>,
178) -> Result<(), Status> {
179 let user_row: Option<(String,)> =
181 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
182 .bind(&token.0)
183 .fetch_optional(&**db_pool)
184 .await
185 .map_err(|_| Status::InternalServerError)?;
186 let firebase_uid = match user_row {
187 Some((uid,)) => uid,
188 None => return Err(Status::Unauthorized),
189 };
190
191 let salt =
193 argon2::password_hash::SaltString::generate(&mut argon2::password_hash::rand_core::OsRng);
194 let argon2 = Argon2::default();
195 let hashed_password = argon2
196 .hash_password(request.password.as_bytes(), &salt)
197 .map_err(|_| Status::InternalServerError)?
198 .to_string();
199
200 sqlx::query("UPDATE users SET hashed_password = $1 WHERE firebase_uid = $2")
202 .bind(&hashed_password)
203 .bind(&firebase_uid)
204 .execute(&**db_pool)
205 .await
206 .map_err(|_| Status::InternalServerError)?;
207
208 Ok(())
209}
210
211#[post("/delete_account")]
213pub async fn delete_account(token: Token, db_pool: &State<PgPool>) -> Result<(), Status> {
214 let user_row: Option<(String,)> =
216 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
217 .bind(&token.0)
218 .fetch_optional(&**db_pool)
219 .await
220 .map_err(|_| Status::InternalServerError)?;
221 let firebase_uid = match user_row {
222 Some((uid,)) => uid,
223 None => return Err(Status::Unauthorized),
224 };
225
226 sqlx::query("UPDATE devices SET user_id = NULL WHERE user_id = $1")
228 .bind(&firebase_uid)
229 .execute(&**db_pool)
230 .await
231 .map_err(|_| Status::InternalServerError)?;
232
233 sqlx::query("DELETE FROM logs WHERE user_id = $1")
235 .bind(&firebase_uid)
236 .execute(&**db_pool)
237 .await
238 .map_err(|_| Status::InternalServerError)?;
239
240 sqlx::query("DELETE FROM invites WHERE sender_id = $1 OR receiver_id = $1")
242 .bind(&firebase_uid)
243 .execute(&**db_pool)
244 .await
245 .map_err(|_| Status::InternalServerError)?;
246
247 sqlx::query("DELETE FROM users WHERE firebase_uid = $1")
249 .bind(&firebase_uid)
250 .execute(&**db_pool)
251 .await
252 .map_err(|_| Status::InternalServerError)?;
253
254 Ok(())
255}
256
257#[post("/verify_password", data = "<request>")]
259pub async fn verify_password(
260 token: Token,
261 request: rocket::serde::json::Json<VerifyPasswordRequest>,
262 db_pool: &State<PgPool>,
263) -> Result<(), Status> {
264 let user_row: Option<(String,)> =
266 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
267 .bind(&token.0)
268 .fetch_optional(&**db_pool)
269 .await
270 .map_err(|_| Status::InternalServerError)?;
271 let firebase_uid = match user_row {
272 Some((uid,)) => uid,
273 None => return Err(Status::Unauthorized),
274 };
275
276 let password_row: Option<(String,)> =
278 sqlx::query_as("SELECT hashed_password FROM users WHERE firebase_uid = $1")
279 .bind(&firebase_uid)
280 .fetch_optional(&**db_pool)
281 .await
282 .map_err(|_| Status::InternalServerError)?;
283
284 if let Some((hashed_password,)) = password_row {
285 let parsed_hash = PasswordHash::new(&hashed_password);
286 if let Ok(hash) = parsed_hash {
287 let argon2 = Argon2::default();
288 if argon2
289 .verify_password(request.password.as_bytes(), &hash)
290 .is_ok()
291 {
292 return Ok(());
293 }
294 }
295 }
296
297 Err(Status::Unauthorized)
298}
299
300#[post("/register_voice", data = "<audio_data>")]
302pub async fn register_voice(
303 token: Token,
304 audio_data: rocket::data::Data<'_>,
305 db_pool: &State<PgPool>,
306 speechbrain_url: &State<SpeechbrainUrl>,
307) -> Result<(), Status> {
308 let user_row: Option<(String,)> =
310 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
311 .bind(&token.0)
312 .fetch_optional(&**db_pool)
313 .await
314 .map_err(|_| Status::InternalServerError)?;
315 let firebase_uid = match user_row {
316 Some((uid,)) => uid,
317 None => {
318 return Err(Status::Unauthorized);
319 }
320 };
321
322 let mut data = Vec::new();
324 audio_data
325 .open(rocket::data::ByteUnit::max_value())
326 .read_to_end(&mut data)
327 .await
328 .map_err(|_| Status::BadRequest)?;
329
330 if data.is_empty() {
331 return Err(Status::BadRequest);
332 }
333
334 let client = Client::new();
336 let base64_data = base64::engine::general_purpose::STANDARD.encode(&data);
337
338 let response = client
339 .post(format!("{}/embed", speechbrain_url.0.as_str()))
340 .header("Content-Type", "application/json")
341 .json(&serde_json::json!({
342 "pcm_base64": base64_data
343 }))
344 .send()
345 .await
346 .map_err(|_| Status::InternalServerError)?;
347
348 if !response.status().is_success() {
349 return Err(Status::InternalServerError);
350 }
351
352 let embed_response: serde_json::Value = response
353 .json()
354 .await
355 .map_err(|_| Status::InternalServerError)?;
356
357 let embedding_b64 = embed_response["embedding"]
358 .as_str()
359 .ok_or(Status::InternalServerError)?;
360
361 let embedding_bytes = base64::engine::general_purpose::STANDARD
363 .decode(embedding_b64)
364 .map_err(|_| Status::InternalServerError)?;
365
366 println!(
367 "DEBUG: Embedding binary length: {} bytes",
368 embedding_bytes.len()
369 );
370
371 println!(
373 "DEBUG: Storing embedding in database for user {}",
374 firebase_uid
375 );
376 sqlx::query("UPDATE users SET voice_embeddings = $1 WHERE firebase_uid = $2")
377 .bind(&embedding_bytes)
378 .bind(&firebase_uid)
379 .execute(&**db_pool)
380 .await
381 .map_err(|_| Status::InternalServerError)?;
382
383 Ok(())
384}
385
386#[post("/delete_voice")]
388pub async fn delete_voice(token: Token, db_pool: &State<PgPool>) -> Result<(), Status> {
389 let user_row: Option<(String,)> =
391 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
392 .bind(&token.0)
393 .fetch_optional(&**db_pool)
394 .await
395 .map_err(|_| Status::InternalServerError)?;
396 let firebase_uid = match user_row {
397 Some((uid,)) => uid,
398 None => {
399 return Err(Status::Unauthorized);
400 }
401 };
402
403 sqlx::query("UPDATE users SET voice_embeddings = NULL WHERE firebase_uid = $1")
405 .bind(&firebase_uid)
406 .execute(&**db_pool)
407 .await
408 .map_err(|_| Status::InternalServerError)?;
409
410 Ok(())
411}
412
413#[get("/voice_status")]
415pub async fn voice_status(token: Token, db_pool: &State<PgPool>) -> Result<String, Status> {
416 let user_row: Option<(String,)> =
418 sqlx::query_as("SELECT firebase_uid FROM users WHERE current_token = $1")
419 .bind(&token.0)
420 .fetch_optional(&**db_pool)
421 .await
422 .map_err(|_| Status::InternalServerError)?;
423 let firebase_uid = match user_row {
424 Some((uid,)) => uid,
425 None => {
426 return Err(Status::Unauthorized);
427 }
428 };
429
430 let voice_row: Option<(Option<Vec<u8>>,)> =
432 sqlx::query_as("SELECT voice_embeddings FROM users WHERE firebase_uid = $1")
433 .bind(&firebase_uid)
434 .fetch_optional(&**db_pool)
435 .await
436 .map_err(|_| Status::InternalServerError)?;
437
438 let has_voice = match voice_row {
439 Some((Some(_),)) => true,
440 _ => {
441 println!(
442 "DEBUG: User {} does not have voice embeddings",
443 firebase_uid
444 );
445 false
446 }
447 };
448
449 Ok(has_voice.to_string())
450}