Add method get_things_by_id for mass lookups.

This commit is contained in:
voussoir 2018-04-15 02:14:06 -07:00
parent 68d6e4faf4
commit 518a45ccd8
2 changed files with 52 additions and 17 deletions

View file

@ -68,6 +68,7 @@ class ObjectBase:
class GroupableMixin: class GroupableMixin:
group_getter = None group_getter = None
group_getter_many = None
group_sql_index = None group_sql_index = None
group_table = None group_table = None
@ -160,12 +161,12 @@ class GroupableMixin:
[self.id] [self.id]
) )
child_ids = [row[0] for row in child_rows] child_ids = [row[0] for row in child_rows]
children = [self.group_getter(id=child_id) for child_id in child_ids] children = self.group_getter_many(child_ids)
if isinstance(self, Tag): if isinstance(self, Tag):
children.sort(key=lambda x: x.name) children = sorted(children, key=lambda x: x.name)
else: else:
children.sort(key=lambda x: x.id) children = sorted(children, key=lambda x: x.id)
return children return children
def get_parent(self): def get_parent(self):
@ -235,14 +236,12 @@ class Album(ObjectBase, GroupableMixin):
self.name = 'Album %s' % self.id self.name = 'Album %s' % self.id
self.group_getter = self.photodb.get_album self.group_getter = self.photodb.get_album
self.group_getter_many = self.photodb.get_albums_by_id
self._sum_bytes_local = None self._sum_bytes_local = None
self._sum_bytes_recursive = None self._sum_bytes_recursive = None
self._sum_photos_recursive = None self._sum_photos_recursive = None
def __hash__(self):
return hash(self.id)
def __repr__(self): def __repr__(self):
return f'Album:{self.id}' return f'Album:{self.id}'
@ -447,8 +446,9 @@ class Album(ObjectBase, GroupableMixin):
'SELECT photoid FROM album_photo_rel WHERE albumid == ?', 'SELECT photoid FROM album_photo_rel WHERE albumid == ?',
[self.id] [self.id]
) )
photos = [self.photodb.get_photo(id=fetch[0]) for fetch in generator] photo_ids = [row[0] for row in generator]
photos.sort(key=lambda x: x.basename.lower()) photos = self.photodb.get_photos_by_id(photo_ids)
photos = sorted(photos, key=lambda x: x.basename.lower())
return photos return photos
def has_photo(self, photo): def has_photo(self, photo):
@ -1161,6 +1161,7 @@ class Tag(ObjectBase, GroupableMixin):
self.author_id = self.normalize_author_id(db_row['author_id']) self.author_id = self.normalize_author_id(db_row['author_id'])
self.group_getter = self.photodb.get_tag self.group_getter = self.photodb.get_tag
self.group_getter_many = self.photodb.get_tags_by_id
self._cached_qualified_name = None self._cached_qualified_name = None
def __eq__(self, other): def __eq__(self, other):

View file

@ -66,6 +66,9 @@ class PDBAlbumMixin:
def get_albums(self): def get_albums(self):
yield from self.get_things(thing_type='album') yield from self.get_things(thing_type='album')
def get_albums_by_id(self, ids):
return self.get_things_by_id('album', ids)
def get_root_albums(self): def get_root_albums(self):
for album in self.get_albums(): for album in self.get_albums():
if album.get_parent() is None: if album.get_parent() is None:
@ -129,6 +132,9 @@ class PDBBookmarkMixin:
def get_bookmarks(self): def get_bookmarks(self):
yield from self.get_things(thing_type='bookmark') yield from self.get_things(thing_type='bookmark')
def get_bookmarks_by_id(self, ids):
return self.get_things_by_id('bookmark', ids)
@decorators.required_feature('bookmark.new') @decorators.required_feature('bookmark.new')
@decorators.transaction @decorators.transaction
def new_bookmark(self, url, title=None, *, author=None, commit=True): def new_bookmark(self, url, title=None, *, author=None, commit=True):
@ -179,6 +185,9 @@ class PDBPhotoMixin:
photo = objects.Photo(self, photo_row) photo = objects.Photo(self, photo_row)
return photo return photo
def get_photos_by_id(self, ids):
return self.get_things_by_id('photo', ids)
def get_photos_by_recent(self, count=None): def get_photos_by_recent(self, count=None):
''' '''
Yield photo objects in order of creation time. Yield photo objects in order of creation time.
@ -788,7 +797,7 @@ class PDBTagMixin:
def get_tag(self, name=None, id=None): def get_tag(self, name=None, id=None):
''' '''
Redirect to get_tag_by_id or get_tag_by_name after xor-checking the parameters. Redirect to get_tag_by_id or get_tag_by_name.
''' '''
if not helpers.is_xor(id, name): if not helpers.is_xor(id, name):
raise exceptions.NotExclusive(['id', 'name']) raise exceptions.NotExclusive(['id', 'name'])
@ -842,6 +851,9 @@ class PDBTagMixin:
''' '''
yield from self.get_things(thing_type='tag') yield from self.get_things(thing_type='tag')
def get_tags_by_id(self, ids):
return self.get_things_by_id('tag', ids)
def get_root_tags(self): def get_root_tags(self):
''' '''
Yield all Tags that have no parent. Yield all Tags that have no parent.
@ -1429,9 +1441,7 @@ class PhotoDB(
thing_cache = self.caches[thing_type] thing_cache = self.caches[thing_type]
try: try:
#self.log.debug('Cache hit for %s %s', thing_type, thing_id) return thing_cache[thing_id]
val = thing_cache[thing_id]
return val
except KeyError: except KeyError:
pass pass
@ -1444,19 +1454,43 @@ class PhotoDB(
thing_cache[thing_id] = thing thing_cache[thing_id] = thing
return thing return thing
def get_things(self, thing_type, orderby=None): def get_things(self, thing_type):
thing_map = _THING_CLASSES[thing_type] thing_map = _THING_CLASSES[thing_type]
if orderby: query = 'SELECT * FROM %s' % thing_map['table']
query = 'SELECT * FROM %s ORDER BY %s' % (thing_map['table'], orderby)
else:
query = 'SELECT * FROM %s' % thing_map['table']
things = self.sql_select(query) things = self.sql_select(query)
for thing_row in things: for thing_row in things:
thing = thing_map['class'](self, db_row=thing_row) thing = thing_map['class'](self, db_row=thing_row)
yield thing yield thing
def get_things_by_id(self, thing_type, thing_ids):
thing_map = _THING_CLASSES[thing_type]
thing_class = thing_map['class']
thing_cache = self.caches[thing_type]
ids_needed = set(thing_ids)
things = set()
for thing_id in ids_needed:
try:
thing = thing_cache[thing_id]
except KeyError:
pass
else:
things.add(thing)
ids_needed.remove(thing.id)
yield from things
if ids_needed:
qmarks = '(%s)' % ','.join('?' * len(ids_needed))
query = 'SELECT * FROM %s WHERE id IN %s' % (thing_map['table'], qmarks)
bindings = list(ids_needed)
more_things = self.sql_select(query, bindings)
for thing_row in more_things:
thing = thing_map['class'](self, db_row=thing_row)
yield thing
def load_config(self): def load_config(self):
config = copy.deepcopy(constants.DEFAULT_CONFIGURATION) config = copy.deepcopy(constants.DEFAULT_CONFIGURATION)
user_config_exists = self.config_filepath.is_file user_config_exists = self.config_filepath.is_file