use rocket::Request; use rocket::{State, fs::TempFile, form::Form, response::Responder, http::ContentType, response::Response}; use rocket::serde::json::Json; use rocket::tokio::fs; use rocket::http::Status; use serde::Deserialize; use tokio::io::AsyncReadExt; use uuid::Uuid; use std::io::Cursor; use chrono::{Utc, Duration}; use sqlx::PgPool; use crate::auth::AuthenticatedUser; use crate::encryption::{encrypt_data, decrypt_data}; use crate::models::{File, Share}; #[derive(FromForm)] pub struct Upload<'f> { pub file: TempFile<'f>, pub filename: Option, } #[post("/upload", data = "", format = "multipart/form-data")] pub async fn upload_file( pool: &State, user: AuthenticatedUser, upload: Form> ) -> Result { let mut buffer = Vec::new(); let mut file = upload.file.open().await.map_err(|_| Status::BadRequest)?; file.read_to_end(&mut buffer).await.map_err(|_| Status::BadRequest)?; let nonce = rand::random::<[u8; 12]>(); let key = std::env::var("ENCRYPTION_KEY").expect("ENCRYPTION_KEY must be set"); let encrypted = encrypt_data(&buffer, key.as_bytes(), &nonce); let file_id = Uuid::new_v4().to_string(); let user_id = user.user_id.ok_or(Status::BadRequest)?; let user_dir = format!("/wdblue/litecloud-store/{}", user_id); let storage_path = format!("{}/{}", &user_dir, file_id); fs::create_dir_all(&user_dir).await.map_err(|_| Status::InternalServerError)?; fs::write(&storage_path, &encrypted).await.map_err(|_| Status::InternalServerError)?; fs::write(format!("{}.nonce", &storage_path), &nonce).await.map_err(|_| Status::InternalServerError)?; let original_name = upload.filename.as_deref().ok_or(Status::BadRequest)?; let size = buffer.len() as i64; sqlx::query!( "INSERT INTO files (user_id, original_name, storage_path, uploaded_at, size) VALUES ($1, $2, $3, $4, $5)", user_id, original_name, storage_path, Utc::now(), size, ) .execute(pool.inner()) .await .map_err(|_| Status::InternalServerError)?; Ok(Status::Created) } #[get("/files")] pub async fn list_user_files(pool: &State, user: AuthenticatedUser) -> Json> { let Some(user_id) = user.user_id else { return Json(Vec::new()); }; let files = sqlx::query_as!( File, "SELECT * FROM files WHERE user_id = $1", user_id ) .fetch_all(pool.inner()) .await .unwrap_or_default(); Json(files) } pub struct FileDownload { pub data: Vec, pub filename: String, } impl<'r> Responder<'r, 'static> for FileDownload { fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'static> { Response::build() .header(ContentType::Binary) .raw_header("Content-Disposition", format!("attachment; filename=\"{}\"", self.filename)) .sized_body(self.data.len(), Cursor::new(self.data)) .ok() } } #[get("/files/")] pub async fn download_file(pool: &State, user: AuthenticatedUser, id: i32) -> Option { let file = sqlx::query_as!( File, "SELECT * FROM files WHERE id = $1 AND user_id = $2", id, user.user_id ) .fetch_optional(pool.inner()) .await.ok()??; let data = fs::read(&file.storage_path).await.ok()?; let nonce = fs::read(format!("{}.nonce", &file.storage_path)).await.ok()?; let key = std::env::var("ENCRYPTION_KEY").unwrap().into_bytes(); let nonce_array: [u8; 12] = nonce.try_into().map_err(|_| "Invalid nonce size").ok()?; let decrypted = decrypt_data(&data, &key, &nonce_array); Some(FileDownload { data: decrypted, filename: file.original_name, }) } #[delete("/files/")] pub async fn delete_file(pool: &State, user: AuthenticatedUser, id: i32) -> Status { let file = sqlx::query_as!( File, "SELECT * FROM files WHERE id = $1 AND user_id = $2", id, user.user_id ) .fetch_optional(pool.inner()) .await .ok() .flatten(); if let Some(file) = file { let _ = fs::remove_file(&file.storage_path).await; let _ = fs::remove_file(format!("{}.nonce", &file.storage_path)).await; sqlx::query!("DELETE FROM files WHERE id = $1", file.id) .execute(pool.inner()) .await .ok(); Status::NoContent } else { Status::NotFound } } #[derive(Deserialize)] pub struct ShareRequest { pub file_id: i32, pub expires_in_days: Option, } #[post("/share", data = "")] pub async fn share_file(pool: &State, user: AuthenticatedUser, req: Json) -> Result, Status> { let file = sqlx::query!( "SELECT id FROM files WHERE id = $1 AND user_id = $2", req.file_id, user.user_id ) .fetch_optional(pool.inner()) .await .map_err(|_| Status::InternalServerError)?; if file.is_none() { return Err(Status::NotFound); } let id = Uuid::new_v4(); let expires = req.expires_in_days.map(|d| Utc::now() + Duration::days(d)); sqlx::query!( "INSERT INTO shares (id, file_id, shared_by, created_at, expires_at) VALUES ($1, $2, $3, $4, $5)", id, req.file_id, user.user_id, Utc::now(), expires, ) .execute(pool.inner()) .await .map_err(|_| Status::InternalServerError)?; Ok(Json(id.to_string())) } #[get("/shared/")] pub async fn download_shared(pool: &State, link: &str) -> Option { let uuid = Uuid::parse_str(link).ok()?; let share_record = sqlx::query!( "SELECT id, file_id, shared_by, created_at, expires_at FROM shares WHERE id = $1", uuid ) .fetch_optional(pool.inner()) .await .ok()??; let share = Share { id: share_record.id, file_id: share_record.file_id, shared_by: share_record.shared_by?, created_at: share_record.created_at, expires_at: share_record.expires_at.map(|dt| dt.naive_utc()), }; if let Some(expiry) = share.expires_at { if expiry < Utc::now().naive_utc() { return None; } } let file_id = share.file_id?; let file = sqlx::query_as!( File, "SELECT * FROM files WHERE id = $1", file_id ) .fetch_optional(pool.inner()) .await .ok()??; let data = fs::read(&file.storage_path).await.ok()?; let nonce = fs::read(format!("{}.nonce", &file.storage_path)).await.ok()?; let key = std::env::var("ENCRYPTION_KEY").unwrap().into_bytes(); let nonce_array: [u8; 12] = nonce.try_into().map_err(|_| "Invalid nonce size").ok()?; let decrypted = decrypt_data(&data, &key, &nonce_array); Some(FileDownload { data: decrypted, filename: file.original_name, }) } #[delete("/shares/")] pub async fn delete_share( pool: &State, user: AuthenticatedUser, share_id: &str ) -> Status { let Some(user_id) = user.user_id else { return Status::Unauthorized; }; // Parse the UUID let uuid = match Uuid::parse_str(share_id) { Ok(id) => id, Err(_) => return Status::BadRequest, }; // Check if the share exists and belongs to the user let share = sqlx::query!( "SELECT id FROM shares WHERE id = $1 AND shared_by = $2", uuid, user_id ) .fetch_optional(pool.inner()) .await; match share { Ok(Some(_)) => { // Delete the share match sqlx::query!("DELETE FROM shares WHERE id = $1", uuid) .execute(pool.inner()) .await { Ok(_) => Status::NoContent, Err(_) => Status::InternalServerError, } } Ok(None) => Status::NotFound, Err(_) => Status::InternalServerError, } } #[derive(serde::Serialize)] pub struct ShareInfo { pub id: String, pub file_id: i32, pub file_name: String, pub created_at: chrono::DateTime, pub expires_at: Option, pub is_expired: bool, } #[get("/shares")] pub async fn list_user_shares(pool: &State, user: AuthenticatedUser) -> Json> { let Some(user_id) = user.user_id else { return Json(Vec::new()); }; let shares = sqlx::query!( r#" SELECT s.id, s.file_id, s.created_at, s.expires_at, f.original_name as file_name FROM shares s JOIN files f ON s.file_id = f.id WHERE s.shared_by = $1 ORDER BY s.created_at DESC "#, user_id ) .fetch_all(pool.inner()) .await .unwrap_or_default(); let now = Utc::now().naive_utc(); let share_infos = shares .into_iter() .map(|record| { let expires_at = record.expires_at.map(|dt| dt.naive_utc()); let is_expired = expires_at.map_or(false, |exp| exp < now); ShareInfo { id: record.id.to_string(), file_id: record.file_id.unwrap_or(0), file_name: record.file_name, created_at: record.created_at, expires_at, is_expired, } }) .collect(); Json(share_infos) }