From a67f9a72a6cddff741d28070407e4dfbb382fe62 Mon Sep 17 00:00:00 2001 From: Yoshihiro OKUMURA Date: Mon, 20 Apr 2026 16:26:47 +0900 Subject: [PATCH] fix(auth): refresh tokens before authenticated requests Move token refresh checks into the shared Rust connection/API path so long-running authenticated operations stop reusing stale access tokens. This covers recursive download and upload traversal, recursive ls via the shared APIs, and direct authenticated commands such as cp, mv, rm, and chacl. Also surface HTTP failures earlier in the affected API methods instead of failing later during response parsing. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/api/files.rs | 26 ++++----- src/api/folders.rs | 47 +++++++--------- src/api/laboratories.rs | 5 +- src/api/users.rs | 5 +- src/cache/mod.rs | 4 +- src/commands/chacl.rs | 9 ++- src/commands/cp.rs | 12 +--- src/commands/download.rs | 13 +---- src/commands/mv.rs | 12 +--- src/commands/rm.rs | 16 ++---- src/commands/upload.rs | 13 +---- src/connection.rs | 116 +++++++++++++++++++++++++++++++++++---- 12 files changed, 165 insertions(+), 113 deletions(-) diff --git a/src/api/files.rs b/src/api/files.rs index 7c67e81..6ed8338 100644 --- a/src/api/files.rs +++ b/src/api/files.rs @@ -1,5 +1,6 @@ use crate::connection::MDRSConnection; pub use crate::models::file::File; +use anyhow::bail; use unicode_normalization::UnicodeNormalization; #[derive(serde::Deserialize)] @@ -14,13 +15,14 @@ impl MDRSConnection { let mut all_files = Vec::new(); let mut page: u32 = 1; loop { - let resp = self - .client - .get(self.build_url("v3/files/")) - .headers(self.prepare_headers()) - .query(&[("folder_id", folder_id), ("page", &page.to_string())]) - .send() - .await?; + let params = [ + ("folder_id", folder_id.to_string()), + ("page", page.to_string()), + ]; + let resp = self.get_with_query("v3/files/", ¶ms).await?; + if !resp.status().is_success() { + anyhow::bail!("List files failed: {}", resp.status()); + } let list: FileListResponse = resp.json().await?; let has_next = list.next.is_some(); all_files.extend(list.results); @@ -56,12 +58,10 @@ impl MDRSConnection { /// Download a file from `url` and write it to `dest`. pub async fn download_file(&self, url: &str, dest: &str) -> Result<(), anyhow::Error> { - let resp = self - .client - .get(url) - .headers(self.prepare_headers()) - .send() - .await?; + let resp = self.get_url(url).await?; + if !resp.status().is_success() { + bail!("Download failed: {}", resp.status()); + } let bytes = resp.bytes().await?; tokio::fs::write(dest, &bytes).await?; Ok(()) diff --git a/src/api/folders.rs b/src/api/folders.rs index 5d23214..de7182d 100644 --- a/src/api/folders.rs +++ b/src/api/folders.rs @@ -8,24 +8,25 @@ impl MDRSConnection { &self, lab_id: u32, path: &str, - ) -> Result, reqwest::Error> { - let resp = self - .client - .get(self.build_url("v3/folders/")) - .headers(self.prepare_headers()) - .query(&[ - ("laboratory_id", lab_id.to_string()), - ("path", path.to_string()), - ]) - .send() - .await?; - resp.json::>().await + ) -> Result, anyhow::Error> { + let params = [ + ("laboratory_id", lab_id.to_string()), + ("path", path.to_string()), + ]; + let resp = self.get_with_query("v3/folders/", ¶ms).await?; + if !resp.status().is_success() { + bail!("List folders failed: {}", resp.status()); + } + Ok(resp.json::>().await?) } /// Retrieve full folder details including sub_folders (GET v3/folders/{id}/) - pub async fn retrieve_folder(&self, id: &str) -> Result { + pub async fn retrieve_folder(&self, id: &str) -> Result { let resp = self.get(&format!("v3/folders/{}/", id)).await?; - resp.json::().await + if !resp.status().is_success() { + bail!("Retrieve folder failed: {}", resp.status()); + } + Ok(resp.json::().await?) } /// Create a new folder under `parent_id` (POST v3/folders/). @@ -33,30 +34,24 @@ impl MDRSConnection { &self, parent_id: &str, folder_name: &str, - ) -> reqwest::Result { + ) -> Result { let body = serde_json::json!({ "name": folder_name, "parent_id": parent_id, "description": "", "template_id": -1, }); - self.client - .post(self.build_url("v3/folders/")) - .headers(self.prepare_headers()) - .json(&body) - .send() - .await + self.post_json("v3/folders/", &body).await } /// Authenticate against a password-locked folder (POST v3/folders/{id}/auth/). /// Returns `Err` if the password is incorrect or the request fails. pub async fn folder_auth(&self, folder_id: &str, password: &str) -> Result<(), anyhow::Error> { let resp = self - .client - .post(self.build_url(&format!("v3/folders/{}/auth/", folder_id))) - .headers(self.prepare_headers()) - .json(&serde_json::json!({"password": password})) - .send() + .post_json( + &format!("v3/folders/{}/auth/", folder_id), + &serde_json::json!({"password": password}), + ) .await?; if resp.status() == reqwest::StatusCode::UNAUTHORIZED { bail!("Password is incorrect."); diff --git a/src/api/laboratories.rs b/src/api/laboratories.rs index a0de2dc..ed3dc47 100644 --- a/src/api/laboratories.rs +++ b/src/api/laboratories.rs @@ -9,8 +9,11 @@ struct LabListResponse { } impl MDRSConnection { - pub async fn list_laboratories(&self) -> Result { + pub async fn list_laboratories(&self) -> Result { let resp = self.get("v3/laboratories/").await?; + if !resp.status().is_success() { + anyhow::bail!("List laboratories failed: {}", resp.status()); + } // The API may return a paginated object or a direct array let text = resp.text().await?; let items: Vec = diff --git a/src/api/users.rs b/src/api/users.rs index d083647..42afb2f 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -24,8 +24,11 @@ struct TokenRefreshResponse { impl MDRSConnection { /// Fetch current user and return the slim 4-field model matching the Python cache format. - pub async fn get_current_user(&self) -> Result { + pub async fn get_current_user(&self) -> Result { let resp = self.get("v3/users/current/").await?; + if !resp.status().is_success() { + bail!("Get current user failed: {}", resp.status()); + } let obj = resp.json::().await?; let laboratory_ids = obj.laboratories.into_iter().map(|l| l.id).collect(); Ok(ModelUser { diff --git a/src/cache/mod.rs b/src/cache/mod.rs index 89fc32b..f4d7b32 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -168,5 +168,7 @@ pub fn create_authenticated_conn( ) -> Result { let url = crate::commands::config::get_remote_url(remote)? .ok_or_else(|| anyhow!("Remote `{}` is not configured.", remote))?; - Ok(MDRSConnection::new(&url).with_token(cache.token.access.clone())) + Ok(MDRSConnection::new(&url) + .with_remote(remote) + .with_token(cache.token.access.clone())) } diff --git a/src/commands/chacl.rs b/src/commands/chacl.rs index 88cdebb..a335a78 100644 --- a/src/commands/chacl.rs +++ b/src/commands/chacl.rs @@ -38,11 +38,10 @@ pub async fn chacl( data.insert("password".to_string(), serde_json::json!(pw)); } let resp = conn - .client - .post(conn.build_url(&format!("v3/folders/{}/acl/", folder.id))) - .headers(conn.prepare_headers()) - .json(&serde_json::Value::Object(data)) - .send() + .post_json( + &format!("v3/folders/{}/acl/", folder.id), + &serde_json::Value::Object(data), + ) .await?; if !resp.status().is_success() { diff --git a/src/commands/cp.rs b/src/commands/cp.rs index 8fa7775..5bbe43e 100644 --- a/src/commands/cp.rs +++ b/src/commands/cp.rs @@ -59,11 +59,7 @@ pub async fn cp(src_path: &str, dest_path: &str, recursive: bool) -> Result<(), } let body = serde_json::json!({"folder": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/files/{}/copy/", src_file_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/files/{}/copy/", src_file_id), &body) .await?; if !resp.status().is_success() { bail!("Copy failed: {}", resp.status()); @@ -103,11 +99,7 @@ pub async fn cp(src_path: &str, dest_path: &str, recursive: bool) -> Result<(), } let body = serde_json::json!({"parent": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/folders/{}/copy/", src_folder_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/folders/{}/copy/", src_folder_id), &body) .await?; if !resp.status().is_success() { bail!("Copy failed: {}", resp.status()); diff --git a/src/commands/download.rs b/src/commands/download.rs index e1802f9..6194f4e 100644 --- a/src/commands/download.rs +++ b/src/commands/download.rs @@ -110,20 +110,9 @@ pub async fn download( } let url = make_absolute_url(&conn, &f.download_url); let conn = conn.clone(); - let remote = remote.clone(); futs.push(tokio::spawn(async move { - // Refresh the access token if it has expired or is about to - // expire. conn.with_token() reuses the shared HTTP client - // (connection pool) while supplying a fresh Bearer token. - let task_conn = match load_cache_with_token_refresh(&remote).await { - Ok(c) => conn.with_token(c.token.access), - Err(e) => { - eprintln!("Error: {}", e); - return; - } - }; let dest_str = dest_path.to_string_lossy().to_string(); - match task_conn.download_file(&url, &dest_str).await { + match conn.download_file(&url, &dest_str).await { Ok(_) => println!("{}", dest_path.display()), Err(_) => { eprintln!("Failed: {}", dest_path.display()); diff --git a/src/commands/mv.rs b/src/commands/mv.rs index 75165ad..215b3dd 100644 --- a/src/commands/mv.rs +++ b/src/commands/mv.rs @@ -59,11 +59,7 @@ pub async fn mv(src_path: &str, dest_path: &str) -> Result<(), anyhow::Error> { } let body = serde_json::json!({"folder": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/files/{}/move/", src_file_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/files/{}/move/", src_file_id), &body) .await?; if !resp.status().is_success() { bail!("Move failed: {}", resp.status()); @@ -100,11 +96,7 @@ pub async fn mv(src_path: &str, dest_path: &str) -> Result<(), anyhow::Error> { } let body = serde_json::json!({"parent": d_parent_folder.id, "name": d_basename}); let resp = conn - .client - .post(conn.build_url(&format!("v3/folders/{}/move/", src_folder_id))) - .headers(conn.prepare_headers()) - .json(&body) - .send() + .post_json(&format!("v3/folders/{}/move/", src_folder_id), &body) .await?; if !resp.status().is_success() { bail!("Move failed: {}", resp.status()); diff --git a/src/commands/rm.rs b/src/commands/rm.rs index d592d65..8c09338 100644 --- a/src/commands/rm.rs +++ b/src/commands/rm.rs @@ -28,12 +28,7 @@ pub async fn rm(remote_path: &str, recursive: bool) -> Result<(), anyhow::Error> // Check if target is a file let files = conn.list_all_files(&parent_folder.id).await?; if let Some(file) = find_file_by_name(&files, target_name) { - let resp = conn - .client - .delete(conn.build_url(&format!("v3/files/{}/", file.id))) - .headers(conn.prepare_headers()) - .send() - .await?; + let resp = conn.delete(&format!("v3/files/{}/", file.id)).await?; if !resp.status().is_success() { bail!("Failed to delete file: {}", resp.status()); } @@ -46,11 +41,10 @@ pub async fn rm(remote_path: &str, recursive: bool) -> Result<(), anyhow::Error> bail!("Cannot remove `{}`: Is a folder.", path); } let resp = conn - .client - .delete(conn.build_url(&format!("v3/folders/{}/", subfolder.id))) - .headers(conn.prepare_headers()) - .query(&[("recursive", "true")]) - .send() + .delete_with_query( + &format!("v3/folders/{}/", subfolder.id), + &[("recursive", "true")], + ) .await?; if !resp.status().is_success() { bail!("Failed to delete folder: {}", resp.status()); diff --git a/src/commands/upload.rs b/src/commands/upload.rs index 2985b90..6365d62 100644 --- a/src/commands/upload.rs +++ b/src/commands/upload.rs @@ -110,19 +110,8 @@ pub async fn upload( let folder_id = remote_id.clone(); let remote_path_prefix = folder_detail.path.clone(); let fname = filename.clone(); - let remote = remote.clone(); futs.push(tokio::spawn(async move { - // Refresh the access token if it has expired or is about to - // expire. conn.with_token() reuses the shared HTTP client - // (connection pool) while supplying a fresh Bearer token. - let task_conn = match load_cache_with_token_refresh(&remote).await { - Ok(c) => conn.with_token(c.token.access), - Err(e) => { - eprintln!("Error: {}", e); - return; - } - }; - match task_conn.upload_file(&folder_id, &file_path_str).await { + match conn.upload_file(&folder_id, &file_path_str).await { Ok(_) => println!("{}{}", remote_path_prefix, fname), Err(e) => eprintln!("Error: {}", e), } diff --git a/src/connection.rs b/src/connection.rs index 904c694..e265c06 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,5 +1,6 @@ -use reqwest::Client; use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT}; +use reqwest::{Client, Response}; +use serde::Serialize; fn build_user_agent() -> String { let info = os_info::get(); @@ -26,8 +27,10 @@ fn build_user_agent() -> String { ) } +#[derive(Clone)] /// HTTP transport layer for MDRS API calls. pub struct MDRSConnection { + pub remote: Option, pub url: String, pub client: Client, pub token: Option, @@ -36,12 +39,18 @@ pub struct MDRSConnection { impl MDRSConnection { pub fn new(url: &str) -> Self { MDRSConnection { + remote: None, url: url.to_string(), client: Client::new(), token: None, } } + pub fn with_remote(mut self, remote: &str) -> Self { + self.remote = Some(remote.to_string()); + self + } + /// Create a new connection that shares the HTTP client (and its connection /// pool) with the receiver but uses a fresh access token. Useful for /// spawning per-task connections without allocating a new connection pool @@ -51,12 +60,23 @@ impl MDRSConnection { /// keeps the shared pool intact. pub fn with_token(&self, access_token: String) -> Self { MDRSConnection { + remote: self.remote.clone(), url: self.url.clone(), client: self.client.clone(), token: Some(access_token), } } + async fn connection_with_fresh_token(&self) -> Result { + match (&self.remote, &self.token) { + (Some(remote), Some(_)) => { + let cache = crate::cache::load_cache_with_token_refresh(remote).await?; + Ok(self.with_token(cache.token.access)) + } + _ => Ok(self.clone()), + } + } + pub fn build_url(&self, path: &str) -> String { format!("{}/{}", self.url.trim_end_matches('/'), path) } @@ -79,24 +99,98 @@ impl MDRSConnection { headers } - pub async fn get(&self, path: &str) -> reqwest::Result { - self.client - .get(self.build_url(path)) - .headers(self.prepare_headers()) + pub async fn get(&self, path: &str) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .get(conn.build_url(path)) + .headers(conn.prepare_headers()) .send() - .await + .await?) + } + + pub async fn get_with_query(&self, path: &str, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .get(conn.build_url(path)) + .headers(conn.prepare_headers()) + .query(query) + .send() + .await?) + } + + pub async fn get_url(&self, url: &str) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .get(if url.starts_with("http") { + url.to_string() + } else { + conn.build_url(url) + }) + .headers(conn.prepare_headers()) + .send() + .await?) + } + + pub async fn post_json(&self, path: &str, body: &B) -> Result + where + B: Serialize + ?Sized, + { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .post(conn.build_url(path)) + .headers(conn.prepare_headers()) + .json(body) + .send() + .await?) } pub async fn post_multipart( &self, path: &str, form: reqwest::multipart::Form, - ) -> reqwest::Result { - self.client - .post(self.build_url(path)) - .headers(self.prepare_headers()) + ) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .post(conn.build_url(path)) + .headers(conn.prepare_headers()) .multipart(form) .send() - .await + .await?) + } + + pub async fn delete(&self, path: &str) -> Result { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .delete(conn.build_url(path)) + .headers(conn.prepare_headers()) + .send() + .await?) + } + + pub async fn delete_with_query( + &self, + path: &str, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + let conn = self.connection_with_fresh_token().await?; + Ok(conn + .client + .delete(conn.build_url(path)) + .headers(conn.prepare_headers()) + .query(query) + .send() + .await?) } }