from typing import IO, Generator, Tuple, Union, overload from pathlib import Path, PosixPath, PurePosixPath import io import os import re import requests import fnmatch from azure.identity import DefaultAzureCredential from azure.storage.blob import ContainerClient, BlobClient import requests.adapters import requests.packages from urllib3.util.retry import Retry __all__ = [ 'download_blob', 'upload_blob', 'download_blob_with_cache', 'open_blob', 'open_blob_with_cache', 'blob_file_exists', 'AzureBlobPath','SmartPath' ] DEFAULT_CREDENTIAL = DefaultAzureCredential() BLOB_CACHE_DIR = './.blobcache' def download_blob(blob: Union[str, BlobClient]) -> bytes: if isinstance(blob, str): blob_client = BlobClient.from_blob_url(blob_client) else: blob_client = blob return blob_client.download_blob().read() def upload_blob(blob: Union[str, BlobClient], data: Union[str, bytes]): if isinstance(blob, str): blob_client = BlobClient.from_blob_url(blob) else: blob_client = blob blob_client.upload_blob(data, overwrite=True) def download_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> bytes: """ Download a blob file from a container and return its content as bytes. If the file is already present in the cache, it is read from there. """ cache_path = Path(cache_dir) / blob_name if cache_path.exists(): return cache_path.read_bytes() data = download_blob(container, blob_name) cache_path.parent.mkdir(parents=True, exist_ok=True) cache_path.write_bytes(data) return data def open_blob(container: Union[str, ContainerClient], blob_name: str) -> io.BytesIO: """ Open a blob file for reading from a container and return its content as a BytesIO object. """ return io.BytesIO(download_blob(container, blob_name)) def open_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> io.BytesIO: """ Open a blob file for reading from a container and return its content as a BytesIO object. If the file is already present in the cache, it is read from there. """ return io.BytesIO(download_blob_with_cache(container, blob_name, cache_dir=cache_dir)) def blob_file_exists(container: Union[str, ContainerClient], blob_name: str) -> bool: """ Check if a blob file exists in a container. """ if isinstance(container, str): container = ContainerClient.from_container_url(container) blob_client = container.get_blob_client(blob_name) return blob_client.exists() def is_blob_url(url: str) -> bool: return re.match(r'https://[^/]+blob.core.windows.net/+', url) is not None def split_blob_url(url: str) -> Tuple[str, str, str]: match = re.match(r'(https://[^/]+blob.core.windows.net/[^/?]+)(/([^\?]*))?(\?.+)?', url) if match: container, _, path, sas = match.groups() return container, path or '', sas or '' raise ValueError(f'Not a valid blob URL: {url}') def join_blob_path(url: str, *others: str) -> str: container, path, sas = split_blob_url(url) return container + '/' + os.path.join(path, *others) + sas class AzureBlobStringWriter(io.StringIO): def __init__(self, blob_client: BlobClient, encoding: str = 'utf-8', **kwargs): self._encoding = encoding self.blob_client = blob_client self.kwargs = kwargs super().__init__() def close(self): self.blob_client.upload_blob(self.getvalue().encode(self._encoding), blob_type='BlockBlob', overwrite=True, **self.kwargs) class AzureBlobBytesWriter(io.BytesIO): def __init__(self, blob_client: BlobClient, **kwargs): super().__init__() self.blob_client = blob_client self.kwargs = kwargs def close(self): self.blob_client.upload_blob(self.getvalue(), blob_type='BlockBlob', overwrite=True, **self.kwargs) def open_azure_blob(blob: Union[str, BlobClient], mode: str = 'r', encoding: str = 'utf-8', newline: str = None, cache_blob: bool = False, **kwargs) -> IO: if isinstance(blob, str): blob_client = BlobClient.from_blob_url(blob) elif isinstance(blob, BlobClient): blob_client = blob else: raise ValueError(f'Must be a blob URL or a BlobClient object: {blob}') if cache_blob: cache_path = Path(BLOB_CACHE_DIR, blob_client.account_name, blob_client.container_name, blob_client.blob_name) if mode == 'r' or mode == 'rb': if cache_blob: if cache_path.exists(): data = cache_path.read_bytes() else: data = blob_client.download_blob(**kwargs).read() cache_path.parent.mkdir(parents=True, exist_ok=True) cache_path.write_bytes(data) else: data = blob_client.download_blob(**kwargs).read() if mode == 'r': return io.StringIO(data.decode(encoding), newline=newline) else: return io.BytesIO(data) elif mode == 'w': return AzureBlobStringWriter(blob_client, **kwargs) elif mode == 'wb': return AzureBlobBytesWriter(blob_client, **kwargs) else: raise ValueError(f'Unsupported mode: {mode}') def smart_open(path_or_url: Union[Path, str], mode: str = 'r', encoding: str = 'utf-8') -> IO: if is_blob_url(str(path_or_url)): return open_azure_blob(str(path_or_url), mode, encoding) return open(path_or_url, mode, encoding) class AzureBlobPath(PurePosixPath): """ Implementation of pathlib.Path like interface for Azure Blob Storage. """ container_client: ContainerClient _parse_path = PurePosixPath._parse_args if hasattr(PurePosixPath, '_parse_args') else PurePosixPath._parse_path def __new__(cls, *args, **kwargs): """Override the old __new__ method. Parts are parsed in __init__""" return object.__new__(cls) def __init__(self, root: Union[str, 'AzureBlobPath', ContainerClient], *others: Union[str, PurePosixPath], pool_maxsize: int = 256, retries: int = 3): if isinstance(root, AzureBlobPath): self.container_client = root.container_client parts = root.parts + others elif isinstance(root, str): url = root container, path, sas = split_blob_url(url) session = self._get_session(pool_maxsize=pool_maxsize, retries=retries) if sas: self.container_client = ContainerClient.from_container_url(container + sas, session=session) else: self.container_client = ContainerClient.from_container_url(container, credential=DEFAULT_CREDENTIAL, session=session) parts = (path, *others) elif isinstance(root, ContainerClient): self.container_client = root parts = others else: raise ValueError(f'Invalid root: {root}') if hasattr(PurePosixPath, '_parse_args'): # For compatibility with Python 3.10 drv, root, parts = PurePosixPath._parse_args(parts) self._drv = drv self._root = root self._parts = parts else: super().__init__(*parts) def _get_session(self, pool_maxsize: int = 1024, retries: int = 3) -> requests.Session: session = requests.Session() retry_strategy = Retry( total=retries, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["HEAD", "GET", "PUT", "DELETE"], backoff_factor=1, raise_on_status=False, read=retries, connect=retries, redirect=retries, ) adapter = requests.adapters.HTTPAdapter(pool_connections=pool_maxsize, pool_maxsize=pool_maxsize, max_retries=retry_strategy) session.mount('http://', adapter) session.mount('https://', adapter) return session def _from_parsed_parts(self, drv, root, parts): "For compatibility with Python 3.10" return AzureBlobPath(self.container_client, drv, root, *parts) def with_segments(self, *pathsegments): return AzureBlobPath(self.container_client, *pathsegments) @property def path(self) -> str: return '/'.join(self.parts) @property def blob_client(self) -> BlobClient: return self.container_client.get_blob_client(self.path) @property def url(self) -> str: if len(self.parts) == 0: return self.container_client.url return self.container_client.get_blob_client(self.path).url @property def container_name(self) -> str: return self.container_client.container_name @property def account_name(self) -> str: return self.container_client.account_name def __str__(self): return self.url def __repr__(self): return self.url def open(self, mode: str = 'r', encoding: str = 'utf-8', cache_blob: bool = False, **kwargs) -> IO: return open_azure_blob(self.blob_client, mode, encoding, cache_blob=cache_blob, **kwargs) def __truediv__(self, other: Union[str, Path]) -> 'AzureBlobPath': return self.joinpath(other) def mkdir(self, parents: bool = False, exist_ok: bool = False): pass def iterdir(self) -> Generator['AzureBlobPath', None, None]: path = self.path if not path.endswith('/'): path += '/' for item in self.container_client.walk_blobs(self.path): yield AzureBlobPath(self.container_client, item.name) def glob(self, pattern: str) -> Generator['AzureBlobPath', None, None]: special_chars = ".^$+{}[]()|/" for char in special_chars: pattern = pattern.replace(char, "\\" + char) pattern = pattern.replace('**', './/.') pattern = pattern.replace('*', '[^/]*') pattern = pattern.replace('.//.', '.*') pattern = "^" + pattern + "$" reg = re.compile(pattern) for item in self.container_client.list_blobs(self.path): if reg.match(os.path.relpath(item.name, self.path)): yield AzureBlobPath(self.container_client, item.name) def exists(self) -> bool: return self.blob_client.exists() def read_bytes(self, cache_blob: bool = False) -> bytes: with self.open('rb', cache_blob=cache_blob) as f: return f.read() def read_text(self, encoding: str = 'utf-8', cache_blob: bool = False) -> str: with self.open('r', encoding=encoding, cache_blob=cache_blob) as f: return f.read() def write_bytes(self, data: bytes): self.blob_client.upload_blob(data, overwrite=True) def write_text(self, data: str, encoding: str = 'utf-8'): self.blob_client.upload_blob(data.encode(encoding), overwrite=True) def unlink(self): self.blob_client.delete_blob() def new_client(self) -> 'AzureBlobPath': return AzureBlobPath(self.container_client.url, self.path) class SmartPath(Path, AzureBlobPath): """ Supports both local file paths and Azure Blob Storage URLs. """ def __new__(cls, first: Union[Path, str], *others: Union[str, PurePosixPath]) -> Union[Path, AzureBlobPath]: if is_blob_url(str(first)): return AzureBlobPath(str(first), *others) return Path(first, *others)