Add some type annotations, document some exceptions.

master
voussoir 2021-08-31 19:25:42 -07:00
parent 2e0b4cfa14
commit e883409daf
No known key found for this signature in database
GPG Key ID: 5F7554F8C26DACCB
3 changed files with 355 additions and 208 deletions

View File

@ -7,6 +7,7 @@ import hashlib
import mimetypes import mimetypes
import os import os
import PIL.Image import PIL.Image
import typing
import zipstream import zipstream
from voussoirkit import bytestring from voussoirkit import bytestring
@ -123,7 +124,7 @@ def album_photos_as_filename_map(
return arcnames return arcnames
def checkerboard_image(color_1, color_2, image_size, checker_size): def checkerboard_image(color_1, color_2, image_size, checker_size) -> PIL.Image:
''' '''
Generate a PIL Image with a checkerboard pattern. Generate a PIL Image with a checkerboard pattern.
@ -200,10 +201,10 @@ def decollide_names(things, namer):
final[thing] = myname final[thing] = myname
return final return final
def dict_to_tuple(d): def dict_to_tuple(d) -> tuple:
return tuple(sorted(d.items())) return tuple(sorted(d.items()))
def generate_image_thumbnail(filepath, width, height): def generate_image_thumbnail(filepath, width, height) -> PIL.Image:
if not os.path.isfile(filepath): if not os.path.isfile(filepath):
raise FileNotFoundError(filepath) raise FileNotFoundError(filepath)
image = PIL.Image.open(filepath) image = PIL.Image.open(filepath)
@ -234,7 +235,7 @@ def generate_image_thumbnail(filepath, width, height):
image = image.convert('RGB') image = image.convert('RGB')
return image return image
def generate_video_thumbnail(filepath, outfile, width, height, **special): def generate_video_thumbnail(filepath, outfile, width, height, **special) -> PIL.Image:
if not os.path.isfile(filepath): if not os.path.isfile(filepath):
raise FileNotFoundError(filepath) raise FileNotFoundError(filepath)
probe = constants.ffmpeg.probe(filepath) probe = constants.ffmpeg.probe(filepath)
@ -267,7 +268,7 @@ def generate_video_thumbnail(filepath, outfile, width, height, **special):
) )
return True return True
def get_mimetype(filepath): def get_mimetype(filepath) -> typing.Optional[str]:
''' '''
Extension to mimetypes.guess_type which uses my Extension to mimetypes.guess_type which uses my
constants.ADDITIONAL_MIMETYPES. constants.ADDITIONAL_MIMETYPES.
@ -278,7 +279,7 @@ def get_mimetype(filepath):
mimetype = mimetypes.guess_type(filepath)[0] mimetype = mimetypes.guess_type(filepath)[0]
return mimetype return mimetype
def hash_photoset(photos): def hash_photoset(photos) -> str:
''' '''
Given some photos, return a fingerprint string for that particular set. Given some photos, return a fingerprint string for that particular set.
''' '''
@ -290,7 +291,7 @@ def hash_photoset(photos):
return hasher.hexdigest() return hasher.hexdigest()
def hyphen_range(s): def hyphen_range(s) -> tuple:
''' '''
Given a string like '1-3', return numbers (1, 3) representing lower Given a string like '1-3', return numbers (1, 3) representing lower
and upper bounds. and upper bounds.
@ -319,9 +320,9 @@ def hyphen_range(s):
if low is not None and high is not None and low > high: if low is not None and high is not None and low > high:
raise exceptions.OutOfOrder(range=s, min=low, max=high) raise exceptions.OutOfOrder(range=s, min=low, max=high)
return low, high return (low, high)
def is_xor(*args): def is_xor(*args) -> bool:
''' '''
Return True if and only if one arg is truthy. Return True if and only if one arg is truthy.
''' '''
@ -336,7 +337,7 @@ def now(timestamp=True):
return n.timestamp() return n.timestamp()
return n return n
def parse_unit_string(s): def parse_unit_string(s) -> typing.Union[int, float, None]:
''' '''
Try to parse the string as an int, float, or bytestring, or hms. Try to parse the string as an int, float, or bytestring, or hms.
''' '''
@ -357,7 +358,12 @@ def parse_unit_string(s):
else: else:
return bytestring.parsebytes(s) return bytestring.parsebytes(s)
def read_filebytes(filepath, range_min=0, range_max=None, chunk_size=bytestring.MIBIBYTE): def read_filebytes(
filepath,
range_min=0,
range_max=None,
chunk_size=bytestring.MIBIBYTE,
) -> typing.Iterable[bytes]:
''' '''
Yield chunks of bytes from the file between the endpoints. Yield chunks of bytes from the file between the endpoints.
''' '''
@ -373,19 +379,15 @@ def read_filebytes(filepath, range_min=0, range_max=None, chunk_size=bytestring.
with f: with f:
f.seek(range_min) f.seek(range_min)
while sent_amount < range_span: while sent_amount < range_span:
chunk = f.read(chunk_size)
if len(chunk) == 0:
break
needed = range_span - sent_amount needed = range_span - sent_amount
if len(chunk) >= needed: chunk = f.read(min(needed, chunk_size))
yield chunk[:needed] if len(chunk) == 0:
break break
yield chunk yield chunk
sent_amount += len(chunk) sent_amount += len(chunk)
def remove_path_badchars(filepath, allowed=''): def remove_path_badchars(filepath, allowed='') -> str:
''' '''
Remove the bad characters seen in constants.FILENAME_BADCHARS, except Remove the bad characters seen in constants.FILENAME_BADCHARS, except
those which you explicitly permit. those which you explicitly permit.
@ -409,15 +411,18 @@ def slice_before(li, item):
index = li.index(item) index = li.index(item)
return li[:index] return li[:index]
def split_easybake_string(ebstring): def split_easybake_string(ebstring) -> tuple[str, str, str]:
''' '''
Given an easybake string, return (tagname, synonym, rename_to), where Given an easybake string, return (tagname, synonym, rename_to), where
tagname may be a full qualified name, and at least one of tagname may be a full qualified name, and at least one of
synonym or rename_to will be None since both are not posible at once. synonym or rename_to will be None since both are not posible at once.
'languages.python' -> ('languages.python', None, None) >>> split_easybake_string('languages.python')
'languages.python+py' -> ('languages.python', 'py', None) ('languages.python', None, None)
'languages.python=bestlang' -> ('languages.python', None, 'bestlang') >>> split_easybake_string('languages.python+py')
('languages.python', 'py', None)
>>> split_easybake_string('languages.python=bestlang')
('languages.python', None, 'bestlang')
''' '''
ebstring = ebstring.strip() ebstring = ebstring.strip()
ebstring = ebstring.strip('.+=') ebstring = ebstring.strip('.+=')
@ -454,7 +459,7 @@ def split_easybake_string(ebstring):
tagname = tagname.strip('.') tagname = tagname.strip('.')
return (tagname, synonym, rename_to) return (tagname, synonym, rename_to)
def truthystring(s, fallback=False): def truthystring(s, fallback=False) -> typing.Union[bool, None]:
''' '''
If s is already a boolean, int, or None, return a boolean or None. If s is already a boolean, int, or None, return a boolean or None.
If s is a string, return True, False, or None based on the options presented If s is a string, return True, False, or None based on the options presented
@ -480,7 +485,7 @@ def truthystring(s, fallback=False):
return None return None
return False return False
def zip_album(album, recursive=True): def zip_album(album, recursive=True) -> zipstream.ZipFile:
''' '''
Given an album, return a zipstream zipfile that contains the album's Given an album, return a zipstream zipfile that contains the album's
photos (recursive = include children's photos) organized into folders photos (recursive = include children's photos) organized into folders
@ -521,7 +526,7 @@ def zip_album(album, recursive=True):
return zipfile return zipfile
def zip_photos(photos): def zip_photos(photos) -> zipstream.ZipFile:
''' '''
Given some photos, return a zipstream zipfile that contains the files. Given some photos, return a zipstream zipfile that contains the files.
''' '''

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ import sqlite3
import tempfile import tempfile
import time import time
import types import types
import typing
from voussoirkit import cacheclass from voussoirkit import cacheclass
from voussoirkit import configlayers from voussoirkit import configlayers
@ -34,19 +35,19 @@ class PDBAlbumMixin:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def get_album(self, id): def get_album(self, id) -> objects.Album:
return self.get_thing_by_id('album', id) return self.get_thing_by_id('album', id)
def get_album_count(self): def get_album_count(self) -> int:
return self.sql_select_one('SELECT COUNT(id) FROM albums')[0] return self.sql_select_one('SELECT COUNT(id) FROM albums')[0]
def get_albums(self): def get_albums(self) -> typing.Iterable[objects.Album]:
return self.get_things(thing_type='album') return self.get_things(thing_type='album')
def get_albums_by_id(self, ids): def get_albums_by_id(self, ids) -> typing.Iterable[objects.Album]:
return self.get_things_by_id('album', ids) return self.get_things_by_id('album', ids)
def get_albums_by_path(self, directory): def get_albums_by_path(self, directory) -> typing.Iterable[objects.Album]:
''' '''
Yield Albums with the `associated_directory` of this value, Yield Albums with the `associated_directory` of this value,
NOT case-sensitive. NOT case-sensitive.
@ -58,10 +59,10 @@ class PDBAlbumMixin:
album_ids = (album_id for (album_id,) in album_rows) album_ids = (album_id for (album_id,) in album_rows)
return self.get_albums_by_id(album_ids) return self.get_albums_by_id(album_ids)
def get_albums_by_sql(self, query, bindings=None): def get_albums_by_sql(self, query, bindings=None) -> typing.Iterable[objects.Album]:
return self.get_things_by_sql('album', query, bindings) return self.get_things_by_sql('album', query, bindings)
def get_albums_within_directory(self, directory): def get_albums_within_directory(self, directory) -> typing.Iterable[objects.Album]:
# This function is something of a stopgap measure since `search` only # This function is something of a stopgap measure since `search` only
# searches for photos and then yields their containing albums. Thus it # searches for photos and then yields their containing albums. Thus it
# is not possible for search to find albums that contain no photos. # is not possible for search to find albums that contain no photos.
@ -78,7 +79,7 @@ class PDBAlbumMixin:
albums = self.get_albums_by_id(album_ids) albums = self.get_albums_by_id(album_ids)
return albums return albums
def get_root_albums(self): def get_root_albums(self) -> typing.Iterable[objects.Album]:
''' '''
Yield Albums that have no parent. Yield Albums that have no parent.
''' '''
@ -94,7 +95,7 @@ class PDBAlbumMixin:
associated_directories=None, associated_directories=None,
author=None, author=None,
photos=None, photos=None,
): ) -> objects.Album:
''' '''
Create a new album. Photos can be added now or later. Create a new album. Photos can be added now or later.
''' '''
@ -131,7 +132,7 @@ class PDBAlbumMixin:
return album return album
@decorators.transaction @decorators.transaction
def purge_deleted_associated_directories(self, albums=None): def purge_deleted_associated_directories(self, albums=None) -> typing.Iterable[pathclass.Path]:
directories = self.sql_select('SELECT DISTINCT directory FROM album_associated_directories') directories = self.sql_select('SELECT DISTINCT directory FROM album_associated_directories')
directories = (pathclass.Path(directory) for (directory,) in directories) directories = (pathclass.Path(directory) for (directory,) in directories)
directories = [directory for directory in directories if not directory.is_dir] directories = [directory for directory in directories if not directory.is_dir]
@ -148,7 +149,7 @@ class PDBAlbumMixin:
yield from directories yield from directories
@decorators.transaction @decorators.transaction
def purge_empty_albums(self, albums=None): def purge_empty_albums(self, albums=None) -> typing.Iterable[objects.Album]:
if albums is None: if albums is None:
to_check = set(self.get_albums()) to_check = set(self.get_albums())
else: else:
@ -171,24 +172,24 @@ class PDBBookmarkMixin:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def get_bookmark(self, id): def get_bookmark(self, id) -> objects.Bookmark:
return self.get_thing_by_id('bookmark', id) return self.get_thing_by_id('bookmark', id)
def get_bookmark_count(self): def get_bookmark_count(self) -> int:
return self.sql_select_one('SELECT COUNT(id) FROM bookmarks')[0] return self.sql_select_one('SELECT COUNT(id) FROM bookmarks')[0]
def get_bookmarks(self): def get_bookmarks(self) -> typing.Iterable[objects.Bookmark]:
return self.get_things(thing_type='bookmark') return self.get_things(thing_type='bookmark')
def get_bookmarks_by_id(self, ids): def get_bookmarks_by_id(self, ids) -> typing.Iterable[objects.Bookmark]:
return self.get_things_by_id('bookmark', ids) return self.get_things_by_id('bookmark', ids)
def get_bookmarks_by_sql(self, query, bindings=None): def get_bookmarks_by_sql(self, query, bindings=None) -> typing.Iterable[objects.Bookmark]:
return self.get_things_by_sql('bookmark', query, bindings) return self.get_things_by_sql('bookmark', query, bindings)
@decorators.required_feature('bookmark.new') @decorators.required_feature('bookmark.new')
@decorators.transaction @decorators.transaction
def new_bookmark(self, url, title=None, *, author=None): def new_bookmark(self, url, title=None, *, author=None) -> objects.Bookmark:
# These might raise exceptions. # These might raise exceptions.
title = objects.Bookmark.normalize_title(title) title = objects.Bookmark.normalize_title(title)
url = objects.Bookmark.normalize_url(url) url = objects.Bookmark.normalize_url(url)
@ -245,7 +246,7 @@ class PDBCacheManagerMixin:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def clear_all_caches(self): def clear_all_caches(self) -> None:
self.caches['album'].clear() self.caches['album'].clear()
self.caches['bookmark'].clear() self.caches['bookmark'].clear()
self.caches['photo'].clear() self.caches['photo'].clear()
@ -430,7 +431,7 @@ class PDBPhotoMixin:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def assert_no_such_photo_by_path(self, filepath): def assert_no_such_photo_by_path(self, filepath) -> None:
try: try:
existing = self.get_photo_by_path(filepath) existing = self.get_photo_by_path(filepath)
except exceptions.NoSuchPhoto: except exceptions.NoSuchPhoto:
@ -438,10 +439,10 @@ class PDBPhotoMixin:
else: else:
raise exceptions.PhotoExists(existing) raise exceptions.PhotoExists(existing)
def get_photo(self, id): def get_photo(self, id) -> objects.Photo:
return self.get_thing_by_id('photo', id) return self.get_thing_by_id('photo', id)
def get_photo_by_path(self, filepath): def get_photo_by_path(self, filepath) -> objects.Photo:
filepath = pathclass.Path(filepath) filepath = pathclass.Path(filepath)
query = 'SELECT * FROM photos WHERE filepath == ?' query = 'SELECT * FROM photos WHERE filepath == ?'
bindings = [filepath.absolute_path] bindings = [filepath.absolute_path]
@ -451,13 +452,13 @@ class PDBPhotoMixin:
photo = self.get_cached_instance('photo', photo_row) photo = self.get_cached_instance('photo', photo_row)
return photo return photo
def get_photo_count(self): def get_photo_count(self) -> int:
return self.sql_select_one('SELECT COUNT(id) FROM photos')[0] return self.sql_select_one('SELECT COUNT(id) FROM photos')[0]
def get_photos_by_id(self, ids): def get_photos_by_id(self, ids) -> typing.Iterable[objects.Photo]:
return self.get_things_by_id('photo', ids) return self.get_things_by_id('photo', ids)
def get_photos_by_recent(self, count=None): def get_photos_by_recent(self, count=None) -> typing.Iterable[objects.Photo]:
''' '''
Yield photo objects in order of creation time. Yield photo objects in order of creation time.
''' '''
@ -476,7 +477,7 @@ class PDBPhotoMixin:
if count <= 0: if count <= 0:
break break
def get_photos_by_hash(self, sha256): def get_photos_by_hash(self, sha256) -> typing.Iterable[objects.Photo]:
if not isinstance(sha256, str) or len(sha256) != 64: if not isinstance(sha256, str) or len(sha256) != 64:
raise TypeError(f'sha256 shoulbe the 64-character hexdigest string.') raise TypeError(f'sha256 shoulbe the 64-character hexdigest string.')
@ -484,7 +485,7 @@ class PDBPhotoMixin:
bindings = [sha256] bindings = [sha256]
yield from self.get_photos_by_sql(query, bindings) yield from self.get_photos_by_sql(query, bindings)
def get_photos_by_sql(self, query, bindings=None): def get_photos_by_sql(self, query, bindings=None) -> typing.Iterable[objects.Photo]:
return self.get_things_by_sql('photo', query, bindings) return self.get_things_by_sql('photo', query, bindings)
@decorators.required_feature('photo.new') @decorators.required_feature('photo.new')
@ -500,7 +501,7 @@ class PDBPhotoMixin:
known_hash=None, known_hash=None,
searchhidden=False, searchhidden=False,
tags=None, tags=None,
): ) -> objects.Photo:
''' '''
Given a filepath, determine its attributes and create a new Photo object Given a filepath, determine its attributes and create a new Photo object
in the database. Tags may be applied now or later. in the database. Tags may be applied now or later.
@ -574,7 +575,7 @@ class PDBPhotoMixin:
return photo return photo
@decorators.transaction @decorators.transaction
def purge_deleted_files(self, photos=None): def purge_deleted_files(self, photos=None) -> typing.Iterable[objects.Photo]:
''' '''
Delete Photos whose corresponding file on disk is missing. Delete Photos whose corresponding file on disk is missing.
@ -1020,13 +1021,13 @@ class PDBSQLMixin:
self.savepoints = [] self.savepoints = []
self._cached_sql_tables = None self._cached_sql_tables = None
def assert_table_exists(self, table): def assert_table_exists(self, table) -> None:
if not self._cached_sql_tables: if not self._cached_sql_tables:
self._cached_sql_tables = self.get_sql_tables() self._cached_sql_tables = self.get_sql_tables()
if table not in self._cached_sql_tables: if table not in self._cached_sql_tables:
raise exceptions.BadTable(table) raise exceptions.BadTable(table)
def commit(self, message=None): def commit(self, message=None) -> None:
if message is not None: if message is not None:
self.log.debug('Committing - %s.', message) self.log.debug('Committing - %s.', message)
@ -1048,13 +1049,13 @@ class PDBSQLMixin:
self.savepoints.clear() self.savepoints.clear()
self.sql.commit() self.sql.commit()
def get_sql_tables(self): def get_sql_tables(self) -> set[str]:
query = 'SELECT name FROM sqlite_master WHERE type = "table"' query = 'SELECT name FROM sqlite_master WHERE type = "table"'
table_rows = self.sql_select(query) table_rows = self.sql_select(query)
tables = set(name for (name,) in table_rows) tables = set(name for (name,) in table_rows)
return tables return tables
def release_savepoint(self, savepoint, allow_commit=False): def release_savepoint(self, savepoint, allow_commit=False) -> None:
''' '''
Releasing a savepoint removes that savepoint from the timeline, so that Releasing a savepoint removes that savepoint from the timeline, so that
you can no longer roll back to it. Then your choices are to commit you can no longer roll back to it. Then your choices are to commit
@ -1078,7 +1079,7 @@ class PDBSQLMixin:
self.sql_execute(f'RELEASE "{savepoint}"') self.sql_execute(f'RELEASE "{savepoint}"')
self.savepoints = helpers.slice_before(self.savepoints, savepoint) self.savepoints = helpers.slice_before(self.savepoints, savepoint)
def rollback(self, savepoint=None): def rollback(self, savepoint=None) -> None:
''' '''
Given a savepoint, roll the database back to the moment before that Given a savepoint, roll the database back to the moment before that
savepoint was created. Keep in mind that a @transaction savepoint is savepoint was created. Keep in mind that a @transaction savepoint is
@ -1117,7 +1118,7 @@ class PDBSQLMixin:
self.savepoints.clear() self.savepoints.clear()
self.on_commit_queue.clear() self.on_commit_queue.clear()
def savepoint(self, message=None): def savepoint(self, message=None) -> str:
savepoint_id = passwordy.random_hex(length=16) savepoint_id = passwordy.random_hex(length=16)
if message: if message:
self.log.log(5, 'Savepoint %s for %s.', savepoint_id, message) self.log.log(5, 'Savepoint %s for %s.', savepoint_id, message)
@ -1130,13 +1131,13 @@ class PDBSQLMixin:
self.on_rollback_queue.append(savepoint_id) self.on_rollback_queue.append(savepoint_id)
return savepoint_id return savepoint_id
def sql_delete(self, table, pairs): def sql_delete(self, table, pairs) -> None:
self.assert_table_exists(table) self.assert_table_exists(table)
(qmarks, bindings) = sqlhelpers.delete_filler(pairs) (qmarks, bindings) = sqlhelpers.delete_filler(pairs)
query = f'DELETE FROM {table} {qmarks}' query = f'DELETE FROM {table} {qmarks}'
self.sql_execute(query, bindings) self.sql_execute(query, bindings)
def sql_execute(self, query, bindings=[]): def sql_execute(self, query, bindings=[]) -> sqlite3.Cursor:
if bindings is None: if bindings is None:
bindings = [] bindings = []
cur = self.sql.cursor() cur = self.sql.cursor()
@ -1144,7 +1145,7 @@ class PDBSQLMixin:
cur.execute(query, bindings) cur.execute(query, bindings)
return cur return cur
def sql_executescript(self, script): def sql_executescript(self, script) -> None:
''' '''
The problem with Python's default executescript is that it executes a The problem with Python's default executescript is that it executes a
COMMIT before running your script. If I wanted a commit I'd write one! COMMIT before running your script. If I wanted a commit I'd write one!
@ -1157,7 +1158,7 @@ class PDBSQLMixin:
self.log.loud(line) self.log.loud(line)
cur.execute(line) cur.execute(line)
def sql_insert(self, table, data): def sql_insert(self, table, data) -> None:
self.assert_table_exists(table) self.assert_table_exists(table)
column_names = constants.SQL_COLUMNS[table] column_names = constants.SQL_COLUMNS[table]
(qmarks, bindings) = sqlhelpers.insert_filler(column_names, data) (qmarks, bindings) = sqlhelpers.insert_filler(column_names, data)
@ -1165,7 +1166,7 @@ class PDBSQLMixin:
query = f'INSERT INTO {table} VALUES({qmarks})' query = f'INSERT INTO {table} VALUES({qmarks})'
self.sql_execute(query, bindings) self.sql_execute(query, bindings)
def sql_select(self, query, bindings=None): def sql_select(self, query, bindings=None) -> typing.Iterable:
cur = self.sql_execute(query, bindings) cur = self.sql_execute(query, bindings)
while True: while True:
fetch = cur.fetchone() fetch = cur.fetchone()
@ -1177,7 +1178,7 @@ class PDBSQLMixin:
cur = self.sql_execute(query, bindings) cur = self.sql_execute(query, bindings)
return cur.fetchone() return cur.fetchone()
def sql_update(self, table, pairs, where_key): def sql_update(self, table, pairs, where_key) -> None:
self.assert_table_exists(table) self.assert_table_exists(table)
(qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key) (qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key)
query = f'UPDATE {table} {qmarks}' query = f'UPDATE {table} {qmarks}'
@ -1189,7 +1190,7 @@ class PDBTagMixin:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def assert_no_such_tag(self, name): def assert_no_such_tag(self, name) -> None:
try: try:
existing_tag = self.get_tag_by_name(name) existing_tag = self.get_tag_by_name(name)
except exceptions.NoSuchTag: except exceptions.NoSuchTag:
@ -1203,7 +1204,7 @@ class PDBTagMixin:
names = set(name for (name,) in tag_rows) names = set(name for (name,) in tag_rows)
return names return names
def get_all_tag_names(self): def get_all_tag_names(self) -> set[str]:
''' '''
Return a set containing the names of all tags as strings. Return a set containing the names of all tags as strings.
Useful for when you don't want the overhead of actual Tag objects. Useful for when you don't want the overhead of actual Tag objects.
@ -1216,19 +1217,19 @@ class PDBTagMixin:
synonyms = {syn: tag for (syn, tag) in syn_rows} synonyms = {syn: tag for (syn, tag) in syn_rows}
return synonyms return synonyms
def get_all_synonyms(self): def get_all_synonyms(self) -> dict:
''' '''
Return a dict mapping {synonym: mastertag} as strings. Return a dict mapping {synonym: mastertag} as strings.
''' '''
return self.get_cached_tag_export(self._get_all_synonyms) return self.get_cached_tag_export(self._get_all_synonyms)
def get_root_tags(self): def get_root_tags(self) -> typing.Iterable[objects.Tag]:
''' '''
Yield Tags that have no parent. Yield Tags that have no parent.
''' '''
return self.get_root_things('tag') return self.get_root_things('tag')
def get_tag(self, name=None, id=None): def get_tag(self, name=None, id=None) -> objects.Tag:
''' '''
Redirect to get_tag_by_id or get_tag_by_name. Redirect to get_tag_by_id or get_tag_by_name.
''' '''
@ -1240,10 +1241,10 @@ class PDBTagMixin:
else: else:
return self.get_tag_by_name(name) return self.get_tag_by_name(name)
def get_tag_by_id(self, id): def get_tag_by_id(self, id) -> objects.Tag:
return self.get_thing_by_id('tag', thing_id=id) return self.get_thing_by_id('tag', thing_id=id)
def get_tag_by_name(self, tagname): def get_tag_by_name(self, tagname) -> objects.Tag:
if isinstance(tagname, objects.Tag): if isinstance(tagname, objects.Tag):
if tagname.photodb == self: if tagname.photodb == self:
return tagname return tagname
@ -1277,24 +1278,24 @@ class PDBTagMixin:
tag = self.get_cached_instance('tag', tag_row) tag = self.get_cached_instance('tag', tag_row)
return tag return tag
def get_tag_count(self): def get_tag_count(self) -> int:
return self.sql_select_one('SELECT COUNT(id) FROM tags')[0] return self.sql_select_one('SELECT COUNT(id) FROM tags')[0]
def get_tags(self): def get_tags(self) -> typing.Iterable[objects.Tag]:
''' '''
Yield all Tags in the database. Yield all Tags in the database.
''' '''
return self.get_things(thing_type='tag') return self.get_things(thing_type='tag')
def get_tags_by_id(self, ids): def get_tags_by_id(self, ids) -> typing.Iterable[objects.Tag]:
return self.get_things_by_id('tag', ids) return self.get_things_by_id('tag', ids)
def get_tags_by_sql(self, query, bindings=None): def get_tags_by_sql(self, query, bindings=None) -> typing.Iterable[objects.Tag]:
return self.get_things_by_sql('tag', query, bindings) return self.get_things_by_sql('tag', query, bindings)
@decorators.required_feature('tag.new') @decorators.required_feature('tag.new')
@decorators.transaction @decorators.transaction
def new_tag(self, tagname, description=None, *, author=None): def new_tag(self, tagname, description=None, *, author=None) -> objects.Tag:
''' '''
Register a new tag and return the Tag object. Register a new tag and return the Tag object.
''' '''
@ -1324,7 +1325,7 @@ class PDBTagMixin:
return tag return tag
def normalize_tagname(self, tagname): def normalize_tagname(self, tagname) -> str:
tagname = objects.Tag.normalize_name( tagname = objects.Tag.normalize_name(
tagname, tagname,
# valid_chars=self.config['tag']['valid_chars'], # valid_chars=self.config['tag']['valid_chars'],
@ -1339,7 +1340,7 @@ class PDBUserMixin:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def assert_no_such_user(self, username): def assert_no_such_user(self, username) -> None:
try: try:
existing_user = self.get_user(username=username) existing_user = self.get_user(username=username)
except exceptions.NoSuchUser: except exceptions.NoSuchUser:
@ -1347,14 +1348,14 @@ class PDBUserMixin:
else: else:
raise exceptions.UserExists(existing_user) raise exceptions.UserExists(existing_user)
def assert_valid_password(self, password): def assert_valid_password(self, password) -> None:
if not isinstance(password, bytes): if not isinstance(password, bytes):
raise TypeError(f'Password must be {bytes}, not {type(password)}.') raise TypeError(f'Password must be {bytes}, not {type(password)}.')
if len(password) < self.config['user']['min_password_length']: if len(password) < self.config['user']['min_password_length']:
raise exceptions.PasswordTooShort(min_length=self.config['user']['min_password_length']) raise exceptions.PasswordTooShort(min_length=self.config['user']['min_password_length'])
def assert_valid_username(self, username): def assert_valid_username(self, username) -> None:
if not isinstance(username, str): if not isinstance(username, str):
raise TypeError(f'Username must be {str}, not {type(username)}.') raise TypeError(f'Username must be {str}, not {type(username)}.')
@ -1374,7 +1375,7 @@ class PDBUserMixin:
if badchars: if badchars:
raise exceptions.InvalidUsernameChars(username=username, badchars=badchars) raise exceptions.InvalidUsernameChars(username=username, badchars=badchars)
def generate_user_id(self): def generate_user_id(self) -> str:
''' '''
User IDs are randomized instead of integers like the other objects, User IDs are randomized instead of integers like the other objects,
so they get their own method. so they get their own method.
@ -1391,7 +1392,7 @@ class PDBUserMixin:
return user_id return user_id
def get_user(self, username=None, id=None): def get_user(self, username=None, id=None) -> objects.User:
''' '''
Redirect to get_user_by_id or get_user_by_username. Redirect to get_user_by_id or get_user_by_username.
''' '''
@ -1403,10 +1404,10 @@ class PDBUserMixin:
else: else:
return self.get_user_by_username(username) return self.get_user_by_username(username)
def get_user_by_id(self, id): def get_user_by_id(self, id) -> objects.User:
return self.get_thing_by_id('user', id) return self.get_thing_by_id('user', id)
def get_user_by_username(self, username): def get_user_by_username(self, username) -> objects.User:
user_row = self.sql_select_one('SELECT * FROM users WHERE username == ?', [username]) user_row = self.sql_select_one('SELECT * FROM users WHERE username == ?', [username])
if user_row is None: if user_row is None:
@ -1414,10 +1415,10 @@ class PDBUserMixin:
return self.get_cached_instance('user', user_row) return self.get_cached_instance('user', user_row)
def get_user_count(self): def get_user_count(self) -> int:
return self.sql_select_one('SELECT COUNT(id) FROM users')[0] return self.sql_select_one('SELECT COUNT(id) FROM users')[0]
def get_user_id_or_none(self, user_obj_or_id): def get_user_id_or_none(self, user_obj_or_id) -> typing.Optional[str]:
''' '''
For methods that create photos, albums, etc., we sometimes associate For methods that create photos, albums, etc., we sometimes associate
them with an author but sometimes not. The callers of those methods them with an author but sometimes not. The callers of those methods
@ -1446,17 +1447,17 @@ class PDBUserMixin:
return author_id return author_id
def get_users(self): def get_users(self) -> typing.Iterable[objects.User]:
return self.get_things('user') return self.get_things('user')
def get_users_by_id(self, ids): def get_users_by_id(self, ids) -> typing.Iterable[objects.User]:
return self.get_things_by_id('user', ids) return self.get_things_by_id('user', ids)
def get_users_by_sql(self, query, bindings=None): def get_users_by_sql(self, query, bindings=None) -> typing.Iterable[objects.User]:
return self.get_things_by_sql('user', query, bindings) return self.get_things_by_sql('user', query, bindings)
@decorators.required_feature('user.login') @decorators.required_feature('user.login')
def login(self, username=None, id=None, *, password): def login(self, username=None, id=None, *, password) -> objects.User:
''' '''
Return the User object for the user if the credentials are correct. Return the User object for the user if the credentials are correct.
''' '''
@ -1476,7 +1477,7 @@ class PDBUserMixin:
@decorators.required_feature('user.new') @decorators.required_feature('user.new')
@decorators.transaction @decorators.transaction
def new_user(self, username, password, *, display_name=None): def new_user(self, username, password, *, display_name=None) -> objects.User:
# These might raise exceptions. # These might raise exceptions.
self.assert_valid_username(username) self.assert_valid_username(username)
self.assert_no_such_user(username=username) self.assert_no_such_user(username=username)
@ -1953,6 +1954,7 @@ class PhotoDB(
self.sql_executescript(constants.DB_PRAGMAS) self.sql_executescript(constants.DB_PRAGMAS)
self.sql.commit() self.sql.commit()
# Will add -> PhotoDB when forward references are supported
@classmethod @classmethod
def closest_photodb(cls, path, *args, **kwargs): def closest_photodb(cls, path, *args, **kwargs):
''' '''
@ -1989,7 +1991,7 @@ class PhotoDB(
else: else:
return f'PhotoDB(data_directory={self.data_directory})' return f'PhotoDB(data_directory={self.data_directory})'
def close(self): def close(self) -> None:
# Wrapped in hasattr because if the object fails __init__, Python will # Wrapped in hasattr because if the object fails __init__, Python will
# still call __del__ and thus close(), even though the attributes # still call __del__ and thus close(), even though the attributes
# we're trying to clean up never got set. # we're trying to clean up never got set.
@ -1999,7 +2001,7 @@ class PhotoDB(
if getattr(self, 'ephemeral', False): if getattr(self, 'ephemeral', False):
self.ephemeral_directory.cleanup() self.ephemeral_directory.cleanup()
def generate_id(self, table): def generate_id(self, table) -> str:
''' '''
Create a new ID number that is unique to the given table. Create a new ID number that is unique to the given table.
Note that while this method may INSERT / UPDATE, it does not commit. Note that while this method may INSERT / UPDATE, it does not commit.
@ -2032,7 +2034,7 @@ class PhotoDB(
self.sql_update(table='id_numbers', pairs=pairs, where_key='tab') self.sql_update(table='id_numbers', pairs=pairs, where_key='tab')
return new_id return new_id
def load_config(self): def load_config(self) -> None:
(config, needs_rewrite) = configlayers.load_file( (config, needs_rewrite) = configlayers.load_file(
filepath=self.config_filepath, filepath=self.config_filepath,
defaults=constants.DEFAULT_CONFIGURATION, defaults=constants.DEFAULT_CONFIGURATION,
@ -2042,6 +2044,6 @@ class PhotoDB(
if needs_rewrite: if needs_rewrite:
self.save_config() self.save_config()
def save_config(self): def save_config(self) -> None:
with self.config_filepath.open('w', encoding='utf-8') as handle: with self.config_filepath.open('w', encoding='utf-8') as handle:
handle.write(json.dumps(self.config, indent=4, sort_keys=True)) handle.write(json.dumps(self.config, indent=4, sort_keys=True))