diff --git a/etiquette/exceptions.py b/etiquette/exceptions.py index 517298b..6419760 100644 --- a/etiquette/exceptions.py +++ b/etiquette/exceptions.py @@ -157,6 +157,14 @@ class WrongLogin(EtiquetteException): error_message = 'Wrong username-password combination.' +# SQL ERRORS +class BadSQL(EtiquetteException): + pass + +class BadTable(BadSQL): + error_message = 'Table "{}" does not exist.' + + # GENERAL ERRORS class BadDataDirectory(EtiquetteException): ''' diff --git a/etiquette/photodb.py b/etiquette/photodb.py index a696318..4f113bf 100644 --- a/etiquette/photodb.py +++ b/etiquette/photodb.py @@ -685,6 +685,10 @@ class PDBSQLMixin: self.on_commit_queue = [] self.savepoints = [] + def assert_table_exists(self, table): + if table not in self._cached_sql_tables: + raise exceptions.BadTable(table) + def commit(self, message=None): if message is not None: self.log.debug('Committing - %s.', message) @@ -699,6 +703,12 @@ class PDBSQLMixin: self.savepoints.clear() self.sql.commit() + def get_sql_tables(self): + query = 'SELECT name FROM sqlite_master WHERE type = "table"' + cur = self.sql_execute(query) + tables = set(row[0] for row in cur.fetchall()) + return tables + def rollback(self, savepoint=None): ''' Given a savepoint, roll the database back to the moment before that @@ -740,6 +750,7 @@ class PDBSQLMixin: return savepoint_id def sql_delete(self, table, pairs): + self.assert_table_exists(table) (qmarks, bindings) = sqlhelpers.delete_filler(pairs) query = f'DELETE FROM {table} {qmarks}' self.sql_execute(query, bindings) @@ -752,6 +763,7 @@ class PDBSQLMixin: return cur def sql_insert(self, table, data): + self.assert_table_exists(table) column_names = constants.SQL_COLUMNS[table] (qmarks, bindings) = sqlhelpers.insert_filler(column_names, data) @@ -771,6 +783,7 @@ class PDBSQLMixin: return cur.fetchone() def sql_update(self, table, pairs, where_key): + self.assert_table_exists(table) (qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key) query = f'UPDATE {table} {qmarks}' self.sql_execute(query, bindings) @@ -1332,6 +1345,8 @@ class PhotoDB( self.log.setLevel(self.config['log_level']) # OTHER + self._cached_sql_tables = self.get_sql_tables() + self._cached_frozen_children = None self.caches = {