Add method get_things_by_id for mass lookups.

This commit is contained in:
voussoir 2018-04-15 02:14:06 -07:00
parent 68d6e4faf4
commit 518a45ccd8
2 changed files with 52 additions and 17 deletions

View file

@ -68,6 +68,7 @@ class ObjectBase:
class GroupableMixin:
group_getter = None
group_getter_many = None
group_sql_index = None
group_table = None
@ -160,12 +161,12 @@ class GroupableMixin:
[self.id]
)
child_ids = [row[0] for row in child_rows]
children = [self.group_getter(id=child_id) for child_id in child_ids]
children = self.group_getter_many(child_ids)
if isinstance(self, Tag):
children.sort(key=lambda x: x.name)
children = sorted(children, key=lambda x: x.name)
else:
children.sort(key=lambda x: x.id)
children = sorted(children, key=lambda x: x.id)
return children
def get_parent(self):
@ -235,14 +236,12 @@ class Album(ObjectBase, GroupableMixin):
self.name = 'Album %s' % self.id
self.group_getter = self.photodb.get_album
self.group_getter_many = self.photodb.get_albums_by_id
self._sum_bytes_local = None
self._sum_bytes_recursive = None
self._sum_photos_recursive = None
def __hash__(self):
return hash(self.id)
def __repr__(self):
return f'Album:{self.id}'
@ -447,8 +446,9 @@ class Album(ObjectBase, GroupableMixin):
'SELECT photoid FROM album_photo_rel WHERE albumid == ?',
[self.id]
)
photos = [self.photodb.get_photo(id=fetch[0]) for fetch in generator]
photos.sort(key=lambda x: x.basename.lower())
photo_ids = [row[0] for row in generator]
photos = self.photodb.get_photos_by_id(photo_ids)
photos = sorted(photos, key=lambda x: x.basename.lower())
return photos
def has_photo(self, photo):
@ -1161,6 +1161,7 @@ class Tag(ObjectBase, GroupableMixin):
self.author_id = self.normalize_author_id(db_row['author_id'])
self.group_getter = self.photodb.get_tag
self.group_getter_many = self.photodb.get_tags_by_id
self._cached_qualified_name = None
def __eq__(self, other):

View file

@ -66,6 +66,9 @@ class PDBAlbumMixin:
def get_albums(self):
yield from self.get_things(thing_type='album')
def get_albums_by_id(self, ids):
return self.get_things_by_id('album', ids)
def get_root_albums(self):
for album in self.get_albums():
if album.get_parent() is None:
@ -129,6 +132,9 @@ class PDBBookmarkMixin:
def get_bookmarks(self):
yield from self.get_things(thing_type='bookmark')
def get_bookmarks_by_id(self, ids):
return self.get_things_by_id('bookmark', ids)
@decorators.required_feature('bookmark.new')
@decorators.transaction
def new_bookmark(self, url, title=None, *, author=None, commit=True):
@ -179,6 +185,9 @@ class PDBPhotoMixin:
photo = objects.Photo(self, photo_row)
return photo
def get_photos_by_id(self, ids):
return self.get_things_by_id('photo', ids)
def get_photos_by_recent(self, count=None):
'''
Yield photo objects in order of creation time.
@ -788,7 +797,7 @@ class PDBTagMixin:
def get_tag(self, name=None, id=None):
'''
Redirect to get_tag_by_id or get_tag_by_name after xor-checking the parameters.
Redirect to get_tag_by_id or get_tag_by_name.
'''
if not helpers.is_xor(id, name):
raise exceptions.NotExclusive(['id', 'name'])
@ -842,6 +851,9 @@ class PDBTagMixin:
'''
yield from self.get_things(thing_type='tag')
def get_tags_by_id(self, ids):
return self.get_things_by_id('tag', ids)
def get_root_tags(self):
'''
Yield all Tags that have no parent.
@ -1429,9 +1441,7 @@ class PhotoDB(
thing_cache = self.caches[thing_type]
try:
#self.log.debug('Cache hit for %s %s', thing_type, thing_id)
val = thing_cache[thing_id]
return val
return thing_cache[thing_id]
except KeyError:
pass
@ -1444,12 +1454,9 @@ class PhotoDB(
thing_cache[thing_id] = thing
return thing
def get_things(self, thing_type, orderby=None):
def get_things(self, thing_type):
thing_map = _THING_CLASSES[thing_type]
if orderby:
query = 'SELECT * FROM %s ORDER BY %s' % (thing_map['table'], orderby)
else:
query = 'SELECT * FROM %s' % thing_map['table']
things = self.sql_select(query)
@ -1457,6 +1464,33 @@ class PhotoDB(
thing = thing_map['class'](self, db_row=thing_row)
yield thing
def get_things_by_id(self, thing_type, thing_ids):
thing_map = _THING_CLASSES[thing_type]
thing_class = thing_map['class']
thing_cache = self.caches[thing_type]
ids_needed = set(thing_ids)
things = set()
for thing_id in ids_needed:
try:
thing = thing_cache[thing_id]
except KeyError:
pass
else:
things.add(thing)
ids_needed.remove(thing.id)
yield from things
if ids_needed:
qmarks = '(%s)' % ','.join('?' * len(ids_needed))
query = 'SELECT * FROM %s WHERE id IN %s' % (thing_map['table'], qmarks)
bindings = list(ids_needed)
more_things = self.sql_select(query, bindings)
for thing_row in more_things:
thing = thing_map['class'](self, db_row=thing_row)
yield thing
def load_config(self):
config = copy.deepcopy(constants.DEFAULT_CONFIGURATION)
user_config_exists = self.config_filepath.is_file