Return dotdict instead of list from copy_file, copy_dir.

This commit is contained in:
voussoir 2021-01-18 12:12:27 -08:00
parent 5c0482032d
commit 32d8137201
No known key found for this signature in database
GPG key ID: 5F7554F8C26DACCB

View file

@ -6,6 +6,7 @@ import sys
# pip install voussoirkit
from voussoirkit import bytestring
from voussoirkit import dotdict
from voussoirkit import pathclass
from voussoirkit import ratelimiter
from voussoirkit import sentinel
@ -194,7 +195,7 @@ def copy_dir(
validate_hash:
Passed directly into each `copy_file`.
Returns: [destination path, number of bytes written to destination]
Returns a dotdict containing at least `destination` and `written_bytes`.
(Written bytes is 0 if all files already existed.)
'''
# Prepare parameters
@ -233,7 +234,6 @@ def copy_dir(
files_per_second = limiter_or_none(files_per_second)
# Copy
written_bytes = 0
walker = walk_generator(
source,
callback_exclusion=callback_exclusion,
@ -275,6 +275,7 @@ def copy_dir(
yield (source_file, destination_file)
walker = denester(walker)
written_bytes = 0
for (source_file, destination_file) in walker:
if stop_event and stop_event.is_set():
@ -302,18 +303,21 @@ def copy_dir(
validate_hash=validate_hash,
)
copiedname = copied[0]
written_bytes += copied[1]
written_bytes += copied.written_bytes
if precalcsize is False:
callback_directory_progress(copiedname, written_bytes, written_bytes)
callback_directory_progress(copied.destination, written_bytes, written_bytes)
else:
callback_directory_progress(copiedname, written_bytes, total_bytes)
callback_directory_progress(copied.destination, written_bytes, total_bytes)
if files_per_second is not None:
files_per_second.limit(1)
return [destination, written_bytes]
results = dotdict.DotDict({
'destination': destination,
'written_bytes': written_bytes,
})
return results
def copy_file(
source,
@ -327,6 +331,7 @@ def copy_file(
callback_validate_hash=None,
chunk_size=CHUNK_SIZE,
dry_run=False,
hash_class=None,
overwrite_old=True,
validate_hash=False,
):
@ -376,16 +381,24 @@ def copy_file(
dry_run:
Do everything except the actual file copying.
hash_class:
If provided, should be a hashlib class. The hash will be computed while
the file is being copied, and returned in the dotdict as `hash`.
Note that if the function returns early due to dry_run or file not
needing overwrite, this won't be set, so be prepared to handle None.
If None, the hash will not be calculated.
overwrite_old:
If True, overwrite the destination file if the source file
has a more recent "last modified" timestamp.
If False, existing files will be skipped no matter what.
validate_hash:
If True, verify the file hash of the resulting file, using the
`HASH_CLASS` global.
If True, the copied file will be read back after the copy is complete,
and its hash will be compared against the hash of the source file.
If hash_class is None, then the global HASH_CLASS is used.
Returns: [destination filename, number of bytes written to destination]
Returns a dotdict containing at least `destination` and `written_bytes`.
(Written bytes is 0 if the file already existed.)
'''
# Prepare parameters
@ -412,26 +425,31 @@ def copy_file(
bytes_per_second = limiter_or_none(bytes_per_second)
results = dotdict.DotDict({
'destination': destination,
'written_bytes': 0,
}, default=None)
# Determine overwrite
if destination.exists:
if overwrite_old is False:
return [destination, 0]
if not overwrite_old:
return results
source_modtime = source.stat.st_mtime
destination_modtime = destination.stat.st_mtime
if source_modtime == destination_modtime:
return [destination, 0]
return results
# Copy
if dry_run:
if callback_progress is not None:
callback_progress(destination, 0, 0)
return [destination, 0]
return results
source_bytes = source.size
if callback_pre_copy(source, destination, dry_run=dry_run) is BAIL:
return [destination, 0]
return results
destination.parent.makedirs(exist_ok=True)
@ -452,43 +470,45 @@ def copy_file(
if source_handle is None and destination_handle:
destination_handle.close()
return [destination, 0]
return results
if destination_handle is None:
source_handle.close()
return [destination, 0]
return results
if validate_hash:
hasher = HASH_CLASS()
if hash_class is not None:
results.hash = hash_class()
elif validate_hash:
hash_class = HASH_CLASS
results.hash = HASH_CLASS()
written_bytes = 0
while True:
try:
data_chunk = source_handle.read(chunk_size)
except PermissionError as exception:
if callback_permission_denied is not None:
callback_permission_denied(source, exception)
return [destination, 0]
return results
else:
raise
data_bytes = len(data_chunk)
if data_bytes == 0:
break
if validate_hash:
hasher.update(data_chunk)
if results.hash:
results.hash.update(data_chunk)
destination_handle.write(data_chunk)
written_bytes += data_bytes
results.written_bytes += data_bytes
if bytes_per_second is not None:
bytes_per_second.limit(data_bytes)
callback_progress(destination, written_bytes, source_bytes)
callback_progress(destination, results.written_bytes, source_bytes)
if written_bytes == 0:
if results.written_bytes == 0:
# For zero-length files, we want to get at least one call in there.
callback_progress(destination, written_bytes, source_bytes)
callback_progress(destination, results.written_bytes, source_bytes)
# Fin
log.debug('Closing source handle.')
@ -502,11 +522,12 @@ def copy_file(
verify_hash(
destination,
callback_progress=callback_validate_hash,
hash_class=hash_class,
known_hash=results.hash.hexdigest(),
known_size=source_bytes,
known_hash=hasher.hexdigest(),
)
return [destination, written_bytes]
return results
def do_nothing(*args, **kwargs):
'''
@ -534,7 +555,7 @@ def hash_file(
path,
hash_class=HASH_CLASS,
*,
callback_progress=do_nothing,
callback_progress=None,
chunk_size=CHUNK_SIZE,
):
'''
@ -547,6 +568,8 @@ def hash_file(
checked_bytes = 0
file_size = os.path.getsize(path.absolute_path)
callback_progress = callback_progress or do_nothing
handle = path.open('rb')
with handle:
while True:
@ -611,7 +634,7 @@ def verify_hash(
**hash_kwargs,
):
path = pathclass.Path(path)
log.debug('Validating hash for "%s" against %s', path.absolute_path, known_hash)
log.debug('Validating hash for "%s" against %s.', path.absolute_path, known_hash)
if known_size is not None:
file_size = os.path.getsize(path.absolute_path)