# -*- coding: utf-8 -*-
from __future__ import annotations
import itertools
import logging
import mimetypes
import os.path
import re
from concurrent.futures import Future, as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Pattern, Tuple, Union, cast
import botocore.exceptions
from boto3 import Session
from botocore import UNSIGNED
from botocore.client import BaseClient, Config
from fsspec import AbstractFileSystem
from fsspec.callbacks import _DEFAULT_CALLBACK
from fsspec.spec import AbstractBufferedFile
from fsspec.utils import tokenize
import pyathena
from pyathena.filesystem.s3_object import (
S3CompleteMultipartUpload,
S3MultipartUpload,
S3MultipartUploadPart,
S3Object,
S3ObjectType,
S3PutObject,
S3StorageClass,
)
from pyathena.util import RetryConfig, retry_api_call
if TYPE_CHECKING:
from pyathena.connection import Connection
_logger = logging.getLogger(__name__) # type: ignore
[docs]
class S3FileSystem(AbstractFileSystem):
"""A filesystem interface for Amazon S3 that implements the fsspec protocol.
This class provides a file-system like interface to Amazon S3, allowing you to
use familiar file operations (ls, open, cp, rm, etc.) with S3 objects. It's
designed to be compatible with s3fs while offering PyAthena-specific optimizations.
The filesystem supports standard S3 operations including:
- Listing objects and directories
- Reading and writing files
- Copying and moving objects
- Creating and removing directories
- Multipart uploads for large files
- Various S3 storage classes and encryption options
Attributes:
session: The boto3 session used for S3 operations.
client: The S3 client for direct API calls.
config: Boto3 configuration for the client.
retry_config: Configuration for retry behavior on failed operations.
Example:
>>> from pyathena.filesystem.s3 import S3FileSystem
>>> fs = S3FileSystem()
>>>
>>> # List objects in a bucket
>>> files = fs.ls('s3://my-bucket/data/')
>>>
>>> # Read a file
>>> with fs.open('s3://my-bucket/data/file.csv', 'r') as f:
... content = f.read()
>>>
>>> # Write a file
>>> with fs.open('s3://my-bucket/output/result.txt', 'w') as f:
... f.write('Hello, S3!')
>>>
>>> # Copy files
>>> fs.cp('s3://source-bucket/file.txt', 's3://dest-bucket/file.txt')
Note:
This filesystem is used internally by PyAthena for handling query results
stored in S3, but can also be used independently for S3 file operations.
"""
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html
# The minimum size of a part in a multipart upload is 5MiB.
MULTIPART_UPLOAD_MIN_PART_SIZE: int = 5 * 2**20 # 5MiB
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html
# The maximum size of a part in a multipart upload is 5GiB.
MULTIPART_UPLOAD_MAX_PART_SIZE: int = 5 * 2**30 # 5GiB
# https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html
DELETE_OBJECTS_MAX_KEYS: int = 1000
DEFAULT_BLOCK_SIZE: int = 5 * 2**20 # 5MiB
PATTERN_PATH: Pattern[str] = re.compile(
r"(^s3://|^s3a://|^)(?P<bucket>[a-zA-Z0-9.\-_]+)(/(?P<key>[^?]+)|/)?"
r"($|\?version(Id|ID|id|_id)=(?P<version_id>.+)$)"
)
protocol = ("s3", "s3a")
_extra_tokenize_attributes = ("default_block_size",)
[docs]
def __init__(
self,
connection: Optional["Connection[Any]"] = None,
default_block_size: Optional[int] = None,
default_cache_type: Optional[str] = None,
max_workers: int = (cpu_count() or 1) * 5,
s3_additional_kwargs=None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if connection:
self._client = connection.session.client(
"s3",
region_name=connection.region_name,
config=connection.config,
**connection._client_kwargs,
)
self._retry_config = connection.retry_config
else:
self._client = self._get_client_compatible_with_s3fs(**kwargs)
self._retry_config = RetryConfig()
self.default_block_size = (
default_block_size if default_block_size else self.DEFAULT_BLOCK_SIZE
)
self.default_cache_type = default_cache_type if default_cache_type else "bytes"
self.max_workers = max_workers
self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {}
requester_pays = kwargs.pop("requester_pays", False)
self.request_kwargs = {"RequestPayer": "requester"} if requester_pays else {}
def _get_client_compatible_with_s3fs(self, **kwargs) -> BaseClient:
"""
https://github.com/fsspec/s3fs/blob/2023.4.0/s3fs/core.py#L457-L535
"""
from pyathena.connection import Connection
config_kwargs = deepcopy(kwargs.pop("config_kwargs", {}))
user_agent_extra = config_kwargs.pop("user_agent_extra", None)
if user_agent_extra:
if pyathena.user_agent_extra not in user_agent_extra:
config_kwargs.update(
{"user_agent_extra": f"{pyathena.user_agent_extra} {user_agent_extra}"}
)
else:
config_kwargs.update({"user_agent_extra": user_agent_extra})
else:
config_kwargs.update({"user_agent_extra": pyathena.user_agent_extra})
connect_timeout = kwargs.pop("connect_timeout", None)
if connect_timeout:
config_kwargs.update({"connect_timeout": connect_timeout})
read_timeout = kwargs.pop("read_timeout", None)
if read_timeout:
config_kwargs.update({"read_timeout": read_timeout})
client_kwargs = deepcopy(kwargs.pop("client_kwargs", {}))
use_ssl = kwargs.pop("use_ssl", None)
if use_ssl:
client_kwargs.update({"use_ssl": use_ssl})
endpoint_url = kwargs.pop("endpoint_url", None)
if endpoint_url:
client_kwargs.update({"endpoint_url": endpoint_url})
anon = kwargs.pop("anon", False)
if anon:
config_kwargs.update({"signature_version": UNSIGNED})
else:
creds = {
"aws_access_key_id": kwargs.pop("key", kwargs.pop("username", None)),
"aws_secret_access_key": kwargs.pop("secret", kwargs.pop("password", None)),
"aws_session_token": kwargs.pop("token", None),
}
kwargs.update(**creds)
client_kwargs.update(**creds)
config = Config(**config_kwargs)
session = Session(
**{k: v for k, v in kwargs.items() if k in Connection._SESSION_PASSING_ARGS}
)
return session.client(
"s3",
config=config,
**{k: v for k, v in client_kwargs.items() if k in Connection._CLIENT_PASSING_ARGS},
)
[docs]
@staticmethod
def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]:
match = S3FileSystem.PATTERN_PATH.search(path)
if match:
return match.group("bucket"), match.group("key"), match.group("version_id")
raise ValueError(f"Invalid S3 path format {path}.")
def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]:
if bucket not in self.dircache or refresh:
try:
self._call(
self._client.head_bucket,
Bucket=bucket,
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ["NoSuchKey", "NoSuchBucket", "404"]:
return None
raise
file = S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_BUCKET,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=None,
version_id=None,
)
self.dircache[bucket] = file
else:
file = self.dircache[bucket]
return file
def _head_object(
self, path: str, version_id: Optional[str] = None, refresh: bool = False
) -> Optional[S3Object]:
bucket, key, path_version_id = self.parse_path(path)
version_id = path_version_id if path_version_id else version_id
if path not in self.dircache or refresh:
try:
request = {
"Bucket": bucket,
"Key": key,
}
if version_id:
request.update({"VersionId": version_id})
response = self._call(
self._client.head_object,
**request,
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ["NoSuchKey", "NoSuchBucket", "404"]:
return None
raise
file = S3Object(
init=response,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=key,
version_id=version_id,
)
self.dircache[path] = file
else:
file = self.dircache[path]
return file
def _ls_buckets(self, refresh: bool = False) -> List[S3Object]:
if "" not in self.dircache or refresh:
response = self._call(
self._client.list_buckets,
)
buckets = [
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_BUCKET,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=b["Name"],
key=None,
version_id=None,
)
for b in response["Buckets"]
]
self.dircache[""] = buckets
else:
buckets = self.dircache[""]
return buckets
def _ls_dirs(
self,
path: str,
prefix: str = "",
delimiter: str = "/",
next_token: Optional[str] = None,
max_keys: Optional[int] = None,
refresh: bool = False,
) -> List[S3Object]:
bucket, key, version_id = self.parse_path(path)
if key:
prefix = f"{key}/{prefix if prefix else ''}"
# Create a cache key that includes the delimiter
cache_key = (path, delimiter)
if cache_key in self.dircache and not refresh:
return cast(List[S3Object], self.dircache[cache_key])
files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
)
for c in response.get("CommonPrefixes", [])
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
)
for c in response.get("Contents", [])
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[cache_key] = files
return files
[docs]
def ls(
self, path: str, detail: bool = False, refresh: bool = False, **kwargs
) -> Union[List[S3Object], List[str]]:
"""List contents of an S3 path.
Lists buckets (when path is root) or objects within a bucket/prefix.
Compatible with fsspec interface for filesystem operations.
Args:
path: S3 path to list (e.g., "s3://bucket" or "s3://bucket/prefix").
detail: If True, return S3Object instances; if False, return paths as strings.
refresh: If True, bypass cache and fetch fresh results from S3.
**kwargs: Additional arguments (ignored for S3).
Returns:
List of S3Object instances (if detail=True) or paths as strings (if detail=False).
Example:
>>> fs = S3FileSystem()
>>> fs.ls("s3://my-bucket") # List objects in bucket
>>> fs.ls("s3://my-bucket/", detail=True) # Get detailed object info
"""
path = self._strip_protocol(path).rstrip("/")
if path in ["", "/"]:
files = self._ls_buckets(refresh)
else:
files = self._ls_dirs(path, refresh=refresh)
if not files and "/" in path:
file = self._head_object(path, refresh=refresh)
if file:
files = [file]
return list(files) if detail else [f.name for f in files]
[docs]
def info(self, path: str, **kwargs) -> S3Object:
refresh = kwargs.pop("refresh", False)
path = self._strip_protocol(path)
bucket, key, path_version_id = self.parse_path(path)
version_id = path_version_id if path_version_id else kwargs.pop("version_id", None)
if path in ["/", ""]:
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_BUCKET,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=None,
version_id=None,
)
if not refresh:
caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path)
if caches is not None:
if isinstance(caches, list):
cache = next((c for c in caches if c.name == path), None)
elif caches.name == path:
cache = caches
else:
cache = None
if cache:
return cache
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
)
if key:
object_info = self._head_object(path, refresh=refresh, version_id=version_id)
if object_info:
return object_info
else:
bucket_info = self._head_bucket(path, refresh=refresh)
if bucket_info:
return bucket_info
raise FileNotFoundError(path)
response = self._call(
self._client.list_objects_v2,
Bucket=bucket,
Prefix=f"{key.rstrip('/')}/" if key else "",
Delimiter="/",
MaxKeys=1,
)
if (
response.get("KeyCount", 0) > 0
or response.get("Contents", [])
or response.get("CommonPrefixes", [])
):
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
)
raise FileNotFoundError(path)
def _extract_parent_directories(
self, files: List[S3Object], bucket: str, base_key: Optional[str]
) -> List[S3Object]:
"""Extract parent directory objects from file paths.
When listing files without delimiter, S3 doesn't return directory entries.
This method creates directory objects by analyzing file paths.
Args:
files: List of S3Object instances representing files.
bucket: S3 bucket name.
base_key: Base key path to calculate relative paths from.
Returns:
List of S3Object instances representing directories.
"""
dirs = set()
base_key = base_key.rstrip("/") if base_key else ""
for f in files:
if f.key and f.type == S3ObjectType.S3_OBJECT_TYPE_FILE:
# Extract directory paths from file paths
f_key = f.key
if base_key and f_key.startswith(base_key + "/"):
relative_path = f_key[len(base_key) + 1 :]
elif not base_key:
relative_path = f_key
else:
continue
# Get all parent directories
parts = relative_path.split("/")
for i in range(1, len(parts)):
if base_key:
dir_path = base_key + "/" + "/".join(parts[:i])
else:
dir_path = "/".join(parts[:i])
dirs.add(dir_path)
# Create S3Object instances for directories
directory_objects = []
for dir_path in dirs:
dir_obj = S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=dir_path,
version_id=None,
)
directory_objects.append(dir_obj)
return directory_objects
def _find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
**kwargs,
) -> List[S3Object]:
path = self._strip_protocol(path)
if path in ["", "/"]:
raise ValueError("Cannot traverse all files in S3.")
bucket, key, _ = self.parse_path(path)
prefix = kwargs.pop("prefix", "")
# When maxdepth is specified, use a recursive approach with delimiter
if maxdepth is not None:
result: List[S3Object] = []
# List files and directories at current level
current_items = self._ls_dirs(path, prefix=prefix, delimiter="/")
for item in current_items:
if item.type == S3ObjectType.S3_OBJECT_TYPE_FILE:
# Add files
result.append(item)
elif item.type == S3ObjectType.S3_OBJECT_TYPE_DIRECTORY:
# Add directory if withdirs is True
if withdirs:
result.append(item)
# Recursively explore subdirectory if depth allows
if maxdepth > 0:
sub_path = f"s3://{bucket}/{item.key}"
sub_results = self._find(
sub_path, maxdepth=maxdepth - 1, withdirs=withdirs, **kwargs
)
result.extend(sub_results)
return result
# For unlimited depth, use the original approach (get all files at once)
files = self._ls_dirs(path, prefix=prefix, delimiter="")
if not files and key:
try:
files = [self.info(path)]
except FileNotFoundError:
files = []
# If withdirs is True, we need to derive directories from file paths
if withdirs:
files.extend(self._extract_parent_directories(files, bucket, key))
# Filter directories if withdirs is False (default)
if withdirs is False or withdirs is None:
files = [f for f in files if f.type != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY]
return files
[docs]
def find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
"""Find all files below a given S3 path.
Recursively searches for files under the specified path, with optional
depth limiting and directory inclusion. Uses efficient S3 list operations
with delimiter handling for performance.
Args:
path: S3 path to search under (e.g., "s3://bucket/prefix").
maxdepth: Maximum depth to recurse (None for unlimited).
withdirs: Whether to include directories in results (None = default behavior).
detail: If True, return dict of {path: S3Object}; if False, return list of paths.
**kwargs: Additional arguments.
Returns:
Dictionary mapping paths to S3Objects (if detail=True) or
list of paths (if detail=False).
Example:
>>> fs = S3FileSystem()
>>> fs.find("s3://bucket/data/", maxdepth=2) # Limit depth
>>> fs.find("s3://bucket/", withdirs=True) # Include directories
"""
files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs)
if detail:
return {f.name: f for f in files}
return [f.name for f in files]
[docs]
def exists(self, path: str, **kwargs) -> bool:
"""Check if an S3 path exists.
Determines whether a bucket, object, or prefix exists in S3.
Uses caching and efficient head operations to minimize API calls.
Args:
path: S3 path to check (e.g., "s3://bucket" or "s3://bucket/key").
**kwargs: Additional arguments (unused).
Returns:
True if the path exists, False otherwise.
Example:
>>> fs = S3FileSystem()
>>> fs.exists("s3://my-bucket/file.txt")
>>> fs.exists("s3://my-bucket/")
"""
path = self._strip_protocol(path)
if path in ["", "/"]:
# The root always exists.
return True
bucket, key, _ = self.parse_path(path)
if key:
try:
if self._ls_from_cache(path):
return True
info = self.info(path)
return bool(info)
except FileNotFoundError:
return False
elif self.dircache.get(bucket, False):
return True
else:
try:
if self._ls_from_cache(bucket):
return True
except FileNotFoundError:
pass
file = self._head_bucket(bucket)
return bool(file)
[docs]
def rm_file(self, path: str, **kwargs) -> None:
bucket, key, version_id = self.parse_path(path)
if not key:
return
self._delete_object(bucket=bucket, key=key, version_id=version_id, **kwargs)
self.invalidate_cache(path)
[docs]
def rm(self, path, recursive=False, maxdepth=None, **kwargs) -> None:
bucket, key, version_id = self.parse_path(path)
if not key:
raise ValueError("Cannot delete the bucket.")
expand_path = self.expand_path(path, recursive=recursive, maxdepth=maxdepth)
self._delete_objects(bucket, expand_path, **kwargs)
for p in expand_path:
self.invalidate_cache(p)
def _delete_object(
self, bucket: str, key: str, version_id: Optional[str] = None, **kwargs
) -> None:
request = {
"Bucket": bucket,
"Key": key,
}
if version_id:
request.update({"VersionId": version_id})
_logger.debug(f"Delete object: s3://{bucket}/{key}?versionId={version_id}")
self._call(
self._client.delete_object,
**request,
)
def _delete_objects(
self, bucket: str, paths: List[str], max_workers: Optional[int] = None, **kwargs
) -> None:
if not paths:
return
max_workers = max_workers if max_workers else self.max_workers
quiet = kwargs.pop("Quiet", True)
delete_objects = []
for p in paths:
bucket, key, version_id = self.parse_path(p)
if key:
object_ = {"Key": key}
if version_id:
object_.update({"VersionId": version_id})
delete_objects.append(object_)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
fs = []
for delete in [
delete_objects[i : i + self.DELETE_OBJECTS_MAX_KEYS]
for i in range(0, len(delete_objects), self.DELETE_OBJECTS_MAX_KEYS)
]:
request = {
"Bucket": bucket,
"Delete": {
"Objects": delete,
"Quiet": quiet,
},
}
fs.append(
executor.submit(self._call, self._client.delete_objects, **request, **kwargs)
)
for f in as_completed(fs):
f.result()
[docs]
def touch(self, path: str, truncate: bool = True, **kwargs) -> Dict[str, Any]:
bucket, key, version_id = self.parse_path(path)
if version_id:
raise ValueError("Cannot touch the file with the version specified.")
if not truncate and self.exists(path):
raise ValueError("Cannot touch the existing file without specifying truncate.")
if not key:
raise ValueError("Cannot touch the bucket.")
object_ = self._put_object(bucket=bucket, key=key, body=None, **kwargs)
self.invalidate_cache(path)
return object_.to_dict()
[docs]
def cp_file(
self, path1: str, path2: str, recursive=False, maxdepth=None, on_error=None, **kwargs
):
"""Copy an S3 object to another S3 location.
Performs server-side copy of S3 objects, which is more efficient than
downloading and re-uploading. Automatically chooses between simple copy
and multipart copy based on object size.
Args:
path1: Source S3 path (s3://bucket/key).
path2: Destination S3 path (s3://bucket/key).
recursive: Unused parameter for fsspec compatibility.
maxdepth: Unused parameter for fsspec compatibility.
on_error: Unused parameter for fsspec compatibility.
**kwargs: Additional S3 copy parameters (e.g., metadata, storage class).
Raises:
ValueError: If trying to copy to a versioned file or copy buckets.
Note:
Uses multipart copy for objects larger than the maximum part size
to optimize performance for large files. The copy operation is
performed entirely on the S3 service without data transfer.
"""
# TODO: Delete the value that seems to be a typo, onerror=false.
# https://github.com/fsspec/filesystem_spec/commit/346a589fef9308550ffa3d0d510f2db67281bb05
# https://github.com/fsspec/filesystem_spec/blob/2024.10.0/fsspec/spec.py#L1185
# https://github.com/fsspec/filesystem_spec/blob/2024.10.0/fsspec/spec.py#L1077
kwargs.pop("onerror", None)
bucket1, key1, version_id1 = self.parse_path(path1)
bucket2, key2, version_id2 = self.parse_path(path2)
if version_id2:
raise ValueError("Cannot copy to a versioned file.")
if not key1 or not key2:
raise ValueError("Cannot copy buckets.")
info1 = self.info(path1)
size1 = info1.get("size", 0)
if size1 <= self.MULTIPART_UPLOAD_MAX_PART_SIZE:
self._copy_object(
bucket1=bucket1,
key1=key1,
version_id1=version_id1,
bucket2=bucket2,
key2=key2,
**kwargs,
)
else:
self._copy_object_with_multipart_upload(
bucket1=bucket1,
key1=key1,
version_id1=version_id1,
size1=size1,
bucket2=bucket2,
key2=key2,
**kwargs,
)
self.invalidate_cache(path2)
def _copy_object(
self,
bucket1: str,
key1: str,
version_id1: Optional[str],
bucket2: str,
key2: str,
**kwargs,
) -> None:
copy_source = {
"Bucket": bucket1,
"Key": key1,
}
if version_id1:
copy_source.update({"VersionId": version_id1})
request = {
"CopySource": copy_source,
"Bucket": bucket2,
"Key": key2,
}
_logger.debug(
f"Copy object from s3://{bucket1}/{key1}?versionId={version_id1} "
f"to s3://{bucket2}/{key2}."
)
self._call(self._client.copy_object, **request, **kwargs)
def _copy_object_with_multipart_upload(
self,
bucket1: str,
key1: str,
size1: int,
bucket2: str,
key2: str,
max_workers: Optional[int] = None,
block_size: Optional[int] = None,
version_id1: Optional[str] = None,
**kwargs,
) -> None:
max_workers = max_workers if max_workers else self.max_workers
block_size = block_size if block_size else self.MULTIPART_UPLOAD_MAX_PART_SIZE
if (
block_size < self.MULTIPART_UPLOAD_MIN_PART_SIZE
or block_size > self.MULTIPART_UPLOAD_MAX_PART_SIZE
):
raise ValueError("Block size must be greater than 5MiB and less than 5GiB.")
copy_source = {
"Bucket": bucket1,
"Key": key1,
}
if version_id1:
copy_source.update({"VersionId": version_id1})
ranges = S3File._get_ranges(
0,
size1,
max_workers,
block_size,
)
multipart_upload = self._create_multipart_upload(
bucket=bucket2,
key=key2,
**kwargs,
)
parts = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
fs = [
executor.submit(
self._upload_part_copy,
bucket=bucket2,
key=key2,
copy_source=copy_source,
upload_id=cast(str, multipart_upload.upload_id),
part_number=i + 1,
copy_source_ranges=range_,
)
for i, range_ in enumerate(ranges)
]
for f in as_completed(fs):
result = f.result()
parts.append(
{
"ETag": result.etag,
"PartNumber": result.part_number,
}
)
parts.sort(key=lambda x: x["PartNumber"]) # type: ignore
self._complete_multipart_upload(
bucket=bucket2,
key=key2,
upload_id=cast(str, multipart_upload.upload_id),
parts=parts,
)
[docs]
def cat_file(
self, path: str, start: Optional[int] = None, end: Optional[int] = None, **kwargs
) -> bytes:
bucket, key, version_id = self.parse_path(path)
if start is not None or end is not None:
size = self.info(path).get("size", 0)
if start is None:
range_start = 0
elif start < 0:
range_start = size + start
else:
range_start = start
if end is None:
range_end = size
elif end < 0:
range_end = size + end
else:
range_end = end
ranges = (range_start, range_end)
else:
ranges = None
return self._get_object(
bucket=bucket,
key=cast(str, key),
ranges=ranges,
version_id=version_id,
**kwargs,
)[1]
[docs]
def put_file(self, lpath: str, rpath: str, callback=_DEFAULT_CALLBACK, **kwargs):
"""Upload a local file to S3.
Uploads a file from the local filesystem to an S3 location. Supports
automatic content type detection based on file extension and provides
progress callback functionality.
Args:
lpath: Local file path to upload.
rpath: S3 destination path (s3://bucket/key).
callback: Progress callback for tracking upload progress.
**kwargs: Additional S3 parameters (e.g., ContentType, StorageClass).
Note:
Directories are not supported for upload. If lpath is a directory,
the method returns without performing any operation. Bucket-only
destinations (without key) are also not supported.
"""
if os.path.isdir(lpath):
# No support for directory uploads.
return
bucket, key, _ = self.parse_path(rpath)
if not key:
# No support for bucket copy.
return
size = os.path.getsize(lpath)
callback.set_size(size)
if "ContentType" not in kwargs:
content_type, _ = mimetypes.guess_type(lpath)
if content_type is not None:
kwargs["ContentType"] = content_type
with (
self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote,
open(lpath, "rb") as local,
):
while data := local.read(remote.blocksize):
remote.write(data)
callback.relative_update(len(data))
self.invalidate_cache(rpath)
[docs]
def get_file(self, rpath: str, lpath: str, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs):
"""Download an S3 file to local filesystem.
Downloads a file from S3 to the local filesystem with progress tracking.
Reads the file in chunks to handle large files efficiently.
Args:
rpath: S3 source path (s3://bucket/key).
lpath: Local destination file path.
callback: Progress callback for tracking download progress.
outfile: Unused parameter for fsspec compatibility.
**kwargs: Additional S3 parameters passed to open().
Note:
If lpath is a directory, the method returns without performing
any operation.
"""
if os.path.isdir(lpath):
return
with open(lpath, "wb") as local, self.open(rpath, "rb", **kwargs) as remote:
callback.set_size(remote.size)
while data := remote.read(remote.blocksize):
local.write(data)
callback.relative_update(len(data))
[docs]
def checksum(self, path: str, **kwargs):
"""Get checksum for S3 object or directory.
Computes a checksum for the specified S3 path. For individual objects,
returns the ETag converted to an integer. For directories, returns a
checksum based on the directory's tokenized representation.
Args:
path: S3 path (s3://bucket/key) to get checksum for.
**kwargs: Additional arguments including:
refresh: If True, refresh cached info before computing checksum.
Returns:
Integer checksum value derived from S3 ETag or directory token.
Note:
For multipart uploads, ETag format is different and only the first
part before the dash is used for checksum calculation.
"""
refresh = kwargs.pop("refresh", False)
info = self.info(path, refresh=refresh)
if info.get("type") != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY:
return int(info.get("etag").strip('"').split("-")[0], 16)
return int(tokenize(info), 16)
[docs]
def sign(self, path: str, expiration: int = 3600, **kwargs):
"""Generate a presigned URL for S3 object access.
Creates a presigned URL that allows temporary access to an S3 object
without requiring AWS credentials. Useful for sharing files or providing
time-limited access to resources.
Args:
path: S3 path (s3://bucket/key) to generate URL for.
expiration: URL expiration time in seconds. Defaults to 3600 (1 hour).
**kwargs: Additional parameters including:
client_method: S3 operation ('get_object', 'put_object', etc.).
Defaults to 'get_object'.
Additional parameters passed to the S3 operation.
Returns:
Presigned URL string that provides temporary access to the S3 object.
Example:
>>> fs = S3FileSystem()
>>> url = fs.sign("s3://my-bucket/file.txt", expiration=7200)
>>> # URL valid for 2 hours
>>>
>>> # Generate upload URL
>>> upload_url = fs.sign(
... "s3://my-bucket/upload.txt",
... client_method="put_object"
... )
"""
bucket, key, version_id = self.parse_path(path)
client_method = kwargs.pop("client_method", "get_object")
params = {"Bucket": bucket, "Key": key}
if version_id:
params.update({"VersionId": version_id})
if kwargs:
params.update(kwargs)
request = {
"ClientMethod": client_method,
"Params": params,
"ExpiresIn": expiration,
}
_logger.debug(f"Generate signed url: s3://{bucket}/{key}?versionId={version_id}")
return self._call(
self._client.generate_presigned_url,
**request,
)
[docs]
def created(self, path: str) -> datetime:
return self.modified(path)
[docs]
def modified(self, path: str) -> datetime:
info = self.info(path)
return cast(datetime, info.get("last_modified"))
[docs]
def invalidate_cache(self, path: Optional[str] = None) -> None:
if path is None:
self.dircache.clear()
else:
path = self._strip_protocol(path)
while path:
self.dircache.pop(path, None)
path = self._parent(path)
def _open(
self,
path: str,
mode: str = "rb",
block_size: Optional[int] = None,
cache_type: Optional[str] = None,
autocommit: bool = True,
cache_options: Optional[Dict[Any, Any]] = None,
**kwargs,
) -> S3File:
if block_size is None:
block_size = self.default_block_size
if cache_type is None:
cache_type = self.default_cache_type
max_workers = kwargs.pop("max_worker", self.max_workers)
s3_additional_kwargs = kwargs.pop("s3_additional_kwargs", {})
s3_additional_kwargs.update(self.s3_additional_kwargs)
return S3File(
self,
path,
mode,
version_id=None,
max_workers=max_workers,
block_size=block_size,
cache_type=cache_type,
autocommit=autocommit,
cache_options=cache_options,
s3_additional_kwargs=s3_additional_kwargs,
**kwargs,
)
def _get_object(
self,
bucket: str,
key: str,
ranges: Optional[Tuple[int, int]] = None,
version_id: Optional[str] = None,
**kwargs,
) -> Tuple[int, bytes]:
request = {"Bucket": bucket, "Key": key}
if ranges:
range_ = S3File._format_ranges(ranges)
request.update({"Range": range_})
else:
ranges = (0, 0)
range_ = "bytes=0-"
if version_id:
request.update({"VersionId": version_id})
_logger.debug(f"Get object: s3://{bucket}/{key}?versionId={version_id}&range={range_}")
response = self._call(
self._client.get_object,
**request,
**kwargs,
)
return ranges[0], cast(bytes, response["Body"].read())
def _put_object(self, bucket: str, key: str, body: Optional[bytes], **kwargs) -> S3PutObject:
request: Dict[str, Any] = {"Bucket": bucket, "Key": key}
if body:
request.update({"Body": body})
_logger.debug(f"Put object: s3://{bucket}/{key}")
response = self._call(
self._client.put_object,
**request,
**kwargs,
)
return S3PutObject(response)
def _create_multipart_upload(self, bucket: str, key: str, **kwargs) -> S3MultipartUpload:
request = {
"Bucket": bucket,
"Key": key,
}
_logger.debug(f"Create multipart upload to s3://{bucket}/{key}.")
response = self._call(
self._client.create_multipart_upload,
**request,
**kwargs,
)
return S3MultipartUpload(response)
def _upload_part_copy(
self,
bucket: str,
key: str,
copy_source: Union[str, Dict[str, Any]],
upload_id: str,
part_number: int,
copy_source_ranges: Optional[Tuple[int, int]] = None,
**kwargs,
) -> S3MultipartUploadPart:
request = {
"Bucket": bucket,
"Key": key,
"CopySource": copy_source,
"UploadId": upload_id,
"PartNumber": part_number,
}
if copy_source_ranges:
range_ = S3File._format_ranges(copy_source_ranges)
request.update({"CopySourceRange": range_})
_logger.debug(
f"Upload part copy from {copy_source} to s3://{bucket}/{key} as part {part_number}."
)
response = self._call(
self._client.upload_part_copy,
**request,
**kwargs,
)
return S3MultipartUploadPart(part_number, response)
def _upload_part(
self,
bucket: str,
key: str,
upload_id: str,
part_number: int,
body: bytes,
**kwargs,
) -> S3MultipartUploadPart:
request = {
"Bucket": bucket,
"Key": key,
"UploadId": upload_id,
"PartNumber": part_number,
"Body": body,
}
_logger.debug(f"Upload part of {upload_id} to s3://{bucket}/{key} as part {part_number}.")
response = self._call(
self._client.upload_part,
**request,
**kwargs,
)
return S3MultipartUploadPart(part_number, response)
def _complete_multipart_upload(
self, bucket: str, key: str, upload_id: str, parts: List[Dict[str, Any]], **kwargs
) -> S3CompleteMultipartUpload:
request = {
"Bucket": bucket,
"Key": key,
"UploadId": upload_id,
"MultipartUpload": {"Parts": parts},
}
_logger.debug(f"Complete multipart upload {upload_id} to s3://{bucket}/{key}.")
response = self._call(
self._client.complete_multipart_upload,
**request,
**kwargs,
)
return S3CompleteMultipartUpload(response)
def _call(self, method: Union[str, Callable[..., Any]], **kwargs) -> Dict[str, Any]:
func = getattr(self._client, method) if isinstance(method, str) else method
response = retry_api_call(
func, config=self._retry_config, logger=_logger, **kwargs, **self.request_kwargs
)
return cast(Dict[str, Any], response)
class S3File(AbstractBufferedFile):
def __init__(
self,
fs: S3FileSystem,
path: str,
mode: str = "rb",
version_id: Optional[str] = None,
max_workers: int = (cpu_count() or 1) * 5,
block_size: int = S3FileSystem.DEFAULT_BLOCK_SIZE,
cache_type: str = "bytes",
autocommit: bool = True,
cache_options: Optional[Dict[Any, Any]] = None,
size: Optional[int] = None,
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
self.max_workers = max_workers
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {}
super().__init__(
fs=fs,
path=path,
mode=mode,
block_size=block_size,
autocommit=autocommit,
cache_type=cache_type,
cache_options=cache_options,
size=size,
)
bucket, key, path_version_id = S3FileSystem.parse_path(path)
self.bucket = bucket
if not key:
raise ValueError("The path does not contain a key.")
self.key = key
if version_id and path_version_id:
if version_id != path_version_id:
raise ValueError(
f"The version_id: {version_id} specified in the argument and "
f"the version_id: {path_version_id} specified in the path do not match."
)
self.version_id: Optional[str] = version_id
elif path_version_id:
self.version_id = path_version_id
else:
self.version_id = version_id
if "r" not in mode and block_size < self.fs.MULTIPART_UPLOAD_MIN_PART_SIZE:
# When writing occurs, the block size should not be smaller
# than the minimum size of a part in a multipart upload.
raise ValueError(f"Block size must be >= {self.fs.MULTIPART_UPLOAD_MIN_PART_SIZE}MB.")
self.append_block = False
if "r" in mode:
info = self.fs.info(self.path, version_id=self.version_id)
if etag := info.get("etag"):
self.s3_additional_kwargs.update({"IfMatch": etag})
self._details = info
elif "a" in mode and self.fs.exists(path):
self.append_block = True
info = self.fs.info(self.path, version_id=self.version_id)
loc = info.get("size", 0)
if loc < self.fs.MULTIPART_UPLOAD_MIN_PART_SIZE:
self.write(self.fs.cat(self.path))
self.loc = loc
self.s3_additional_kwargs.update(info.to_api_repr())
self._details = info
else:
self._details = {}
self.multipart_upload: Optional[S3MultipartUpload] = None
self.multipart_upload_parts: List[Future[S3MultipartUploadPart]] = []
def close(self) -> None:
super().close()
self._executor.shutdown()
def _initiate_upload(self) -> None:
if self.tell() < self.blocksize:
# Files smaller than block size in size cannot be multipart uploaded.
return
self.multipart_upload = self.fs._create_multipart_upload(
bucket=self.bucket,
key=self.key,
**self.s3_additional_kwargs,
)
if self.append_block:
if self.tell() > S3FileSystem.MULTIPART_UPLOAD_MAX_PART_SIZE:
info = self.fs.info(self.path, version_id=self.version_id)
ranges = self._get_ranges(
0,
# Set copy source file byte size
info.get("size", 0),
self.max_workers,
S3FileSystem.MULTIPART_UPLOAD_MAX_PART_SIZE,
)
for i, range_ in enumerate(ranges):
self.multipart_upload_parts.append(
self._executor.submit(
self.fs._upload_part_copy,
bucket=self.bucket,
key=self.key,
copy_source=self.path,
upload_id=cast(
str, cast(S3MultipartUpload, self.multipart_upload).upload_id
),
part_number=i + 1,
copy_source_ranges=range_,
)
)
else:
self.multipart_upload_parts.append(
self._executor.submit(
self.fs._upload_part_copy,
bucket=self.bucket,
key=self.key,
copy_source=self.path,
upload_id=cast(
str, cast(S3MultipartUpload, self.multipart_upload).upload_id
),
part_number=1,
)
)
def _upload_chunk(self, final: bool = False) -> bool:
if self.tell() < self.blocksize:
# Files smaller than block size in size cannot be multipart uploaded.
if self.autocommit and final:
self.commit()
return True
if not self.multipart_upload:
raise RuntimeError("Multipart upload is not initialized.")
part_number = len(self.multipart_upload_parts)
self.buffer.seek(0)
while data := self.buffer.read(self.blocksize):
# The last part of a multipart request should be adjusted
# to be larger than the minimum part size.
next_data = self.buffer.read(self.blocksize)
next_data_size = len(next_data)
if 0 < next_data_size < self.fs.MULTIPART_UPLOAD_MIN_PART_SIZE:
upload_data = data + next_data
upload_data_size = len(upload_data)
if upload_data_size < self.fs.MULTIPART_UPLOAD_MAX_PART_SIZE:
uploads = [upload_data]
else:
split_size = upload_data_size // 2
uploads = [upload_data[:split_size], upload_data[split_size:]]
else:
uploads = [data]
if next_data:
uploads.append(next_data)
for upload in uploads:
part_number += 1
self.multipart_upload_parts.append(
self._executor.submit(
self.fs._upload_part,
bucket=self.bucket,
key=self.key,
upload_id=cast(str, self.multipart_upload.upload_id),
part_number=part_number,
body=upload,
)
)
if not next_data:
break
if self.autocommit and final:
self.commit()
return True
def commit(self) -> None:
if self.tell() == 0:
if self.buffer is not None:
self.discard()
self.fs.touch(self.path, **self.s3_additional_kwargs)
elif not self.multipart_upload_parts:
if self.buffer is not None:
# Upload files smaller than block size.
self.buffer.seek(0)
data = self.buffer.read()
self.fs._put_object(
bucket=self.bucket,
key=self.key,
body=data,
**self.s3_additional_kwargs,
)
else:
if not self.multipart_upload:
raise RuntimeError("Multipart upload is not initialized.")
parts: List[Dict[str, Any]] = []
for f in as_completed(self.multipart_upload_parts):
result = f.result()
parts.append(
{
"ETag": result.etag,
"PartNumber": result.part_number,
}
)
parts.sort(key=lambda x: x["PartNumber"])
self.fs._complete_multipart_upload(
bucket=self.bucket,
key=self.key,
upload_id=cast(str, self.multipart_upload.upload_id),
parts=parts,
)
self.fs.invalidate_cache(self.path)
def discard(self) -> None:
if self.multipart_upload:
for f in self.multipart_upload_parts:
f.cancel()
self.fs._call(
"abort_multipart_upload",
Bucket=self.bucket,
Key=self.key,
UploadId=self.multipart_upload.upload_id,
**self.s3_additional_kwargs,
)
self.multipart_upload = None
self.multipart_upload_parts = []
def _fetch_range(self, start: int, end: int) -> bytes:
ranges = self._get_ranges(
start, end, max_workers=self.max_workers, worker_block_size=self.blocksize
)
if len(ranges) > 1:
object_ = self._merge_objects(
list(
self._executor.map(
lambda bucket, key, ranges, version_id, kwargs: self.fs._get_object(
bucket=bucket,
key=key,
ranges=ranges,
version_id=version_id,
**kwargs,
),
itertools.repeat(self.bucket),
itertools.repeat(self.key),
ranges,
itertools.repeat(self.version_id),
itertools.repeat(self.s3_additional_kwargs),
)
)
)
else:
object_ = self.fs._get_object(
self.bucket,
self.key,
ranges[0],
self.version_id,
**self.s3_additional_kwargs,
)[1]
return object_
@staticmethod
def _format_ranges(ranges: Tuple[int, int]):
return f"bytes={ranges[0]}-{ranges[1] - 1}"
@staticmethod
def _get_ranges(
start: int, end: int, max_workers: int, worker_block_size: int
) -> List[Tuple[int, int]]:
ranges = []
range_size = end - start
if max_workers > 1 and range_size > worker_block_size:
range_start = start
while True:
range_end = range_start + worker_block_size
if range_end > end:
ranges.append((range_start, end))
break
ranges.append((range_start, range_end))
range_start += worker_block_size
else:
ranges.append((start, end))
return ranges
@staticmethod
def _merge_objects(objects: List[Tuple[int, bytes]]) -> bytes:
objects.sort(key=lambda x: x[0])
return b"".join([obj for start, obj in objects])