330 lines
9.4 KiB
Rust
330 lines
9.4 KiB
Rust
use rocket::Request;
|
|
use rocket::{State, fs::TempFile, form::Form, response::Responder, http::ContentType, response::Response};
|
|
use rocket::serde::json::Json;
|
|
use rocket::tokio::fs;
|
|
use rocket::http::Status;
|
|
use serde::Deserialize;
|
|
use tokio::io::AsyncReadExt;
|
|
use uuid::Uuid;
|
|
use std::io::Cursor;
|
|
use chrono::{Utc, Duration};
|
|
use sqlx::PgPool;
|
|
|
|
use crate::auth::AuthenticatedUser;
|
|
use crate::encryption::{encrypt_data, decrypt_data};
|
|
use crate::models::{File, Share};
|
|
|
|
#[derive(FromForm)]
|
|
pub struct Upload<'f> {
|
|
pub file: TempFile<'f>,
|
|
pub filename: Option<String>,
|
|
}
|
|
|
|
#[post("/upload", data = "<upload>", format = "multipart/form-data")]
|
|
pub async fn upload_file(
|
|
pool: &State<PgPool>,
|
|
user: AuthenticatedUser,
|
|
upload: Form<Upload<'_>>
|
|
) -> Result<Status, Status> {
|
|
let mut buffer = Vec::new();
|
|
let mut file = upload.file.open().await.map_err(|_| Status::BadRequest)?;
|
|
file.read_to_end(&mut buffer).await.map_err(|_| Status::BadRequest)?;
|
|
|
|
let nonce = rand::random::<[u8; 12]>();
|
|
let key = std::env::var("ENCRYPTION_KEY").expect("ENCRYPTION_KEY must be set");
|
|
let encrypted = encrypt_data(&buffer, key.as_bytes(), &nonce);
|
|
|
|
let file_id = Uuid::new_v4().to_string();
|
|
let user_id = user.user_id.ok_or(Status::BadRequest)?;
|
|
let user_dir = format!("./data/{}", user_id);
|
|
let storage_path = format!("{}/{}", &user_dir, file_id);
|
|
|
|
fs::create_dir_all(&user_dir).await.map_err(|_| Status::InternalServerError)?;
|
|
fs::write(&storage_path, &encrypted).await.map_err(|_| Status::InternalServerError)?;
|
|
fs::write(format!("{}.nonce", &storage_path), &nonce).await.map_err(|_| Status::InternalServerError)?;
|
|
|
|
let original_name = upload.filename.as_deref().ok_or(Status::BadRequest)?;
|
|
let size = buffer.len() as i64;
|
|
|
|
sqlx::query!(
|
|
"INSERT INTO files (user_id, original_name, storage_path, uploaded_at, size) VALUES ($1, $2, $3, $4, $5)",
|
|
user_id,
|
|
original_name,
|
|
storage_path,
|
|
Utc::now(),
|
|
size,
|
|
)
|
|
.execute(pool.inner())
|
|
.await
|
|
.map_err(|_| Status::InternalServerError)?;
|
|
|
|
Ok(Status::Created)
|
|
}
|
|
|
|
|
|
#[get("/files")]
|
|
pub async fn list_user_files(pool: &State<PgPool>, user: AuthenticatedUser) -> Json<Vec<File>> {
|
|
let Some(user_id) = user.user_id else {
|
|
return Json(Vec::new());
|
|
};
|
|
|
|
let files = sqlx::query_as!(
|
|
File,
|
|
"SELECT * FROM files WHERE user_id = $1",
|
|
user_id
|
|
)
|
|
.fetch_all(pool.inner())
|
|
.await
|
|
.unwrap_or_default();
|
|
|
|
Json(files)
|
|
}
|
|
|
|
pub struct FileDownload {
|
|
pub data: Vec<u8>,
|
|
pub filename: String,
|
|
}
|
|
|
|
impl<'r> Responder<'r, 'static> for FileDownload {
|
|
fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'static> {
|
|
Response::build()
|
|
.header(ContentType::Binary)
|
|
.raw_header("Content-Disposition", format!("attachment; filename=\"{}\"", self.filename))
|
|
.sized_body(self.data.len(), Cursor::new(self.data))
|
|
.ok()
|
|
}
|
|
}
|
|
|
|
#[get("/files/<id>")]
|
|
pub async fn download_file(pool: &State<PgPool>, user: AuthenticatedUser, id: i32) -> Option<FileDownload> {
|
|
let file = sqlx::query_as!(
|
|
File,
|
|
"SELECT * FROM files WHERE id = $1 AND user_id = $2",
|
|
id,
|
|
user.user_id
|
|
)
|
|
.fetch_optional(pool.inner())
|
|
.await.ok()??;
|
|
|
|
let data = fs::read(&file.storage_path).await.ok()?;
|
|
let nonce = fs::read(format!("{}.nonce", &file.storage_path)).await.ok()?;
|
|
let key = std::env::var("ENCRYPTION_KEY").unwrap().into_bytes();
|
|
let nonce_array: [u8; 12] = nonce.try_into().map_err(|_| "Invalid nonce size").ok()?;
|
|
let decrypted = decrypt_data(&data, &key, &nonce_array);
|
|
|
|
Some(FileDownload {
|
|
data: decrypted,
|
|
filename: file.original_name,
|
|
})
|
|
}
|
|
|
|
#[delete("/files/<id>")]
|
|
pub async fn delete_file(pool: &State<PgPool>, user: AuthenticatedUser, id: i32) -> Status {
|
|
let file = sqlx::query_as!(
|
|
File,
|
|
"SELECT * FROM files WHERE id = $1 AND user_id = $2",
|
|
id,
|
|
user.user_id
|
|
)
|
|
.fetch_optional(pool.inner())
|
|
.await
|
|
.ok()
|
|
.flatten();
|
|
|
|
if let Some(file) = file {
|
|
let _ = fs::remove_file(&file.storage_path).await;
|
|
let _ = fs::remove_file(format!("{}.nonce", &file.storage_path)).await;
|
|
|
|
sqlx::query!("DELETE FROM files WHERE id = $1", file.id)
|
|
.execute(pool.inner())
|
|
.await
|
|
.ok();
|
|
|
|
Status::NoContent
|
|
} else {
|
|
Status::NotFound
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct ShareRequest {
|
|
pub file_id: i32,
|
|
pub expires_in_days: Option<i64>,
|
|
}
|
|
|
|
#[post("/share", data = "<req>")]
|
|
pub async fn share_file(pool: &State<PgPool>, user: AuthenticatedUser, req: Json<ShareRequest>) -> Result<Json<String>, Status> {
|
|
let file = sqlx::query!(
|
|
"SELECT id FROM files WHERE id = $1 AND user_id = $2",
|
|
req.file_id,
|
|
user.user_id
|
|
)
|
|
.fetch_optional(pool.inner())
|
|
.await
|
|
.map_err(|_| Status::InternalServerError)?;
|
|
|
|
if file.is_none() {
|
|
return Err(Status::NotFound);
|
|
}
|
|
|
|
let id = Uuid::new_v4();
|
|
let expires = req.expires_in_days.map(|d| Utc::now() + Duration::days(d));
|
|
|
|
sqlx::query!(
|
|
"INSERT INTO shares (id, file_id, shared_by, created_at, expires_at) VALUES ($1, $2, $3, $4, $5)",
|
|
id,
|
|
req.file_id,
|
|
user.user_id,
|
|
Utc::now(),
|
|
expires,
|
|
)
|
|
.execute(pool.inner())
|
|
.await
|
|
.map_err(|_| Status::InternalServerError)?;
|
|
|
|
Ok(Json(id.to_string()))
|
|
}
|
|
|
|
#[get("/shared/<link>")]
|
|
pub async fn download_shared(pool: &State<PgPool>, link: &str) -> Option<FileDownload> {
|
|
let uuid = Uuid::parse_str(link).ok()?;
|
|
let share_record = sqlx::query!(
|
|
"SELECT id, file_id, shared_by, created_at, expires_at FROM shares WHERE id = $1",
|
|
uuid
|
|
)
|
|
.fetch_optional(pool.inner())
|
|
.await
|
|
.ok()??;
|
|
|
|
let share = Share {
|
|
id: share_record.id,
|
|
file_id: share_record.file_id,
|
|
shared_by: share_record.shared_by?,
|
|
created_at: share_record.created_at,
|
|
expires_at: share_record.expires_at.map(|dt| dt.naive_utc()),
|
|
};
|
|
|
|
if let Some(expiry) = share.expires_at {
|
|
if expiry < Utc::now().naive_utc() {
|
|
return None;
|
|
}
|
|
}
|
|
|
|
let file_id = share.file_id?;
|
|
let file = sqlx::query_as!(
|
|
File,
|
|
"SELECT * FROM files WHERE id = $1",
|
|
file_id
|
|
)
|
|
.fetch_optional(pool.inner())
|
|
.await
|
|
.ok()??;
|
|
|
|
let data = fs::read(&file.storage_path).await.ok()?;
|
|
let nonce = fs::read(format!("{}.nonce", &file.storage_path)).await.ok()?;
|
|
let key = std::env::var("ENCRYPTION_KEY").unwrap().into_bytes();
|
|
let nonce_array: [u8; 12] = nonce.try_into().map_err(|_| "Invalid nonce size").ok()?;
|
|
let decrypted = decrypt_data(&data, &key, &nonce_array);
|
|
|
|
Some(FileDownload {
|
|
data: decrypted,
|
|
filename: file.original_name,
|
|
})
|
|
}
|
|
|
|
#[delete("/shares/<share_id>")]
|
|
pub async fn delete_share(
|
|
pool: &State<PgPool>,
|
|
user: AuthenticatedUser,
|
|
share_id: &str
|
|
) -> Status {
|
|
let Some(user_id) = user.user_id else {
|
|
return Status::Unauthorized;
|
|
};
|
|
|
|
// Parse the UUID
|
|
let uuid = match Uuid::parse_str(share_id) {
|
|
Ok(id) => id,
|
|
Err(_) => return Status::BadRequest,
|
|
};
|
|
|
|
// Check if the share exists and belongs to the user
|
|
let share = sqlx::query!(
|
|
"SELECT id FROM shares WHERE id = $1 AND shared_by = $2",
|
|
uuid,
|
|
user_id
|
|
)
|
|
.fetch_optional(pool.inner())
|
|
.await;
|
|
|
|
match share {
|
|
Ok(Some(_)) => {
|
|
// Delete the share
|
|
match sqlx::query!("DELETE FROM shares WHERE id = $1", uuid)
|
|
.execute(pool.inner())
|
|
.await
|
|
{
|
|
Ok(_) => Status::NoContent,
|
|
Err(_) => Status::InternalServerError,
|
|
}
|
|
}
|
|
Ok(None) => Status::NotFound,
|
|
Err(_) => Status::InternalServerError,
|
|
}
|
|
}
|
|
|
|
#[derive(serde::Serialize)]
|
|
pub struct ShareInfo {
|
|
pub id: String,
|
|
pub file_id: i32,
|
|
pub file_name: String,
|
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
|
pub expires_at: Option<chrono::NaiveDateTime>,
|
|
pub is_expired: bool,
|
|
}
|
|
|
|
#[get("/shares")]
|
|
pub async fn list_user_shares(pool: &State<PgPool>, user: AuthenticatedUser) -> Json<Vec<ShareInfo>> {
|
|
let Some(user_id) = user.user_id else {
|
|
return Json(Vec::new());
|
|
};
|
|
|
|
let shares = sqlx::query!(
|
|
r#"
|
|
SELECT
|
|
s.id,
|
|
s.file_id,
|
|
s.created_at,
|
|
s.expires_at,
|
|
f.original_name as file_name
|
|
FROM shares s
|
|
JOIN files f ON s.file_id = f.id
|
|
WHERE s.shared_by = $1
|
|
ORDER BY s.created_at DESC
|
|
"#,
|
|
user_id
|
|
)
|
|
.fetch_all(pool.inner())
|
|
.await
|
|
.unwrap_or_default();
|
|
|
|
let now = Utc::now().naive_utc();
|
|
let share_infos = shares
|
|
.into_iter()
|
|
.map(|record| {
|
|
let expires_at = record.expires_at.map(|dt| dt.naive_utc());
|
|
let is_expired = expires_at.map_or(false, |exp| exp < now);
|
|
|
|
ShareInfo {
|
|
id: record.id.to_string(),
|
|
file_id: record.file_id.unwrap_or(0),
|
|
file_name: record.file_name,
|
|
created_at: record.created_at,
|
|
expires_at,
|
|
is_expired,
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
Json(share_infos)
|
|
} |