diff --git a/etiquette/photodb.py b/etiquette/photodb.py index 6ce91d5..543e4cc 100644 --- a/etiquette/photodb.py +++ b/etiquette/photodb.py @@ -46,15 +46,12 @@ class PDBAlbumMixin: NOT case-sensitive. ''' filepath = pathclass.Path(filepath).absolute_path - cur = self.sql.cursor() - cur.execute( - 'SELECT albumid FROM album_associated_directories WHERE directory == ?', - [filepath] - ) - fetch = cur.fetchone() - if fetch is None: + query = 'SELECT albumid FROM album_associated_directories WHERE directory == ?' + bindings = [filepath] + album_row = self.sql_select_one(query, binding) + if album_row is None: raise exceptions.NoSuchAlbum(filepath) - album_id = fetch[0] + album_id = album_row[0] return self.get_album(album_id) def get_albums(self): @@ -165,12 +162,12 @@ class PDBPhotoMixin: def get_photo_by_path(self, filepath): filepath = pathclass.Path(filepath) - cur = self.sql.cursor() - cur.execute('SELECT * FROM photos WHERE filepath == ?', [filepath.absolute_path]) - fetch = cur.fetchone() - if fetch is None: + query = 'SELECT * FROM photos WHERE filepath == ?' + bindings = [filepath.absolute_path] + photo_row = self.sql_select_one(query, bindings) + if photo_row is None: raise exceptions.NoSuchPhoto(filepath) - photo = objects.Photo(self, fetch) + photo = objects.Photo(self, photo_row) return photo def get_photos_by_recent(self, count=None): @@ -179,16 +176,11 @@ class PDBPhotoMixin: ''' if count is not None and count <= 0: return - # We're going to use a second cursor because the first one may - # get used for something else, deactivating this query. - cur = self.sql.cursor() - cur.execute('SELECT * FROM photos ORDER BY created DESC') - while True: - fetch = cur.fetchone() - if fetch is None: - break - photo = objects.Photo(self, fetch) + query = 'SELECT * FROM photos ORDER BY created DESC' + photo_rows = self.sql_select(query) + for photo_row in photo_rows: + photo = objects.Photo(self, photo_row) yield photo if count is None: @@ -629,9 +621,8 @@ class PDBPhotoMixin: query = '%s\n%s\n%s' % ('-' * 80, query, '-' * 80) print(query, bindings) - #cur = self.sql.cursor() - #cur.execute('EXPLAIN QUERY PLAN ' + query, bindings) - #print('\n'.join(str(x) for x in cur.fetchall())) + #explain = self.sql_execute('EXPLAIN QUERY PLAN ' + query, bindings) + #print('\n'.join(str(x) for x in explain.fetchall())) generator = self.sql_select(query, bindings) photos_received = 0 for row in generator: @@ -717,10 +708,9 @@ class PDBSQLMixin: else: restore_to = self.savepoints.pop(-1) - cur = self.sql.cursor() self.log.debug('Rolling back to %s', restore_to) query = 'ROLLBACK TO "%s"' % restore_to - cur.execute(query) + self.sql_execute(query) while len(self.on_commit_queue) > 0: item = self.on_commit_queue.pop(-1) if item == restore_to: @@ -821,21 +811,20 @@ class PDBTagMixin: except (exceptions.TagTooShort, exceptions.TagTooLong): raise exceptions.NoSuchTag(tagname) - cur = self.sql.cursor() while True: # Return if it's a toplevel... - cur.execute('SELECT * FROM tags WHERE name == ?', [tagname]) - fetch = cur.fetchone() - if fetch is not None: - return objects.Tag(self, fetch) + tag_row = self.sql_select_one('SELECT * FROM tags WHERE name == ?', [tagname]) + if tag_row is not None: + return objects.Tag(self, tag_row) # ...or resolve the synonym and try again. - cur.execute('SELECT mastername FROM tag_synonyms WHERE name == ?', [tagname]) - fetch = cur.fetchone() - if fetch is None: + query = 'SELECT mastername FROM tag_synonyms WHERE name == ?' + bindings = [tagname] + name_row = self.sql_select_one(query, bindings) + if name_row is None: # was not a master tag or synonym raise exceptions.NoSuchTag(tagname) - tagname = fetch[0] + tagname = name_row[0] def get_tags(self): ''' @@ -944,13 +933,12 @@ class PDBUserMixin: so they get their own method. ''' possible = string.digits + string.ascii_uppercase - cur = self.sql.cursor() for retry in range(20): user_id = [random.choice(possible) for x in range(self.config['id_length'])] user_id = ''.join(user_id) - cur.execute('SELECT * FROM users WHERE id == ?', [user_id]) - if cur.fetchone() is None: + user_exists = self.sql_select_one('SELECT 1 FROM users WHERE id == ?', [user_id]) + if user_exists is None: break else: raise Exception('Failed to create user id after 20 tries.') @@ -961,15 +949,13 @@ class PDBUserMixin: if not helpers.is_xor(id, username): raise exceptions.NotExclusive(['id', 'username']) - cur = self.sql.cursor() if username is not None: - cur.execute('SELECT * FROM users WHERE username == ?', [username]) + user_row = self.sql_select_one('SELECT * FROM users WHERE username == ?', [username]) else: - cur.execute('SELECT * FROM users WHERE id == ?', [id]) + user_row = self.sql_select_one('SELECT * FROM users WHERE id == ?', [id]) - fetch = cur.fetchone() - if fetch is not None: - return objects.User(self, fetch) + if user_row is not None: + return objects.User(self, user_row) else: raise exceptions.NoSuchUser(username or id) @@ -1010,17 +996,15 @@ class PDBUserMixin: ''' Return the User object for the user if the credentials are correct. ''' - cur = self.sql.cursor() - cur.execute('SELECT * FROM users WHERE id == ?', [user_id]) - fetch = cur.fetchone() + user_row = self.sql_select_one('SELECT * FROM users WHERE id == ?', [user_id]) - if fetch is None: + if user_row is None: raise exceptions.WrongLogin() if not isinstance(password, bytes): password = password.encode('utf-8') - user = objects.User(self, fetch) + user = objects.User(self, user_row) success = bcrypt.checkpw(password, user.password_hash) if not success: @@ -1353,10 +1337,7 @@ class PhotoDB( } def _check_version(self): - cur = self.sql.cursor() - - cur.execute('PRAGMA user_version') - existing_version = cur.fetchone()[0] + existing_version = self.sql_execute('PRAGMA user_version').fetchone()[0] if existing_version != constants.DATABASE_VERSION: exc = exceptions.DatabaseOutOfDate( current=existing_version, @@ -1397,23 +1378,26 @@ class PhotoDB( if table not in ['photos', 'tags', 'albums', 'bookmarks']: raise ValueError('Invalid table requested: %s.', table) - cur = self.sql.cursor() - cur.execute('SELECT last_id FROM id_numbers WHERE tab == ?', [table]) - fetch = cur.fetchone() - if fetch is None: + last_id = self.sql_select_one('SELECT last_id FROM id_numbers WHERE tab == ?', [table]) + if last_id is None: # Register new value new_id_int = 1 do_insert = True else: # Use database value - new_id_int = int(fetch[0]) + 1 + new_id_int = int(last_id[0]) + 1 do_insert = False new_id = str(new_id_int).rjust(self.config['id_length'], '0') + + pairs = { + 'tab': table, + 'last_id': new_id, + } if do_insert: - cur.execute('INSERT INTO id_numbers VALUES(?, ?)', [table, new_id]) + self.sql_insert(table='id_numbers', data=pairs) else: - cur.execute('UPDATE id_numbers SET last_id = ? WHERE tab == ?', [new_id, table]) + self.sql_update(table='id_numbers', pairs=pairs, where_key='tab') return new_id def get_cached_frozen_children(self): @@ -1429,39 +1413,38 @@ class PhotoDB( def get_thing_by_id(self, thing_type, thing_id): thing_map = _THING_CLASSES[thing_type] - if isinstance(thing_id, thing_map['class']): + thing_class = thing_map['class'] + if isinstance(thing_id, thing_class): thing_id = thing_id.id - cache = self.caches[thing_type] + thing_cache = self.caches[thing_type] try: #self.log.debug('Cache hit for %s %s', thing_type, thing_id) - val = cache[thing_id] + val = thing_cache[thing_id] return val except KeyError: pass query = 'SELECT * FROM %s WHERE id == ?' % thing_map['table'] - cur = self.sql.cursor() - cur.execute(query, [thing_id]) - thing = cur.fetchone() - if thing is None: + bindings = [thing_id] + thing_row = self.sql_select_one(query, bindings) + if thing_row is None: raise thing_map['exception'](thing_id) - thing = thing_map['class'](self, thing) - cache[thing_id] = thing + thing = thing_class(self, thing_row) + thing_cache[thing_id] = thing return thing def get_things(self, thing_type, orderby=None): thing_map = _THING_CLASSES[thing_type] - cur = self.sql.cursor() if orderby: - cur.execute('SELECT * FROM %s ORDER BY %s' % (thing_map['table'], orderby)) + query = 'SELECT * FROM %s ORDER BY %s' % (thing_map['table'], orderby) else: - cur.execute('SELECT * FROM %s' % thing_map['table']) + query = 'SELECT * FROM %s' % thing_map['table'] - things = cur.fetchall() - for thing in things: - thing = thing_map['class'](self, db_row=thing) + things = self.sql_select(query) + for thing_row in things: + thing = thing_map['class'](self, db_row=thing_row) yield thing def load_config(self):