import os from argparse import Namespace from concurrent.futures import ThreadPoolExecutor from typing import Any from pydantic.dataclasses import dataclass from mdrsclient.api import FilesApi, FoldersApi from mdrsclient.commands.base import BaseCommand from mdrsclient.connection import MDRSConnection from mdrsclient.exceptions import IllegalArgumentException, UnexpectedException from mdrsclient.models import File from mdrsclient.settings import CONCURRENT @dataclass(frozen=True) class DownloadFileInfo: file: File path: str @dataclass class DownloadContext: hasError: bool files: list[DownloadFileInfo] class DownloadCommand(BaseCommand): @classmethod def register(cls, parsers: Any) -> None: download_parser = parsers.add_parser("download", help="download the file or folder") download_parser.add_argument( "-r", "--recursive", help="download folders and their contents recursive", action="store_true" ) download_parser.add_argument("-p", "--password", help="password to use when open locked folder") download_parser.add_argument("remote_path", help="remote file path (remote:/lab/path/file)") download_parser.add_argument("local_path", help="local folder path (/foo/bar/)") download_parser.set_defaults(func=cls.func) @classmethod def func(cls, args: Namespace) -> None: remote_path = str(args.remote_path) local_path = str(args.local_path) is_recursive = bool(args.recursive) password = str(args.password) if args.password else None cls.download(remote_path, local_path, is_recursive, password) @classmethod def download(cls, remote_path: str, local_path: str, is_recursive: bool, password: str | None) -> None: (remote, laboratory_name, r_path) = cls._parse_remote_host_with_path(remote_path) r_path = r_path.rstrip("/") r_dirname = os.path.dirname(r_path) r_basename = os.path.basename(r_path) connection = cls._create_connection(remote) l_dirname = os.path.realpath(local_path) if not os.path.isdir(l_dirname): raise IllegalArgumentException(f"Local directory `{local_path}` not found.") laboratory = cls._find_laboratory(connection, laboratory_name) r_parent_folder = cls._find_folder(connection, laboratory, r_dirname, password) file = r_parent_folder.find_file(r_basename) if file is not None: context = DownloadContext(False, []) l_path = os.path.join(l_dirname, r_basename) context.files.append(DownloadFileInfo(file, l_path)) cls.__multiple_download(connection, context) else: folder = r_parent_folder.find_sub_folder(r_basename) if folder is None: raise IllegalArgumentException(f"File or folder `{r_path}` not found.") if not is_recursive: raise IllegalArgumentException(f"Cannot download `{r_path}`: Is a folder.") folder_api = FoldersApi(connection) cls.__multiple_download_pickup_recursive_files(connection, folder_api, folder.id, l_dirname) @classmethod def __multiple_download_pickup_recursive_files( cls, connection: MDRSConnection, folder_api: FoldersApi, folder_id: str, basedir: str ) -> None: context = DownloadContext(False, []) folder = folder_api.retrieve(folder_id) dirname = os.path.join(basedir, folder.name) if not os.path.exists(dirname): os.makedirs(dirname) print(dirname) for file in folder.files: path = os.path.join(dirname, file.name) context.files.append(DownloadFileInfo(file, path)) cls.__multiple_download(connection, context) if context.hasError: raise UnexpectedException("Some files failed to download.") for sub_folder in folder.sub_folders: cls.__multiple_download_pickup_recursive_files(connection, folder_api, sub_folder.id, dirname) @classmethod def __multiple_download(cls, connection: MDRSConnection, context: DownloadContext) -> None: file_api = FilesApi(connection) with ThreadPoolExecutor(max_workers=CONCURRENT) as pool: results = pool.map(lambda x: cls.__multiple_download_worker(file_api, x), context.files) hasError = next(filter(lambda x: x is False, results), None) if hasError is not None: context.hasError = True @classmethod def __multiple_download_worker(cls, file_api: FilesApi, info: DownloadFileInfo) -> bool: try: file_api.download(info.file, info.path) except Exception: return False print(info.path) return True