diff --git a/etiquette/objects.py b/etiquette/objects.py index 7d4f610..c0ce0ee 100644 --- a/etiquette/objects.py +++ b/etiquette/objects.py @@ -68,6 +68,7 @@ class ObjectBase: class GroupableMixin: group_getter = None + group_getter_many = None group_sql_index = None group_table = None @@ -160,12 +161,12 @@ class GroupableMixin: [self.id] ) child_ids = [row[0] for row in child_rows] - children = [self.group_getter(id=child_id) for child_id in child_ids] + children = self.group_getter_many(child_ids) if isinstance(self, Tag): - children.sort(key=lambda x: x.name) + children = sorted(children, key=lambda x: x.name) else: - children.sort(key=lambda x: x.id) + children = sorted(children, key=lambda x: x.id) return children def get_parent(self): @@ -235,14 +236,12 @@ class Album(ObjectBase, GroupableMixin): self.name = 'Album %s' % self.id self.group_getter = self.photodb.get_album + self.group_getter_many = self.photodb.get_albums_by_id self._sum_bytes_local = None self._sum_bytes_recursive = None self._sum_photos_recursive = None - def __hash__(self): - return hash(self.id) - def __repr__(self): return f'Album:{self.id}' @@ -447,8 +446,9 @@ class Album(ObjectBase, GroupableMixin): 'SELECT photoid FROM album_photo_rel WHERE albumid == ?', [self.id] ) - photos = [self.photodb.get_photo(id=fetch[0]) for fetch in generator] - photos.sort(key=lambda x: x.basename.lower()) + photo_ids = [row[0] for row in generator] + photos = self.photodb.get_photos_by_id(photo_ids) + photos = sorted(photos, key=lambda x: x.basename.lower()) return photos def has_photo(self, photo): @@ -1161,6 +1161,7 @@ class Tag(ObjectBase, GroupableMixin): self.author_id = self.normalize_author_id(db_row['author_id']) self.group_getter = self.photodb.get_tag + self.group_getter_many = self.photodb.get_tags_by_id self._cached_qualified_name = None def __eq__(self, other): diff --git a/etiquette/photodb.py b/etiquette/photodb.py index c459f8b..3c95819 100644 --- a/etiquette/photodb.py +++ b/etiquette/photodb.py @@ -66,6 +66,9 @@ class PDBAlbumMixin: def get_albums(self): yield from self.get_things(thing_type='album') + def get_albums_by_id(self, ids): + return self.get_things_by_id('album', ids) + def get_root_albums(self): for album in self.get_albums(): if album.get_parent() is None: @@ -129,6 +132,9 @@ class PDBBookmarkMixin: def get_bookmarks(self): yield from self.get_things(thing_type='bookmark') + def get_bookmarks_by_id(self, ids): + return self.get_things_by_id('bookmark', ids) + @decorators.required_feature('bookmark.new') @decorators.transaction def new_bookmark(self, url, title=None, *, author=None, commit=True): @@ -179,6 +185,9 @@ class PDBPhotoMixin: photo = objects.Photo(self, photo_row) return photo + def get_photos_by_id(self, ids): + return self.get_things_by_id('photo', ids) + def get_photos_by_recent(self, count=None): ''' Yield photo objects in order of creation time. @@ -788,7 +797,7 @@ class PDBTagMixin: def get_tag(self, name=None, id=None): ''' - Redirect to get_tag_by_id or get_tag_by_name after xor-checking the parameters. + Redirect to get_tag_by_id or get_tag_by_name. ''' if not helpers.is_xor(id, name): raise exceptions.NotExclusive(['id', 'name']) @@ -842,6 +851,9 @@ class PDBTagMixin: ''' yield from self.get_things(thing_type='tag') + def get_tags_by_id(self, ids): + return self.get_things_by_id('tag', ids) + def get_root_tags(self): ''' Yield all Tags that have no parent. @@ -1429,9 +1441,7 @@ class PhotoDB( thing_cache = self.caches[thing_type] try: - #self.log.debug('Cache hit for %s %s', thing_type, thing_id) - val = thing_cache[thing_id] - return val + return thing_cache[thing_id] except KeyError: pass @@ -1444,19 +1454,43 @@ class PhotoDB( thing_cache[thing_id] = thing return thing - def get_things(self, thing_type, orderby=None): + def get_things(self, thing_type): thing_map = _THING_CLASSES[thing_type] - if orderby: - query = 'SELECT * FROM %s ORDER BY %s' % (thing_map['table'], orderby) - else: - query = 'SELECT * FROM %s' % thing_map['table'] + query = 'SELECT * FROM %s' % thing_map['table'] things = self.sql_select(query) for thing_row in things: thing = thing_map['class'](self, db_row=thing_row) yield thing + def get_things_by_id(self, thing_type, thing_ids): + thing_map = _THING_CLASSES[thing_type] + thing_class = thing_map['class'] + thing_cache = self.caches[thing_type] + + ids_needed = set(thing_ids) + things = set() + for thing_id in ids_needed: + try: + thing = thing_cache[thing_id] + except KeyError: + pass + else: + things.add(thing) + ids_needed.remove(thing.id) + + yield from things + + if ids_needed: + qmarks = '(%s)' % ','.join('?' * len(ids_needed)) + query = 'SELECT * FROM %s WHERE id IN %s' % (thing_map['table'], qmarks) + bindings = list(ids_needed) + more_things = self.sql_select(query, bindings) + for thing_row in more_things: + thing = thing_map['class'](self, db_row=thing_row) + yield thing + def load_config(self): config = copy.deepcopy(constants.DEFAULT_CONFIGURATION) user_config_exists = self.config_filepath.is_file