Skip to content
Snippets Groups Projects
Commit 6fb52060 authored by Morian Sonnet's avatar Morian Sonnet
Browse files

Add custom progress callbacks for RDS3 resource

parent 7ca18e55
No related branches found
No related tags found
1 merge request!23Add support for custom progress callbacks for RDSS3 storage
......@@ -78,20 +78,6 @@ class ResourceQuota:
def serialize(self) -> dict:
return self._data
def progress_callback(
progress_bar: tqdm,
bytes_read: int,
fn: Callable[[int], None] | None = None
) -> None:
"""
Updates the progress bar and calls a callback if one has been specified.
"""
progress_bar.update(bytes_read - progress_bar.n)
if fn:
fn(bytes_read)
class ResourceTypeOptions:
"""
Options and settings regarding the resource type.
......@@ -621,7 +607,7 @@ class Resource:
self.post_metadata(metadata)
assert isinstance(handle, IOBase)
if self.type.general_type == "rdss3" and self.client.native:
self._upload_blob_s3(path, handle)
self._upload_blob_s3(path, handle, progress)
else:
self._upload_blob(path, handle, progress)
......@@ -700,10 +686,16 @@ class Resource:
unit="B", unit_scale=True, ascii=True,
disable=not self.client.verbose
)
def progress_callback(mon):
nonlocal progress_bar, progress
progress_bar.update(mon.bytes_read - progress_bar.n)
if progress:
progress(mon.bytes_read)
monitor = MultipartEncoderMonitor(
encoder,
lambda mon:
progress_callback(progress_bar, mon.bytes_read, progress)
progress_callback
)
headers = {"Content-Type": monitor.content_type}
if use_put:
......@@ -711,7 +703,7 @@ class Resource:
else:
self.client.post(uri, data=monitor, headers=headers)
def _upload_blob_s3(self, path: str, handle: BinaryIO) -> None:
def _upload_blob_s3(self, path: str, handle: BinaryIO, progress: Callable[[int], None] | None = None) -> None:
"""
Works only on rdss3 resources and should not be called
on other resource types! Bypasses Coscine and uploads
......@@ -727,11 +719,20 @@ class Resource:
aws_secret_access_key=self.type.options.secret_key_write,
endpoint_url=self.type.options.endpoint
)
bytes_read_abs = 0
def progress_callback(bytes_read_inc):
nonlocal progress_bar, progress, bytes_read_abs
progress_bar.update(bytes_read_inc)
bytes_read_abs += bytes_read_inc
if progress:
progress(bytes_read_abs)
s3.upload_fileobj(
handle,
self.type.options.bucket_name,
path,
Callback=progress_bar.update
Callback=progress_callback
)
def _fetch_files_recursively(self, path: str = "") -> list[FileObject]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment