From d6d7521bceeb9c067ea6852f00892137460b47f7 Mon Sep 17 00:00:00 2001 From: Ethan Dalool Date: Tue, 4 Feb 2020 18:15:14 -0800 Subject: [PATCH] Assert table exists for any sql op involving argument tables. At the moment, all of these functions are safe because they're called with hardcoded tables determined by other code, not user input. But while I was working in this area I felt it would be good to add a safety check just in case. --- etiquette/exceptions.py | 8 ++++++++ etiquette/photodb.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/etiquette/exceptions.py b/etiquette/exceptions.py index 517298b..6419760 100644 --- a/etiquette/exceptions.py +++ b/etiquette/exceptions.py @@ -157,6 +157,14 @@ class WrongLogin(EtiquetteException): error_message = 'Wrong username-password combination.' +# SQL ERRORS +class BadSQL(EtiquetteException): + pass + +class BadTable(BadSQL): + error_message = 'Table "{}" does not exist.' + + # GENERAL ERRORS class BadDataDirectory(EtiquetteException): ''' diff --git a/etiquette/photodb.py b/etiquette/photodb.py index a696318..4f113bf 100644 --- a/etiquette/photodb.py +++ b/etiquette/photodb.py @@ -685,6 +685,10 @@ class PDBSQLMixin: self.on_commit_queue = [] self.savepoints = [] + def assert_table_exists(self, table): + if table not in self._cached_sql_tables: + raise exceptions.BadTable(table) + def commit(self, message=None): if message is not None: self.log.debug('Committing - %s.', message) @@ -699,6 +703,12 @@ class PDBSQLMixin: self.savepoints.clear() self.sql.commit() + def get_sql_tables(self): + query = 'SELECT name FROM sqlite_master WHERE type = "table"' + cur = self.sql_execute(query) + tables = set(row[0] for row in cur.fetchall()) + return tables + def rollback(self, savepoint=None): ''' Given a savepoint, roll the database back to the moment before that @@ -740,6 +750,7 @@ class PDBSQLMixin: return savepoint_id def sql_delete(self, table, pairs): + self.assert_table_exists(table) (qmarks, bindings) = sqlhelpers.delete_filler(pairs) query = f'DELETE FROM {table} {qmarks}' self.sql_execute(query, bindings) @@ -752,6 +763,7 @@ class PDBSQLMixin: return cur def sql_insert(self, table, data): + self.assert_table_exists(table) column_names = constants.SQL_COLUMNS[table] (qmarks, bindings) = sqlhelpers.insert_filler(column_names, data) @@ -771,6 +783,7 @@ class PDBSQLMixin: return cur.fetchone() def sql_update(self, table, pairs, where_key): + self.assert_table_exists(table) (qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key) query = f'UPDATE {table} {qmarks}' self.sql_execute(query, bindings) @@ -1332,6 +1345,8 @@ class PhotoDB( self.log.setLevel(self.config['log_level']) # OTHER + self._cached_sql_tables = self.get_sql_tables() + self._cached_frozen_children = None self.caches = {