diff --git a/src/api/files.rs b/src/api/files.rs index 6ed8338..d6692ec 100644 --- a/src/api/files.rs +++ b/src/api/files.rs @@ -1,4 +1,4 @@ -use crate::connection::MDRSConnection; +use crate::connection::{ApiRequestLimiter, MDRSConnection}; pub use crate::models::file::File; use anyhow::bail; use unicode_normalization::UnicodeNormalization; @@ -34,8 +34,43 @@ impl MDRSConnection { Ok(all_files) } - /// Upload a local file into the given remote folder. - pub async fn upload_file(&self, folder_id: &str, file_path: &str) -> Result<(), anyhow::Error> { + /// List all files in a folder while consuming the shared API concurrency budget. + pub async fn list_all_files_limited( + &self, + folder_id: &str, + limiter: &ApiRequestLimiter, + ) -> Result, anyhow::Error> { + let mut all_files = Vec::new(); + let mut page: u32 = 1; + loop { + let params = [ + ("folder_id", folder_id.to_string()), + ("page", page.to_string()), + ]; + let _permit = limiter.acquire().await?; + let resp = self.get_with_query("v3/files/", ¶ms).await?; + if !resp.status().is_success() { + anyhow::bail!("List files failed: {}", resp.status()); + } + let list: FileListResponse = resp.json().await?; + let has_next = list.next.is_some(); + all_files.extend(list.results); + if !has_next { + break; + } + page += 1; + } + Ok(all_files) + } + + /// Upload a local file into the given remote folder while consuming the + /// shared API concurrency budget. + pub async fn upload_file_limited( + &self, + folder_id: &str, + file_path: &str, + limiter: &ApiRequestLimiter, + ) -> Result<(), anyhow::Error> { use anyhow::{anyhow, bail}; use reqwest::multipart; let file_name: String = std::path::Path::new(file_path) @@ -49,6 +84,7 @@ impl MDRSConnection { let form = multipart::Form::new() .text("folder_id", folder_id.to_string()) .part("file", part); + let _permit = limiter.acquire().await?; let resp = self.post_multipart("v3/files/", form).await?; if !resp.status().is_success() { bail!("Upload failed: {}", resp.status()); @@ -56,13 +92,20 @@ impl MDRSConnection { Ok(()) } - /// Download a file from `url` and write it to `dest`. - pub async fn download_file(&self, url: &str, dest: &str) -> Result<(), anyhow::Error> { + /// Download a file while consuming the shared API concurrency budget. + pub async fn download_file_limited( + &self, + url: &str, + dest: &str, + limiter: &ApiRequestLimiter, + ) -> Result<(), anyhow::Error> { + let _permit = limiter.acquire().await?; let resp = self.get_url(url).await?; if !resp.status().is_success() { bail!("Download failed: {}", resp.status()); } let bytes = resp.bytes().await?; + drop(_permit); tokio::fs::write(dest, &bytes).await?; Ok(()) } diff --git a/src/api/folders.rs b/src/api/folders.rs index de7182d..49a4d26 100644 --- a/src/api/folders.rs +++ b/src/api/folders.rs @@ -1,6 +1,6 @@ -use crate::connection::MDRSConnection; +use crate::connection::{ApiRequestLimiter, MDRSConnection}; pub use crate::models::folder::{FolderDetail, FolderSimple}; -use anyhow::bail; +use anyhow::{anyhow, bail}; impl MDRSConnection { /// List folders matching the given path under a laboratory (GET v3/folders/?path=...&laboratory_id=...) @@ -20,6 +20,25 @@ impl MDRSConnection { Ok(resp.json::>().await?) } + /// List folders by path while consuming the shared API concurrency budget. + pub async fn list_folders_by_path_limited( + &self, + lab_id: u32, + path: &str, + limiter: &ApiRequestLimiter, + ) -> Result, anyhow::Error> { + let params = [ + ("laboratory_id", lab_id.to_string()), + ("path", path.to_string()), + ]; + let _permit = limiter.acquire().await?; + let resp = self.get_with_query("v3/folders/", ¶ms).await?; + if !resp.status().is_success() { + bail!("List folders failed: {}", resp.status()); + } + Ok(resp.json::>().await?) + } + /// Retrieve full folder details including sub_folders (GET v3/folders/{id}/) pub async fn retrieve_folder(&self, id: &str) -> Result { let resp = self.get(&format!("v3/folders/{}/", id)).await?; @@ -29,6 +48,20 @@ impl MDRSConnection { Ok(resp.json::().await?) } + /// Retrieve folder details while consuming the shared API concurrency budget. + pub async fn retrieve_folder_limited( + &self, + id: &str, + limiter: &ApiRequestLimiter, + ) -> Result { + let _permit = limiter.acquire().await?; + let resp = self.get(&format!("v3/folders/{}/", id)).await?; + if !resp.status().is_success() { + bail!("Retrieve folder failed: {}", resp.status()); + } + Ok(resp.json::().await?) + } + /// Create a new folder under `parent_id` (POST v3/folders/). pub async fn create_folder( &self, @@ -44,6 +77,32 @@ impl MDRSConnection { self.post_json("v3/folders/", &body).await } + /// Create a new folder under `parent_id` and return its ID while consuming + /// the shared API concurrency budget. + pub async fn create_folder_id_limited( + &self, + parent_id: &str, + folder_name: &str, + limiter: &ApiRequestLimiter, + ) -> Result { + let body = serde_json::json!({ + "name": folder_name, + "parent_id": parent_id, + "description": "", + "template_id": -1, + }); + let _permit = limiter.acquire().await?; + let resp = self.post_json("v3/folders/", &body).await?; + if !resp.status().is_success() { + bail!("Failed to create remote folder: {}", folder_name); + } + let json: serde_json::Value = resp.json().await?; + json["id"] + .as_str() + .ok_or_else(|| anyhow!("No id in create_folder response for {}", folder_name)) + .map(|s| s.to_string()) + } + /// Authenticate against a password-locked folder (POST v3/folders/{id}/auth/). /// Returns `Err` if the password is incorrect or the request fails. pub async fn folder_auth(&self, folder_id: &str, password: &str) -> Result<(), anyhow::Error> { @@ -61,4 +120,28 @@ impl MDRSConnection { } Ok(()) } + + /// Authenticate against a locked folder while consuming the shared API + /// concurrency budget. + pub async fn folder_auth_limited( + &self, + folder_id: &str, + password: &str, + limiter: &ApiRequestLimiter, + ) -> Result<(), anyhow::Error> { + let _permit = limiter.acquire().await?; + let resp = self + .post_json( + &format!("v3/folders/{}/auth/", folder_id), + &serde_json::json!({"password": password}), + ) + .await?; + if resp.status() == reqwest::StatusCode::UNAUTHORIZED { + bail!("Password is incorrect."); + } + if !resp.status().is_success() { + bail!("Folder auth failed: {}", resp.status()); + } + Ok(()) + } } diff --git a/src/commands/download.rs b/src/commands/download.rs index 6194f4e..d7a202d 100644 --- a/src/commands/download.rs +++ b/src/commands/download.rs @@ -1,12 +1,13 @@ use crate::cache::{create_authenticated_conn, load_cache_with_token_refresh}; use crate::commands::shared::{ - find_file_by_name, find_folder, find_lab_in_cache, find_subfolder_by_name, parse_remote_path, + find_file_by_name, find_folder_limited, find_lab_in_cache, find_subfolder_by_name, + parse_remote_path, }; -use crate::connection::MDRSConnection; +use crate::connection::{ApiRequestLimiter, MDRSConnection}; use anyhow::{anyhow, bail}; -use futures::stream::{FuturesUnordered, StreamExt}; use std::path::PathBuf; use std::sync::Arc; +use tokio::task::JoinSet; pub async fn download( remote_path: &str, @@ -19,6 +20,7 @@ pub async fn download( let (remote, labname, r_path) = parse_remote_path(remote_path)?; let cache = load_cache_with_token_refresh(&remote).await?; let conn = Arc::new(create_authenticated_conn(&remote, &cache)?); + let limiter = ApiRequestLimiter::new(crate::settings::SETTINGS.concurrent); let lab = find_lab_in_cache(&cache, &labname)?; // Validate that local_path is an existing directory (matching Python's behaviour). @@ -40,8 +42,11 @@ pub async fn download( None => ("/".to_string(), r_path_clean.to_string()), }; - let parent_folder = find_folder(&conn, lab.id, &parent_path, password).await?; - let files = conn.list_all_files(&parent_folder.id).await?; + let parent_folder = + find_folder_limited(&conn, &limiter, lab.id, &parent_path, password).await?; + let files = conn + .list_all_files_limited(&parent_folder.id, &limiter) + .await?; // Case 1: basename matches a file in the parent folder. if let Some(file) = find_file_by_name(&files, &basename) { @@ -61,7 +66,8 @@ pub async fn download( } } let url = make_absolute_url(&conn, &file.download_url); - conn.download_file(&url, &dest.to_string_lossy()).await?; + conn.download_file_limited(&url, &dest.to_string_lossy(), &limiter) + .await?; println!("{}", dest.display()); return Ok(()); } @@ -76,73 +82,41 @@ pub async fn download( // We create that subdirectory first, then recurse into it. let top_local = local_real.join(&sub.name); - // Iterative DFS: each entry is (remote_folder_id, local_dir) - let mut stack: Vec<(String, PathBuf)> = vec![(sub.id.clone(), top_local)]; + let mut folder_tasks: JoinSet> = + JoinSet::new(); + let mut download_tasks: JoinSet> = JoinSet::new(); + let mut errors = Vec::new(); + let excludes = Arc::new(excludes); + let lab_name = Arc::new(lab.name.clone()); + let password = password.map(str::to_string); - while let Some((folder_id, local_dir)) = stack.pop() { - let folder = conn.retrieve_folder(&folder_id).await?; + spawn_download_folder_task( + &mut folder_tasks, + conn.clone(), + limiter.clone(), + lab_name.clone(), + excludes.clone(), + sub.id.clone(), + top_local, + password.clone(), + skip_if_exists, + ); - if is_excluded(&excludes, &lab.name, &folder.path, None) { - continue; - } + drive_download_tasks( + &mut folder_tasks, + &mut download_tasks, + &mut errors, + conn.clone(), + limiter, + lab_name, + excludes, + password, + skip_if_exists, + ) + .await; - tokio::fs::create_dir_all(&local_dir).await?; - println!("{}", local_dir.display()); - - let dir_files = conn.list_all_files(&folder_id).await?; - - // Download files in this folder (up to 10 concurrent). - let mut futs: FuturesUnordered> = FuturesUnordered::new(); - for f in &dir_files { - if is_excluded(&excludes, &lab.name, &folder.path, Some(&f.name)) { - continue; - } - let dest_path = local_dir.join(&f.name); - if skip_if_exists { - if dest_path.exists() { - if let Ok(meta) = std::fs::metadata(&dest_path) { - if meta.len() == f.size { - println!("{}", dest_path.display()); - continue; - } - } - } - } - let url = make_absolute_url(&conn, &f.download_url); - let conn = conn.clone(); - futs.push(tokio::spawn(async move { - let dest_str = dest_path.to_string_lossy().to_string(); - match conn.download_file(&url, &dest_str).await { - Ok(_) => println!("{}", dest_path.display()), - Err(_) => { - eprintln!("Failed: {}", dest_path.display()); - if dest_path.is_file() { - let _ = std::fs::remove_file(&dest_path); - } - } - } - })); - if futs.len() >= crate::settings::SETTINGS.concurrent { - let _ = futs.next().await; - } - } - while futs.next().await.is_some() {} - - // Push sub-folders onto the stack for recursive processing. - for sf in &folder.sub_folders { - if sf.lock { - match password { - Some(pw) => { - if conn.folder_auth(&sf.id, pw).await.is_err() { - continue; - } - } - None => continue, - } - } - let sub_local = local_dir.join(&sf.name); - stack.push((sf.id.clone(), sub_local)); - } + if !errors.is_empty() { + bail!(errors.join("\n")); } return Ok(()); } @@ -181,3 +155,268 @@ fn make_absolute_url(conn: &MDRSConnection, url: &str) -> String { ) } } + +struct DownloadFolderTaskResult { + child_folders: Vec<(String, PathBuf)>, + download_jobs: Vec, +} + +struct DownloadJob { + url: String, + dest_path: PathBuf, +} + +fn spawn_download_folder_task( + folder_tasks: &mut JoinSet>, + conn: Arc, + limiter: ApiRequestLimiter, + lab_name: Arc, + excludes: Arc>, + folder_id: String, + local_dir: PathBuf, + password: Option, + skip_if_exists: bool, +) { + folder_tasks.spawn(async move { + process_download_folder( + conn, + limiter, + lab_name, + excludes, + folder_id, + local_dir, + password, + skip_if_exists, + ) + .await + }); +} + +fn spawn_download_task( + download_tasks: &mut JoinSet>, + conn: Arc, + limiter: ApiRequestLimiter, + job: DownloadJob, +) { + download_tasks.spawn(async move { + let dest_str = job.dest_path.to_string_lossy().to_string(); + match conn + .download_file_limited(&job.url, &dest_str, &limiter) + .await + { + Ok(()) => { + println!("{}", job.dest_path.display()); + Ok(()) + } + Err(err) => { + if job.dest_path.is_file() { + let _ = std::fs::remove_file(&job.dest_path); + } + Err(anyhow!( + "Failed to download {}: {}", + job.dest_path.display(), + err + )) + } + } + }); +} + +async fn process_download_folder( + conn: Arc, + limiter: ApiRequestLimiter, + lab_name: Arc, + excludes: Arc>, + folder_id: String, + local_dir: PathBuf, + password: Option, + skip_if_exists: bool, +) -> Result { + let folder = conn.retrieve_folder_limited(&folder_id, &limiter).await?; + + if is_excluded(excludes.as_slice(), lab_name.as_str(), &folder.path, None) { + return Ok(DownloadFolderTaskResult { + child_folders: Vec::new(), + download_jobs: Vec::new(), + }); + } + + tokio::fs::create_dir_all(&local_dir).await?; + println!("{}", local_dir.display()); + + let dir_files = conn.list_all_files_limited(&folder_id, &limiter).await?; + let mut download_jobs = Vec::new(); + for file in &dir_files { + if is_excluded( + excludes.as_slice(), + lab_name.as_str(), + &folder.path, + Some(&file.name), + ) { + continue; + } + let dest_path = local_dir.join(&file.name); + if skip_if_exists && dest_path.exists() { + if let Ok(meta) = std::fs::metadata(&dest_path) { + if meta.len() == file.size { + println!("{}", dest_path.display()); + continue; + } + } + } + download_jobs.push(DownloadJob { + url: make_absolute_url(&conn, &file.download_url), + dest_path, + }); + } + + let mut child_folder_tasks: JoinSet, anyhow::Error>> = + JoinSet::new(); + for sub_folder in folder.sub_folders { + let conn = conn.clone(); + let limiter = limiter.clone(); + let password = password.clone(); + let sub_local = local_dir.join(&sub_folder.name); + child_folder_tasks.spawn(async move { + if sub_folder.lock { + match password.as_deref() { + Some(pw) => { + if conn + .folder_auth_limited(&sub_folder.id, pw, &limiter) + .await + .is_err() + { + return Ok(None); + } + } + None => return Ok(None), + } + } + Ok(Some((sub_folder.id, sub_local))) + }); + } + + let mut child_folders = Vec::new(); + while let Some(result) = child_folder_tasks.join_next().await { + if let Some(child) = flatten_join_result(result)? { + child_folders.push(child); + } + } + + Ok(DownloadFolderTaskResult { + child_folders, + download_jobs, + }) +} + +async fn drive_download_tasks( + folder_tasks: &mut JoinSet>, + download_tasks: &mut JoinSet>, + errors: &mut Vec, + conn: Arc, + limiter: ApiRequestLimiter, + lab_name: Arc, + excludes: Arc>, + password: Option, + skip_if_exists: bool, +) { + loop { + match (folder_tasks.is_empty(), download_tasks.is_empty()) { + (true, true) => break, + (false, true) => { + if let Some(result) = folder_tasks.join_next().await { + handle_download_folder_result( + result, + folder_tasks, + download_tasks, + errors, + conn.clone(), + limiter.clone(), + lab_name.clone(), + excludes.clone(), + password.clone(), + skip_if_exists, + ); + } + } + (true, false) => { + if let Some(result) = download_tasks.join_next().await { + if let Err(err) = flatten_join_result(result) { + errors.push(err.to_string()); + } + } + } + (false, false) => { + tokio::select! { + result = folder_tasks.join_next() => { + if let Some(result) = result { + handle_download_folder_result( + result, + folder_tasks, + download_tasks, + errors, + conn.clone(), + limiter.clone(), + lab_name.clone(), + excludes.clone(), + password.clone(), + skip_if_exists, + ); + } + } + result = download_tasks.join_next() => { + if let Some(result) = result { + if let Err(err) = flatten_join_result(result) { + errors.push(err.to_string()); + } + } + } + } + } + } + } +} + +fn handle_download_folder_result( + result: Result, tokio::task::JoinError>, + folder_tasks: &mut JoinSet>, + download_tasks: &mut JoinSet>, + errors: &mut Vec, + conn: Arc, + limiter: ApiRequestLimiter, + lab_name: Arc, + excludes: Arc>, + password: Option, + skip_if_exists: bool, +) { + match flatten_join_result(result) { + Ok(task_result) => { + for (folder_id, local_dir) in task_result.child_folders { + spawn_download_folder_task( + folder_tasks, + conn.clone(), + limiter.clone(), + lab_name.clone(), + excludes.clone(), + folder_id, + local_dir, + password.clone(), + skip_if_exists, + ); + } + for job in task_result.download_jobs { + spawn_download_task(download_tasks, conn.clone(), limiter.clone(), job); + } + } + Err(err) => errors.push(err.to_string()), + } +} + +fn flatten_join_result( + result: Result, tokio::task::JoinError>, +) -> Result { + match result { + Ok(inner) => inner, + Err(err) => Err(anyhow!("Task join failed: {}", err)), + } +} diff --git a/src/commands/shared.rs b/src/commands/shared.rs index 5c14b62..3815977 100644 --- a/src/commands/shared.rs +++ b/src/commands/shared.rs @@ -1,4 +1,5 @@ use crate::cache::{Cache, CacheLaboratory}; +use crate::connection::ApiRequestLimiter; use crate::connection::MDRSConnection; use crate::models::file::File; use crate::models::folder::{FolderDetail, FolderSimple}; @@ -95,6 +96,49 @@ pub async fn find_folder( Ok(folder) } +/// Resolve a folder by path while consuming the shared API concurrency budget. +pub async fn find_folder_limited( + conn: &MDRSConnection, + limiter: &ApiRequestLimiter, + lab_id: u32, + path: &str, + password: Option<&str>, +) -> Result { + let normalized_path = nfc(path); + let folders = conn + .list_folders_by_path_limited(lab_id, &normalized_path, limiter) + .await?; + if folders.is_empty() { + bail!("Folder `{}` not found.", path); + } + if folders.len() != 1 { + bail!( + "Ambiguous path `{}`: {} folders matched.", + path, + folders.len() + ); + } + let folder_simple = &folders[0]; + if folder_simple.lock { + match password { + None => { + bail!( + "Folder `{}` is locked. Use -p/--password to provide a password.", + path + ); + } + Some(pw) => { + conn.folder_auth_limited(&folder_simple.id, pw, limiter) + .await? + } + } + } + let folder = conn + .retrieve_folder_limited(&folder_simple.id, limiter) + .await?; + Ok(folder) +} + /// Find a file by name (NFC-normalized, case-insensitive) in a file list. pub fn find_file_by_name<'a>(files: &'a [File], name: &str) -> Option<&'a File> { let name_lower = nfc(name).to_lowercase(); diff --git a/src/commands/upload.rs b/src/commands/upload.rs index 6365d62..f7cd486 100644 --- a/src/commands/upload.rs +++ b/src/commands/upload.rs @@ -1,13 +1,14 @@ use crate::cache::{create_authenticated_conn, load_cache_with_token_refresh}; use crate::commands::shared::{ - find_file_by_name, find_folder, find_lab_in_cache, nfc, parse_remote_path, + find_file_by_name, find_folder_limited, find_lab_in_cache, nfc, parse_remote_path, }; +use crate::connection::{ApiRequestLimiter, MDRSConnection}; use crate::models::folder::FolderSimple; use anyhow::{anyhow, bail}; -use futures::stream::{FuturesUnordered, StreamExt}; use std::path::PathBuf; use std::sync::Arc; use tokio::fs; +use tokio::task::JoinSet; pub async fn upload( local_path: &str, @@ -18,8 +19,9 @@ pub async fn upload( let (remote, labname, r_path) = parse_remote_path(remote_path)?; let cache = load_cache_with_token_refresh(&remote).await?; let conn = Arc::new(create_authenticated_conn(&remote, &cache)?); + let limiter = ApiRequestLimiter::new(crate::settings::SETTINGS.concurrent); let lab = find_lab_in_cache(&cache, &labname)?; - let dest_folder = find_folder(&conn, lab.id, &r_path, None).await?; + let dest_folder = find_folder_limited(&conn, &limiter, lab.id, &r_path, None).await?; // Normalize local_path: resolve to an absolute canonical path so that // trailing slashes and "./" prefixes are handled consistently (matching @@ -30,7 +32,9 @@ pub async fn upload( if local.is_file() { let filename = local.file_name().unwrap().to_string_lossy().to_string(); - let remote_files = conn.list_all_files(&dest_folder.id).await?; + let remote_files = conn + .list_all_files_limited(&dest_folder.id, &limiter) + .await?; if skip_if_exists { if let Some(rf) = find_file_by_name(&remote_files, &filename) { if rf.size == std::fs::metadata(local)?.len() { @@ -39,7 +43,7 @@ pub async fn upload( } } } - conn.upload_file(&dest_folder.id, &local.to_string_lossy()) + conn.upload_file_limited(&dest_folder.id, &local.to_string_lossy(), &limiter) .await?; println!("{}{}", dest_folder.path, filename); } else if local.is_dir() { @@ -52,75 +56,43 @@ pub async fn upload( let local_basename = local.file_name().unwrap().to_string_lossy().to_string(); let top_remote_id = find_or_create_folder( &conn, + &limiter, &dest_folder.id, &dest_folder.sub_folders, &local_basename, ) .await?; - let top_folder = conn.retrieve_folder(&top_remote_id).await?; - println!("{}", top_folder.path.trim_end_matches('/')); + println!( + "{}", + format!("{}{}", dest_folder.path, local_basename).trim_end_matches('/') + ); - // Iterative depth-first walk: each entry is (local_dir, remote_folder_id) - let mut stack: Vec<(PathBuf, String)> = vec![(local.to_path_buf(), top_remote_id)]; + let mut folder_tasks: JoinSet> = + JoinSet::new(); + let mut upload_tasks: JoinSet> = JoinSet::new(); + let mut errors = Vec::new(); - while let Some((local_dir, remote_id)) = stack.pop() { - let folder_detail = conn.retrieve_folder(&remote_id).await?; - let remote_files = conn.list_all_files(&remote_id).await?; + spawn_upload_folder_task( + &mut folder_tasks, + conn.clone(), + limiter.clone(), + local.to_path_buf(), + top_remote_id, + skip_if_exists, + ); - let mut entries = fs::read_dir(&local_dir).await?; - let mut subdirs: Vec = Vec::new(); - let mut files: Vec = Vec::new(); - while let Some(entry) = entries.next_entry().await? { - let p = entry.path(); - if p.is_dir() { - subdirs.push(p); - } else { - files.push(p); - } - } + drive_upload_tasks( + &mut folder_tasks, + &mut upload_tasks, + &mut errors, + conn.clone(), + limiter, + skip_if_exists, + ) + .await; - // Ensure each local sub-directory exists on the remote side - for subdir in subdirs { - let dirname = subdir.file_name().unwrap().to_string_lossy().to_string(); - let sub_remote_id = - find_or_create_folder(&conn, &remote_id, &folder_detail.sub_folders, &dirname) - .await?; - let sub_folder = conn.retrieve_folder(&sub_remote_id).await?; - println!("{}", sub_folder.path.trim_end_matches('/')); - stack.push((subdir, sub_remote_id)); - } - - // Upload files in this directory (up to 10 concurrent) - let mut futs: FuturesUnordered> = FuturesUnordered::new(); - for file_path in files { - let filename = file_path.file_name().unwrap().to_string_lossy().to_string(); - let file_path_str = file_path.to_string_lossy().to_string(); - if skip_if_exists { - if let Some(rf) = find_file_by_name(&remote_files, &filename) { - if let Ok(meta) = std::fs::metadata(&file_path) { - if rf.size == meta.len() { - let remote_path_prefix = folder_detail.path.clone(); - println!("{}{}", remote_path_prefix, filename); - continue; - } - } - } - } - let conn = conn.clone(); - let folder_id = remote_id.clone(); - let remote_path_prefix = folder_detail.path.clone(); - let fname = filename.clone(); - futs.push(tokio::spawn(async move { - match conn.upload_file(&folder_id, &file_path_str).await { - Ok(_) => println!("{}{}", remote_path_prefix, fname), - Err(e) => eprintln!("Error: {}", e), - } - })); - if futs.len() >= crate::settings::SETTINGS.concurrent { - let _ = futs.next().await; - } - } - while futs.next().await.is_some() {} + if !errors.is_empty() { + bail!(errors.join("\n")); } } else { bail!("File or directory `{}` not found.", local_path); @@ -131,7 +103,8 @@ pub async fn upload( /// Find an existing sub-folder by name or create it, returning its ID. async fn find_or_create_folder( - conn: &crate::connection::MDRSConnection, + conn: &MDRSConnection, + limiter: &ApiRequestLimiter, parent_id: &str, existing: &[FolderSimple], name: &str, @@ -142,13 +115,226 @@ async fn find_or_create_folder( { return Ok(sf.id.clone()); } - let resp = conn.create_folder(parent_id, &nfc(name)).await?; - if !resp.status().is_success() { - bail!("Failed to create remote folder: {}", name); - } - let json: serde_json::Value = resp.json().await?; - json["id"] - .as_str() - .ok_or_else(|| anyhow!("No id in create_folder response for {}", name)) - .map(|s| s.to_string()) + conn.create_folder_id_limited(parent_id, &nfc(name), limiter) + .await +} + +struct UploadFolderTaskResult { + child_folders: Vec<(PathBuf, String)>, + upload_jobs: Vec, +} + +struct UploadJob { + folder_id: String, + file_path: String, + remote_path: String, +} + +fn spawn_upload_folder_task( + folder_tasks: &mut JoinSet>, + conn: Arc, + limiter: ApiRequestLimiter, + local_dir: PathBuf, + remote_id: String, + skip_if_exists: bool, +) { + folder_tasks.spawn(async move { + process_upload_folder(conn, limiter, local_dir, remote_id, skip_if_exists).await + }); +} + +fn spawn_upload_task( + upload_tasks: &mut JoinSet>, + conn: Arc, + limiter: ApiRequestLimiter, + job: UploadJob, +) { + upload_tasks.spawn(async move { + conn.upload_file_limited(&job.folder_id, &job.file_path, &limiter) + .await?; + println!("{}", job.remote_path); + Ok(()) + }); +} + +async fn process_upload_folder( + conn: Arc, + limiter: ApiRequestLimiter, + local_dir: PathBuf, + remote_id: String, + skip_if_exists: bool, +) -> Result { + let (folder_detail, remote_files) = tokio::try_join!( + conn.retrieve_folder_limited(&remote_id, &limiter), + conn.list_all_files_limited(&remote_id, &limiter), + )?; + + let mut entries = fs::read_dir(&local_dir).await?; + let mut subdirs: Vec = Vec::new(); + let mut files: Vec = Vec::new(); + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + if path.is_dir() { + subdirs.push(path); + } else { + files.push(path); + } + } + + let mut subdir_tasks: JoinSet> = + JoinSet::new(); + let existing_subfolders = Arc::new(folder_detail.sub_folders.clone()); + let folder_path_prefix = folder_detail.path.clone(); + for subdir in subdirs { + let conn = conn.clone(); + let limiter = limiter.clone(); + let remote_id = remote_id.clone(); + let existing_subfolders = existing_subfolders.clone(); + let folder_path_prefix = folder_path_prefix.clone(); + subdir_tasks.spawn(async move { + let dirname = subdir.file_name().unwrap().to_string_lossy().to_string(); + let sub_remote_id = find_or_create_folder( + &conn, + &limiter, + &remote_id, + existing_subfolders.as_slice(), + &dirname, + ) + .await?; + Ok(( + subdir, + sub_remote_id, + format!("{}{}", folder_path_prefix, dirname), + )) + }); + } + + let mut child_folders = Vec::new(); + while let Some(result) = subdir_tasks.join_next().await { + let (subdir, sub_remote_id, remote_path) = flatten_join_result(result)?; + println!("{}", remote_path.trim_end_matches('/')); + child_folders.push((subdir, sub_remote_id)); + } + + let mut upload_jobs = Vec::new(); + for file_path in files { + let filename = file_path.file_name().unwrap().to_string_lossy().to_string(); + if skip_if_exists { + if let Some(rf) = find_file_by_name(&remote_files, &filename) { + if let Ok(meta) = std::fs::metadata(&file_path) { + if rf.size == meta.len() { + println!("{}{}", folder_detail.path, filename); + continue; + } + } + } + } + upload_jobs.push(UploadJob { + folder_id: remote_id.clone(), + file_path: file_path.to_string_lossy().to_string(), + remote_path: format!("{}{}", folder_detail.path, filename), + }); + } + + Ok(UploadFolderTaskResult { + child_folders, + upload_jobs, + }) +} + +async fn drive_upload_tasks( + folder_tasks: &mut JoinSet>, + upload_tasks: &mut JoinSet>, + errors: &mut Vec, + conn: Arc, + limiter: ApiRequestLimiter, + skip_if_exists: bool, +) { + loop { + match (folder_tasks.is_empty(), upload_tasks.is_empty()) { + (true, true) => break, + (false, true) => { + if let Some(result) = folder_tasks.join_next().await { + handle_upload_folder_result( + result, + folder_tasks, + upload_tasks, + errors, + conn.clone(), + limiter.clone(), + skip_if_exists, + ); + } + } + (true, false) => { + if let Some(result) = upload_tasks.join_next().await { + if let Err(err) = flatten_join_result(result) { + errors.push(err.to_string()); + } + } + } + (false, false) => { + tokio::select! { + result = folder_tasks.join_next() => { + if let Some(result) = result { + handle_upload_folder_result( + result, + folder_tasks, + upload_tasks, + errors, + conn.clone(), + limiter.clone(), + skip_if_exists, + ); + } + } + result = upload_tasks.join_next() => { + if let Some(result) = result { + if let Err(err) = flatten_join_result(result) { + errors.push(err.to_string()); + } + } + } + } + } + } + } +} + +fn handle_upload_folder_result( + result: Result, tokio::task::JoinError>, + folder_tasks: &mut JoinSet>, + upload_tasks: &mut JoinSet>, + errors: &mut Vec, + conn: Arc, + limiter: ApiRequestLimiter, + skip_if_exists: bool, +) { + match flatten_join_result(result) { + Ok(task_result) => { + for (local_dir, remote_id) in task_result.child_folders { + spawn_upload_folder_task( + folder_tasks, + conn.clone(), + limiter.clone(), + local_dir, + remote_id, + skip_if_exists, + ); + } + for job in task_result.upload_jobs { + spawn_upload_task(upload_tasks, conn.clone(), limiter.clone(), job); + } + } + Err(err) => errors.push(err.to_string()), + } +} + +fn flatten_join_result( + result: Result, tokio::task::JoinError>, +) -> Result { + match result { + Ok(inner) => inner, + Err(err) => Err(anyhow!("Task join failed: {}", err)), + } } diff --git a/src/connection.rs b/src/connection.rs index e265c06..3f1726a 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,8 @@ use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT}; use reqwest::{Client, Response}; use serde::Serialize; +use std::sync::Arc; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; fn build_user_agent() -> String { let info = os_info::get(); @@ -36,6 +38,27 @@ pub struct MDRSConnection { pub token: Option, } +#[derive(Clone)] +pub struct ApiRequestLimiter { + semaphore: Arc, +} + +impl ApiRequestLimiter { + pub fn new(limit: usize) -> Self { + Self { + semaphore: Arc::new(Semaphore::new(limit.max(1))), + } + } + + pub async fn acquire(&self) -> Result { + self.semaphore + .clone() + .acquire_owned() + .await + .map_err(|_| anyhow::anyhow!("API request limiter was closed.")) + } +} + impl MDRSConnection { pub fn new(url: &str) -> Self { MDRSConnection { diff --git a/src/settings.rs b/src/settings.rs index 5690c7b..652aeaa 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -4,7 +4,7 @@ pub struct Settings { /// Base directory for config and cache files. /// Controlled by `MDRS_CLIENT_CONFIG_DIRNAME` env var (default: `~/.mdrs-client`). pub config_dirname: std::path::PathBuf, - /// Maximum number of concurrent upload/download workers. + /// Maximum number of concurrent MDRS API requests used by upload/download. /// Controlled by `MDRS_CLIENT_CONCURRENT` env var (default: 10). pub concurrent: usize, }