diff --git a/src/api/files.rs b/src/api/files.rs index 7c67e81..6ed8338 100644 --- a/src/api/files.rs +++ b/src/api/files.rs @@ -1,5 +1,6 @@ use crate::connection::MDRSConnection; pub use crate::models::file::File; +use anyhow::bail; use unicode_normalization::UnicodeNormalization; #[derive(serde::Deserialize)] @@ -14,13 +15,14 @@ impl MDRSConnection { let mut all_files = Vec::new(); let mut page: u32 = 1; loop { - let resp = self - .client - .get(self.build_url("v3/files/")) - .headers(self.prepare_headers()) - .query(&[("folder_id", folder_id), ("page", &page.to_string())]) - .send() - .await?; + let params = [ + ("folder_id", folder_id.to_string()), + ("page", page.to_string()), + ]; + let resp = self.get_with_query("v3/files/", ¶ms).await?; + if !resp.status().is_success() { + anyhow::bail!("List files failed: {}", resp.status()); + } let list: FileListResponse = resp.json().await?; let has_next = list.next.is_some(); all_files.extend(list.results); @@ -56,12 +58,10 @@ impl MDRSConnection { /// Download a file from `url` and write it to `dest`. pub async fn download_file(&self, url: &str, dest: &str) -> Result<(), anyhow::Error> { - let resp = self - .client - .get(url) - .headers(self.prepare_headers()) - .send() - .await?; + let resp = self.get_url(url).await?; + if !resp.status().is_success() { + bail!("Download failed: {}", resp.status()); + } let bytes = resp.bytes().await?; tokio::fs::write(dest, &bytes).await?; Ok(()) diff --git a/src/api/folders.rs b/src/api/folders.rs index 5d23214..de7182d 100644 --- a/src/api/folders.rs +++ b/src/api/folders.rs @@ -8,24 +8,25 @@ impl MDRSConnection { &self, lab_id: u32, path: &str, - ) -> Result, reqwest::Error> { - let resp = self - .client - .get(self.build_url("v3/folders/")) - .headers(self.prepare_headers()) - .query(&[ - ("laboratory_id", lab_id.to_string()), - ("path", path.to_string()), - ]) - .send() - .await?; - resp.json::>().await + ) -> Result, anyhow::Error> { + let params = [ + ("laboratory_id", lab_id.to_string()), + ("path", path.to_string()), + ]; + let resp = self.get_with_query("v3/folders/", ¶ms).await?; + if !resp.status().is_success() { + bail!("List folders failed: {}", resp.status()); + } + Ok(resp.json::>().await?) } /// Retrieve full folder details including sub_folders (GET v3/folders/{id}/) - pub async fn retrieve_folder(&self, id: &str) -> Result { + pub async fn retrieve_folder(&self, id: &str) -> Result { let resp = self.get(&format!("v3/folders/{}/", id)).await?; - resp.json::().await + if !resp.status().is_success() { + bail!("Retrieve folder failed: {}", resp.status()); + } + Ok(resp.json::().await?) } /// Create a new folder under `parent_id` (POST v3/folders/). @@ -33,30 +34,24 @@ impl MDRSConnection { &self, parent_id: &str, folder_name: &str, - ) -> reqwest::Result { + ) -> Result { let body = serde_json::json!({ "name": folder_name, "parent_id": parent_id, "description": "", "template_id": -1, }); - self.client - .post(self.build_url("v3/folders/")) - .headers(self.prepare_headers()) - .json(&body) - .send() - .await + self.post_json("v3/folders/", &body).await } /// Authenticate against a password-locked folder (POST v3/folders/{id}/auth/). /// Returns `Err` if the password is incorrect or the request fails. pub async fn folder_auth(&self, folder_id: &str, password: &str) -> Result<(), anyhow::Error> { let resp = self - .client - .post(self.build_url(&format!("v3/folders/{}/auth/", folder_id))) - .headers(self.prepare_headers()) - .json(&serde_json::json!({"password": password})) - .send() + .post_json( + &format!("v3/folders/{}/auth/", folder_id), + &serde_json::json!({"password": password}), + ) .await?; if resp.status() == reqwest::StatusCode::UNAUTHORIZED { bail!("Password is incorrect."); diff --git a/src/api/laboratories.rs b/src/api/laboratories.rs index a0de2dc..ed3dc47 100644 --- a/src/api/laboratories.rs +++ b/src/api/laboratories.rs @@ -9,8 +9,11 @@ struct LabListResponse { } impl MDRSConnection { - pub async fn list_laboratories(&self) -> Result { + pub async fn list_laboratories(&self) -> Result { let resp = self.get("v3/laboratories/").await?; + if !resp.status().is_success() { + anyhow::bail!("List laboratories failed: {}", resp.status()); + } // The API may return a paginated object or a direct array let text = resp.text().await?; let items: Vec = diff --git a/src/api/users.rs b/src/api/users.rs index d083647..42afb2f 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -24,8 +24,11 @@ struct TokenRefreshResponse { impl MDRSConnection { /// Fetch current user and return the slim 4-field model matching the Python cache format. - pub async fn get_current_user(&self) -> Result { + pub async fn get_current_user(&self) -> Result { let resp = self.get("v3/users/current/").await?; + if !resp.status().is_success() { + bail!("Get current user failed: {}", resp.status()); + } let obj = resp.json::().await?; let laboratory_ids = obj.laboratories.into_iter().map(|l| l.id).collect(); Ok(ModelUser { diff --git a/src/cache/mod.rs b/src/cache/mod.rs index 89fc32b..f4d7b32 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -168,5 +168,7 @@ pub fn create_authenticated_conn( ) -> Result { let url = crate::commands::config::get_remote_url(remote)? .ok_or_else(|| anyhow!("Remote `{}` is not configured.", remote))?; - Ok(MDRSConnection::new(&url).with_token(cache.token.access.clone())) + Ok(MDRSConnection::new(&url) + .with_remote(remote) + .with_token(cache.token.access.clone())) } diff --git a/src/commands/chacl.rs b/src/commands/chacl.rs index 88cdebb..a335a78 100644 --- a/src/commands/chacl.rs +++ b/src/commands/chacl.rs @@ -38,11 +38,10 @@ pub async fn chacl( data.insert("password".to_string(), serde_json::json!(pw)); } let resp = conn - .client - .post(conn.build_url(&format!("v3/folders/{}/acl/", folder.id))) - .headers(conn.prepare_headers()) - .json(&serde_json::Value::Object(data)) - .send() + .post_json( + &format!("v3/folders/{}/acl/", folder.id), + &serde_json::Value::Object(data), + ) .await?; if !resp.status().is_success() { diff --git a/src/commands/cp.rs b/src/commands/cp.rs index 8fa7775..5bbe43e 100644 --- a/src/commands/cp.rs +++ b/src/commands/cp.rs @@ -59,11 +59,7 @@ pub async fn cp(src_path: &str, dest_path: &str, recursive: bool) -> Result<(), } let body = serde_json::json!({"folder": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/files/{}/copy/", src_file_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/files/{}/copy/", src_file_id), &body) .await?; if !resp.status().is_success() { bail!("Copy failed: {}", resp.status()); @@ -103,11 +99,7 @@ pub async fn cp(src_path: &str, dest_path: &str, recursive: bool) -> Result<(), } let body = serde_json::json!({"parent": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/folders/{}/copy/", src_folder_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/folders/{}/copy/", src_folder_id), &body) .await?; if !resp.status().is_success() { bail!("Copy failed: {}", resp.status()); diff --git a/src/commands/download.rs b/src/commands/download.rs index e1802f9..6194f4e 100644 --- a/src/commands/download.rs +++ b/src/commands/download.rs @@ -110,20 +110,9 @@ pub async fn download( } let url = make_absolute_url(&conn, &f.download_url); let conn = conn.clone(); - let remote = remote.clone(); futs.push(tokio::spawn(async move { - // Refresh the access token if it has expired or is about to - // expire. conn.with_token() reuses the shared HTTP client - // (connection pool) while supplying a fresh Bearer token. - let task_conn = match load_cache_with_token_refresh(&remote).await { - Ok(c) => conn.with_token(c.token.access), - Err(e) => { - eprintln!("Error: {}", e); - return; - } - }; let dest_str = dest_path.to_string_lossy().to_string(); - match task_conn.download_file(&url, &dest_str).await { + match conn.download_file(&url, &dest_str).await { Ok(_) => println!("{}", dest_path.display()), Err(_) => { eprintln!("Failed: {}", dest_path.display()); diff --git a/src/commands/mv.rs b/src/commands/mv.rs index 75165ad..215b3dd 100644 --- a/src/commands/mv.rs +++ b/src/commands/mv.rs @@ -59,11 +59,7 @@ pub async fn mv(src_path: &str, dest_path: &str) -> Result<(), anyhow::Error> { } let body = serde_json::json!({"folder": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/files/{}/move/", src_file_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/files/{}/move/", src_file_id), &body) .await?; if !resp.status().is_success() { bail!("Move failed: {}", resp.status()); @@ -100,11 +96,7 @@ pub async fn mv(src_path: &str, dest_path: &str) -> Result<(), anyhow::Error> { } let body = serde_json::json!({"parent": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/folders/{}/move/", src_folder_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/folders/{}/move/", src_folder_id), &body) .await?; if !resp.status().is_success() { bail!("Move failed: {}", resp.status()); diff --git a/src/commands/rm.rs b/src/commands/rm.rs index d592d65..8c09338 100644 --- a/src/commands/rm.rs +++ b/src/commands/rm.rs @@ -28,12 +28,7 @@ pub async fn rm(remote_path: &str, recursive: bool) -> Result<(), anyhow::Error> // Check if target is a file let files = conn.list_all_files(&parent_folder.id).await?; if let Some(file) = find_file_by_name(&files, target_name) { - let resp = conn - .client - .delete(conn.build_url(&format!("v3/files/{}/", file.id))) - .headers(conn.prepare_headers()) - .send() - .await?; + let resp = conn.delete(&format!("v3/files/{}/", file.id)).await?; if !resp.status().is_success() { bail!("Failed to delete file: {}", resp.status()); } @@ -46,11 +41,10 @@ pub async fn rm(remote_path: &str, recursive: bool) -> Result<(), anyhow::Error> bail!("Cannot remove `{}`: Is a folder.", path); } let resp = conn - .client - .delete(conn.build_url(&format!("v3/folders/{}/", subfolder.id))) - .headers(conn.prepare_headers()) - .query(&[("recursive", "true")]) - .send() + .delete_with_query( + &format!("v3/folders/{}/", subfolder.id), + &[("recursive", "true")], + ) .await?; if !resp.status().is_success() { bail!("Failed to delete folder: {}", resp.status()); diff --git a/src/commands/upload.rs b/src/commands/upload.rs index 2985b90..6365d62 100644 --- a/src/commands/upload.rs +++ b/src/commands/upload.rs @@ -110,19 +110,8 @@ pub async fn upload( let folder_id = remote_id.clone(); let remote_path_prefix = folder_detail.path.clone(); let fname = filename.clone(); - let remote = remote.clone(); futs.push(tokio::spawn(async move { - // Refresh the access token if it has expired or is about to - // expire. conn.with_token() reuses the shared HTTP client - // (connection pool) while supplying a fresh Bearer token. - let task_conn = match load_cache_with_token_refresh(&remote).await { - Ok(c) => conn.with_token(c.token.access), - Err(e) => { - eprintln!("Error: {}", e); - return; - } - }; - match task_conn.upload_file(&folder_id, &file_path_str).await { + match conn.upload_file(&folder_id, &file_path_str).await { Ok(_) => println!("{}{}", remote_path_prefix, fname), Err(e) => eprintln!("Error: {}", e), } diff --git a/src/connection.rs b/src/connection.rs index 904c694..e265c06 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,5 +1,6 @@ -use reqwest::Client; use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT}; +use reqwest::{Client, Response}; +use serde::Serialize; fn build_user_agent() -> String { let info = os_info::get(); @@ -26,8 +27,10 @@ fn build_user_agent() -> String { ) } +#[derive(Clone)] /// HTTP transport layer for MDRS API calls. pub struct MDRSConnection { + pub remote: Option, pub url: String, pub client: Client, pub token: Option, @@ -36,12 +39,18 @@ pub struct MDRSConnection { impl MDRSConnection { pub fn new(url: &str) -> Self { MDRSConnection { + remote: None, url: url.to_string(), client: Client::new(), token: None, } } + pub fn with_remote(mut self, remote: &str) -> Self { + self.remote = Some(remote.to_string()); + self + } + /// Create a new connection that shares the HTTP client (and its connection /// pool) with the receiver but uses a fresh access token. Useful for /// spawning per-task connections without allocating a new connection pool @@ -51,12 +60,23 @@ impl MDRSConnection { /// keeps the shared pool intact. pub fn with_token(&self, access_token: String) -> Self { MDRSConnection { + remote: self.remote.clone(), url: self.url.clone(), client: self.client.clone(), token: Some(access_token), } } + async fn connection_with_fresh_token(&self) -> Result { + match (&self.remote, &self.token) { + (Some(remote), Some(_)) => { + let cache = crate::cache::load_cache_with_token_refresh(remote).await?; + Ok(self.with_token(cache.token.access)) + } + _ => Ok(self.clone()), + } + } + pub fn build_url(&self, path: &str) -> String { format!("{}/{}", self.url.trim_end_matches('/'), path) } @@ -79,24 +99,98 @@ impl MDRSConnection { headers } - pub async fn get(&self, path: &str) -> reqwest::Result { - self.client - .get(self.build_url(path)) - .headers(self.prepare_headers()) + pub async fn get(&self, path: &str) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .get(conn.build_url(path)) + .headers(conn.prepare_headers()) .send() - .await + .await?) + } + + pub async fn get_with_query(&self, path: &str, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .get(conn.build_url(path)) + .headers(conn.prepare_headers()) + .query(query) + .send() + .await?) + } + + pub async fn get_url(&self, url: &str) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .get(if url.starts_with("http") { + url.to_string() + } else { + conn.build_url(url) + }) + .headers(conn.prepare_headers()) + .send() + .await?) + } + + pub async fn post_json(&self, path: &str, body: &B) -> Result + where + B: Serialize + ?Sized, + { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .post(conn.build_url(path)) + .headers(conn.prepare_headers()) + .json(body) + .send() + .await?) } pub async fn post_multipart( &self, path: &str, form: reqwest::multipart::Form, - ) -> reqwest::Result { - self.client - .post(self.build_url(path)) - .headers(self.prepare_headers()) + ) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .post(conn.build_url(path)) + .headers(conn.prepare_headers()) .multipart(form) .send() - .await + .await?) + } + + pub async fn delete(&self, path: &str) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .delete(conn.build_url(path)) + .headers(conn.prepare_headers()) + .send() + .await?) + } + + pub async fn delete_with_query( + &self, + path: &str, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .delete(conn.build_url(path)) + .headers(conn.prepare_headers()) + .query(query) + .send() + .await?) } }