Assert table exists for any sql op involving argument tables.

At the moment, all of these functions are safe because they're
called with hardcoded tables determined by other code, not user input.
But while I was working in this area I felt it would be good to add
a safety check just in case.
This commit is contained in:
voussoir 2020-02-04 18:15:14 -08:00
parent 64f9eb5f2b
commit d6d7521bce
2 changed files with 23 additions and 0 deletions

View file

@ -157,6 +157,14 @@ class WrongLogin(EtiquetteException):
error_message = 'Wrong username-password combination.'
# SQL ERRORS
class BadSQL(EtiquetteException):
pass
class BadTable(BadSQL):
error_message = 'Table "{}" does not exist.'
# GENERAL ERRORS
class BadDataDirectory(EtiquetteException):
'''

View file

@ -685,6 +685,10 @@ class PDBSQLMixin:
self.on_commit_queue = []
self.savepoints = []
def assert_table_exists(self, table):
if table not in self._cached_sql_tables:
raise exceptions.BadTable(table)
def commit(self, message=None):
if message is not None:
self.log.debug('Committing - %s.', message)
@ -699,6 +703,12 @@ class PDBSQLMixin:
self.savepoints.clear()
self.sql.commit()
def get_sql_tables(self):
query = 'SELECT name FROM sqlite_master WHERE type = "table"'
cur = self.sql_execute(query)
tables = set(row[0] for row in cur.fetchall())
return tables
def rollback(self, savepoint=None):
'''
Given a savepoint, roll the database back to the moment before that
@ -740,6 +750,7 @@ class PDBSQLMixin:
return savepoint_id
def sql_delete(self, table, pairs):
self.assert_table_exists(table)
(qmarks, bindings) = sqlhelpers.delete_filler(pairs)
query = f'DELETE FROM {table} {qmarks}'
self.sql_execute(query, bindings)
@ -752,6 +763,7 @@ class PDBSQLMixin:
return cur
def sql_insert(self, table, data):
self.assert_table_exists(table)
column_names = constants.SQL_COLUMNS[table]
(qmarks, bindings) = sqlhelpers.insert_filler(column_names, data)
@ -771,6 +783,7 @@ class PDBSQLMixin:
return cur.fetchone()
def sql_update(self, table, pairs, where_key):
self.assert_table_exists(table)
(qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key)
query = f'UPDATE {table} {qmarks}'
self.sql_execute(query, bindings)
@ -1332,6 +1345,8 @@ class PhotoDB(
self.log.setLevel(self.config['log_level'])
# OTHER
self._cached_sql_tables = self.get_sql_tables()
self._cached_frozen_children = None
self.caches = {