Assert table exists for any sql op involving argument tables.

At the moment, all of these functions are safe because they're
called with hardcoded tables determined by other code, not user input.
But while I was working in this area I felt it would be good to add
a safety check just in case.
This commit is contained in:
voussoir 2020-02-04 18:15:14 -08:00
parent 64f9eb5f2b
commit d6d7521bce
2 changed files with 23 additions and 0 deletions

View file

@ -157,6 +157,14 @@ class WrongLogin(EtiquetteException):
error_message = 'Wrong username-password combination.' error_message = 'Wrong username-password combination.'
# SQL ERRORS
class BadSQL(EtiquetteException):
pass
class BadTable(BadSQL):
error_message = 'Table "{}" does not exist.'
# GENERAL ERRORS # GENERAL ERRORS
class BadDataDirectory(EtiquetteException): class BadDataDirectory(EtiquetteException):
''' '''

View file

@ -685,6 +685,10 @@ class PDBSQLMixin:
self.on_commit_queue = [] self.on_commit_queue = []
self.savepoints = [] self.savepoints = []
def assert_table_exists(self, table):
if table not in self._cached_sql_tables:
raise exceptions.BadTable(table)
def commit(self, message=None): def commit(self, message=None):
if message is not None: if message is not None:
self.log.debug('Committing - %s.', message) self.log.debug('Committing - %s.', message)
@ -699,6 +703,12 @@ class PDBSQLMixin:
self.savepoints.clear() self.savepoints.clear()
self.sql.commit() self.sql.commit()
def get_sql_tables(self):
query = 'SELECT name FROM sqlite_master WHERE type = "table"'
cur = self.sql_execute(query)
tables = set(row[0] for row in cur.fetchall())
return tables
def rollback(self, savepoint=None): def rollback(self, savepoint=None):
''' '''
Given a savepoint, roll the database back to the moment before that Given a savepoint, roll the database back to the moment before that
@ -740,6 +750,7 @@ class PDBSQLMixin:
return savepoint_id return savepoint_id
def sql_delete(self, table, pairs): def sql_delete(self, table, pairs):
self.assert_table_exists(table)
(qmarks, bindings) = sqlhelpers.delete_filler(pairs) (qmarks, bindings) = sqlhelpers.delete_filler(pairs)
query = f'DELETE FROM {table} {qmarks}' query = f'DELETE FROM {table} {qmarks}'
self.sql_execute(query, bindings) self.sql_execute(query, bindings)
@ -752,6 +763,7 @@ class PDBSQLMixin:
return cur return cur
def sql_insert(self, table, data): def sql_insert(self, table, data):
self.assert_table_exists(table)
column_names = constants.SQL_COLUMNS[table] column_names = constants.SQL_COLUMNS[table]
(qmarks, bindings) = sqlhelpers.insert_filler(column_names, data) (qmarks, bindings) = sqlhelpers.insert_filler(column_names, data)
@ -771,6 +783,7 @@ class PDBSQLMixin:
return cur.fetchone() return cur.fetchone()
def sql_update(self, table, pairs, where_key): def sql_update(self, table, pairs, where_key):
self.assert_table_exists(table)
(qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key) (qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key)
query = f'UPDATE {table} {qmarks}' query = f'UPDATE {table} {qmarks}'
self.sql_execute(query, bindings) self.sql_execute(query, bindings)
@ -1332,6 +1345,8 @@ class PhotoDB(
self.log.setLevel(self.config['log_level']) self.log.setLevel(self.config['log_level'])
# OTHER # OTHER
self._cached_sql_tables = self.get_sql_tables()
self._cached_frozen_children = None self._cached_frozen_children = None
self.caches = { self.caches = {