Assert table exists for any sql op involving argument tables.
At the moment, all of these functions are safe because they're called with hardcoded tables determined by other code, not user input. But while I was working in this area I felt it would be good to add a safety check just in case.
This commit is contained in:
parent
64f9eb5f2b
commit
d6d7521bce
2 changed files with 23 additions and 0 deletions
|
@ -157,6 +157,14 @@ class WrongLogin(EtiquetteException):
|
|||
error_message = 'Wrong username-password combination.'
|
||||
|
||||
|
||||
# SQL ERRORS
|
||||
class BadSQL(EtiquetteException):
|
||||
pass
|
||||
|
||||
class BadTable(BadSQL):
|
||||
error_message = 'Table "{}" does not exist.'
|
||||
|
||||
|
||||
# GENERAL ERRORS
|
||||
class BadDataDirectory(EtiquetteException):
|
||||
'''
|
||||
|
|
|
@ -685,6 +685,10 @@ class PDBSQLMixin:
|
|||
self.on_commit_queue = []
|
||||
self.savepoints = []
|
||||
|
||||
def assert_table_exists(self, table):
|
||||
if table not in self._cached_sql_tables:
|
||||
raise exceptions.BadTable(table)
|
||||
|
||||
def commit(self, message=None):
|
||||
if message is not None:
|
||||
self.log.debug('Committing - %s.', message)
|
||||
|
@ -699,6 +703,12 @@ class PDBSQLMixin:
|
|||
self.savepoints.clear()
|
||||
self.sql.commit()
|
||||
|
||||
def get_sql_tables(self):
|
||||
query = 'SELECT name FROM sqlite_master WHERE type = "table"'
|
||||
cur = self.sql_execute(query)
|
||||
tables = set(row[0] for row in cur.fetchall())
|
||||
return tables
|
||||
|
||||
def rollback(self, savepoint=None):
|
||||
'''
|
||||
Given a savepoint, roll the database back to the moment before that
|
||||
|
@ -740,6 +750,7 @@ class PDBSQLMixin:
|
|||
return savepoint_id
|
||||
|
||||
def sql_delete(self, table, pairs):
|
||||
self.assert_table_exists(table)
|
||||
(qmarks, bindings) = sqlhelpers.delete_filler(pairs)
|
||||
query = f'DELETE FROM {table} {qmarks}'
|
||||
self.sql_execute(query, bindings)
|
||||
|
@ -752,6 +763,7 @@ class PDBSQLMixin:
|
|||
return cur
|
||||
|
||||
def sql_insert(self, table, data):
|
||||
self.assert_table_exists(table)
|
||||
column_names = constants.SQL_COLUMNS[table]
|
||||
(qmarks, bindings) = sqlhelpers.insert_filler(column_names, data)
|
||||
|
||||
|
@ -771,6 +783,7 @@ class PDBSQLMixin:
|
|||
return cur.fetchone()
|
||||
|
||||
def sql_update(self, table, pairs, where_key):
|
||||
self.assert_table_exists(table)
|
||||
(qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key)
|
||||
query = f'UPDATE {table} {qmarks}'
|
||||
self.sql_execute(query, bindings)
|
||||
|
@ -1332,6 +1345,8 @@ class PhotoDB(
|
|||
self.log.setLevel(self.config['log_level'])
|
||||
|
||||
# OTHER
|
||||
self._cached_sql_tables = self.get_sql_tables()
|
||||
|
||||
self._cached_frozen_children = None
|
||||
|
||||
self.caches = {
|
||||
|
|
Loading…
Reference in a new issue