Return dotdict instead of list from copy_file, copy_dir.

This commit is contained in:
voussoir 2021-01-18 12:12:27 -08:00
parent 5c0482032d
commit 32d8137201
No known key found for this signature in database
GPG key ID: 5F7554F8C26DACCB

View file

@ -6,6 +6,7 @@ import sys
# pip install voussoirkit # pip install voussoirkit
from voussoirkit import bytestring from voussoirkit import bytestring
from voussoirkit import dotdict
from voussoirkit import pathclass from voussoirkit import pathclass
from voussoirkit import ratelimiter from voussoirkit import ratelimiter
from voussoirkit import sentinel from voussoirkit import sentinel
@ -194,7 +195,7 @@ def copy_dir(
validate_hash: validate_hash:
Passed directly into each `copy_file`. Passed directly into each `copy_file`.
Returns: [destination path, number of bytes written to destination] Returns a dotdict containing at least `destination` and `written_bytes`.
(Written bytes is 0 if all files already existed.) (Written bytes is 0 if all files already existed.)
''' '''
# Prepare parameters # Prepare parameters
@ -233,7 +234,6 @@ def copy_dir(
files_per_second = limiter_or_none(files_per_second) files_per_second = limiter_or_none(files_per_second)
# Copy # Copy
written_bytes = 0
walker = walk_generator( walker = walk_generator(
source, source,
callback_exclusion=callback_exclusion, callback_exclusion=callback_exclusion,
@ -275,6 +275,7 @@ def copy_dir(
yield (source_file, destination_file) yield (source_file, destination_file)
walker = denester(walker) walker = denester(walker)
written_bytes = 0
for (source_file, destination_file) in walker: for (source_file, destination_file) in walker:
if stop_event and stop_event.is_set(): if stop_event and stop_event.is_set():
@ -302,18 +303,21 @@ def copy_dir(
validate_hash=validate_hash, validate_hash=validate_hash,
) )
copiedname = copied[0] written_bytes += copied.written_bytes
written_bytes += copied[1]
if precalcsize is False: if precalcsize is False:
callback_directory_progress(copiedname, written_bytes, written_bytes) callback_directory_progress(copied.destination, written_bytes, written_bytes)
else: else:
callback_directory_progress(copiedname, written_bytes, total_bytes) callback_directory_progress(copied.destination, written_bytes, total_bytes)
if files_per_second is not None: if files_per_second is not None:
files_per_second.limit(1) files_per_second.limit(1)
return [destination, written_bytes] results = dotdict.DotDict({
'destination': destination,
'written_bytes': written_bytes,
})
return results
def copy_file( def copy_file(
source, source,
@ -327,6 +331,7 @@ def copy_file(
callback_validate_hash=None, callback_validate_hash=None,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
dry_run=False, dry_run=False,
hash_class=None,
overwrite_old=True, overwrite_old=True,
validate_hash=False, validate_hash=False,
): ):
@ -376,16 +381,24 @@ def copy_file(
dry_run: dry_run:
Do everything except the actual file copying. Do everything except the actual file copying.
hash_class:
If provided, should be a hashlib class. The hash will be computed while
the file is being copied, and returned in the dotdict as `hash`.
Note that if the function returns early due to dry_run or file not
needing overwrite, this won't be set, so be prepared to handle None.
If None, the hash will not be calculated.
overwrite_old: overwrite_old:
If True, overwrite the destination file if the source file If True, overwrite the destination file if the source file
has a more recent "last modified" timestamp. has a more recent "last modified" timestamp.
If False, existing files will be skipped no matter what. If False, existing files will be skipped no matter what.
validate_hash: validate_hash:
If True, verify the file hash of the resulting file, using the If True, the copied file will be read back after the copy is complete,
`HASH_CLASS` global. and its hash will be compared against the hash of the source file.
If hash_class is None, then the global HASH_CLASS is used.
Returns: [destination filename, number of bytes written to destination] Returns a dotdict containing at least `destination` and `written_bytes`.
(Written bytes is 0 if the file already existed.) (Written bytes is 0 if the file already existed.)
''' '''
# Prepare parameters # Prepare parameters
@ -412,26 +425,31 @@ def copy_file(
bytes_per_second = limiter_or_none(bytes_per_second) bytes_per_second = limiter_or_none(bytes_per_second)
results = dotdict.DotDict({
'destination': destination,
'written_bytes': 0,
}, default=None)
# Determine overwrite # Determine overwrite
if destination.exists: if destination.exists:
if overwrite_old is False: if not overwrite_old:
return [destination, 0] return results
source_modtime = source.stat.st_mtime source_modtime = source.stat.st_mtime
destination_modtime = destination.stat.st_mtime destination_modtime = destination.stat.st_mtime
if source_modtime == destination_modtime: if source_modtime == destination_modtime:
return [destination, 0] return results
# Copy # Copy
if dry_run: if dry_run:
if callback_progress is not None: if callback_progress is not None:
callback_progress(destination, 0, 0) callback_progress(destination, 0, 0)
return [destination, 0] return results
source_bytes = source.size source_bytes = source.size
if callback_pre_copy(source, destination, dry_run=dry_run) is BAIL: if callback_pre_copy(source, destination, dry_run=dry_run) is BAIL:
return [destination, 0] return results
destination.parent.makedirs(exist_ok=True) destination.parent.makedirs(exist_ok=True)
@ -452,43 +470,45 @@ def copy_file(
if source_handle is None and destination_handle: if source_handle is None and destination_handle:
destination_handle.close() destination_handle.close()
return [destination, 0] return results
if destination_handle is None: if destination_handle is None:
source_handle.close() source_handle.close()
return [destination, 0] return results
if validate_hash: if hash_class is not None:
hasher = HASH_CLASS() results.hash = hash_class()
elif validate_hash:
hash_class = HASH_CLASS
results.hash = HASH_CLASS()
written_bytes = 0
while True: while True:
try: try:
data_chunk = source_handle.read(chunk_size) data_chunk = source_handle.read(chunk_size)
except PermissionError as exception: except PermissionError as exception:
if callback_permission_denied is not None: if callback_permission_denied is not None:
callback_permission_denied(source, exception) callback_permission_denied(source, exception)
return [destination, 0] return results
else: else:
raise raise
data_bytes = len(data_chunk) data_bytes = len(data_chunk)
if data_bytes == 0: if data_bytes == 0:
break break
if validate_hash: if results.hash:
hasher.update(data_chunk) results.hash.update(data_chunk)
destination_handle.write(data_chunk) destination_handle.write(data_chunk)
written_bytes += data_bytes results.written_bytes += data_bytes
if bytes_per_second is not None: if bytes_per_second is not None:
bytes_per_second.limit(data_bytes) bytes_per_second.limit(data_bytes)
callback_progress(destination, written_bytes, source_bytes) callback_progress(destination, results.written_bytes, source_bytes)
if written_bytes == 0: if results.written_bytes == 0:
# For zero-length files, we want to get at least one call in there. # For zero-length files, we want to get at least one call in there.
callback_progress(destination, written_bytes, source_bytes) callback_progress(destination, results.written_bytes, source_bytes)
# Fin # Fin
log.debug('Closing source handle.') log.debug('Closing source handle.')
@ -502,11 +522,12 @@ def copy_file(
verify_hash( verify_hash(
destination, destination,
callback_progress=callback_validate_hash, callback_progress=callback_validate_hash,
hash_class=hash_class,
known_hash=results.hash.hexdigest(),
known_size=source_bytes, known_size=source_bytes,
known_hash=hasher.hexdigest(),
) )
return [destination, written_bytes] return results
def do_nothing(*args, **kwargs): def do_nothing(*args, **kwargs):
''' '''
@ -534,7 +555,7 @@ def hash_file(
path, path,
hash_class=HASH_CLASS, hash_class=HASH_CLASS,
*, *,
callback_progress=do_nothing, callback_progress=None,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
): ):
''' '''
@ -547,6 +568,8 @@ def hash_file(
checked_bytes = 0 checked_bytes = 0
file_size = os.path.getsize(path.absolute_path) file_size = os.path.getsize(path.absolute_path)
callback_progress = callback_progress or do_nothing
handle = path.open('rb') handle = path.open('rb')
with handle: with handle:
while True: while True:
@ -611,7 +634,7 @@ def verify_hash(
**hash_kwargs, **hash_kwargs,
): ):
path = pathclass.Path(path) path = pathclass.Path(path)
log.debug('Validating hash for "%s" against %s', path.absolute_path, known_hash) log.debug('Validating hash for "%s" against %s.', path.absolute_path, known_hash)
if known_size is not None: if known_size is not None:
file_size = os.path.getsize(path.absolute_path) file_size = os.path.getsize(path.absolute_path)