diff --git a/voussoirkit/worms.py b/voussoirkit/worms.py index ffb2330..5a4b337 100644 --- a/voussoirkit/worms.py +++ b/voussoirkit/worms.py @@ -6,8 +6,10 @@ import functools import random import re import sqlite3 +import threading import typing +from voussoirkit import pathclass from voussoirkit import sqlhelpers from voussoirkit import vlogging @@ -21,6 +23,12 @@ class WormException(Exception): class BadTable(WormException): pass +class NoTransaction(WormException): + pass + +class TransactionActive(WormException): + pass + class DeletedObject(WormException): ''' For when thing.deleted == True. @@ -36,16 +44,13 @@ def slice_before(li, item): index = li.index(item) return li[:index] -def transaction(method): +def atomic(method): ''' This decorator can be added to functions that modify your worms database. - A savepoint is opened, then your function is run, then we roll back to the - savepoint if an exception is raised. + A savepoint is opened, then your function is run. If an exception is raised, + we roll back to the savepoint. - This decorator adds the keyword argument 'commit' to your function, so that - callers can commit it immediately. - - This decorator adds the attribute 'is_worms_transaction = True' to your + This decorator adds the attribute 'is_worms_atomic = True' to your function. You can use this to distinguish readonly vs writing methods during runtime. @@ -55,7 +60,7 @@ def transaction(method): the action's failure. ''' @functools.wraps(method) - def wrapped_transaction(self, *args, commit=False, **kwargs): + def wrapped_atomic(self, *args, **kwargs): if isinstance(self, Object): self.assert_not_deleted() @@ -63,6 +68,7 @@ def transaction(method): is_root = len(database.savepoints) == 0 savepoint_id = database.savepoint(message=method.__qualname__) + log.loud(f'{method.__qualname__} got savepoint {savepoint_id}.') try: result = method(self, *args, **kwargs) @@ -74,19 +80,36 @@ def transaction(method): if isinstance(result, raise_without_rollback): raise result.exc from result.exc - if commit: - database.commit(message=method.__qualname__) - elif not is_root: + if not is_root: # In order to prevent a huge pile-up of savepoints when a # @transaction calls another @transaction many times, the sub-call # savepoints are removed from the stack. When an exception occurs, # we're going to rollback from the rootmost savepoint anyway, we'll # never rollback one sub-transaction. database.release_savepoint(savepoint=savepoint_id) + return result - wrapped_transaction.is_worms_transaction = True - return wrapped_transaction + wrapped_atomic.is_worms_atomic = True + return wrapped_atomic + +class TransactionContextManager: + def __init__(self, database): + self.database = database + + def __enter__(self): + log.loud('Entering transaction.') + self.database.begin() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + log.loud('Exiting transaction.') + if exc_type is not None: + log.loud(f'Transaction raised {exc_type}.') + self.database.rollback() + raise exc_value + + self.database.commit() class Database(metaclass=abc.ABCMeta): ''' @@ -98,11 +121,17 @@ class Database(metaclass=abc.ABCMeta): ''' def __init__(self): super().__init__() - # Used for transaction + # Used for @atomic decorator self._worms_database = self self.on_commit_queue = [] self.on_rollback_queue = [] self.savepoints = [] + # To prevent two transactions from running at the same time in different + # threads, and committing the database in an odd state, we lock out and + # run one transaction at a time. + self._worms_transaction_lock = threading.Lock() + self._worms_transaction_owner = None + self.transaction = TransactionContextManager(database=self) # If your IDs are integers, you could change this to int. This way, when # taking user input as strings, they will automatically be converted to # int when querying and caching, and you don't have to do the conversion @@ -125,23 +154,105 @@ class Database(metaclass=abc.ABCMeta): @abc.abstractmethod def _init_sql(self): ''' - Your subclass needs to set self.sql, which is a database connection. + Your subclass needs to prepare self.sql_read and self.sql_write, which + are both connection objects. They can be the same object if you want, or + they can be separate connections so that the readers can not get blocked + by the writers. - It is recommended to set self.sql.row_factory = sqlite3.Row so that you - get dictionary-style named access to row members in your objects' init. + You can do it yourself or use the provided _init_connections to get the + basic handles going. Then use the rest of this method to do any other + setup your application needs. ''' raise NotImplementedError + def _make_sqlite_read_connection(self, path): + ''' + Provided for convenience of _init_sql. + ''' + if isinstance(path, pathclass.Path): + path = path.absolute_path + if path == ':memory:': + sql_read = sqlite3.connect('file:memdb1?mode=memory&cache=shared&mode=ro', uri=True) + sql_read.row_factory = sqlite3.Row + else: + log.debug('Connecting to sqlite file "%s".', path) + sql_read = sqlite3.connect(f'file:{path}?mode=ro', uri=True) + sql_read.row_factory = sqlite3.Row + return sql_read + + def _make_sqlite_write_connection(self, path): + if isinstance(path, pathclass.Path): + path = path.absolute_path + + if path == ':memory:': + sql_write = sqlite3.connect('file:memdb1?mode=memory&cache=shared', uri=True) + sql_write.row_factory = sqlite3.Row + else: + log.debug('Connecting to sqlite file "%s".', path) + sql_write = sqlite3.connect(path) + sql_write.row_factory = sqlite3.Row + return sql_write + + def assert_no_transaction(self) -> None: + thread_id = threading.current_thread().ident + if self._worms_transaction_owner == thread_id: + raise TransactionActive() + + def assert_transaction_active(self) -> None: + thread_id = threading.current_thread().ident + if self._worms_transaction_owner != thread_id: + raise NoTransaction() + + def acquire_transaction_lock(self): + ''' + If no transaction is running, the caller gets the lock. + + If a transaction is running on the same thread as the caller, the caller + does not get the lock but the function returns so it can do its work, + since it is a descendant of the original transaction call. + + If a transaction is running and the caller is on a different thread, it + gets blocked until the previous transaction finishes. + ''' + # Don't worry about race conditions, ownership of lock changing while + # the if statement is evaluating, because this individual thread cannot + # be checking its identity and releasing the lock at the same time! If + # transaction_owner is the current thread, we know that will remain + # true until this thread releases it, which can't happen at the same + # time here. + thread_id = threading.current_thread().ident + if self._worms_transaction_owner == thread_id: + return False + + log.loud(f'{thread_id} wants the transaction lock.') + self._worms_transaction_lock.acquire() + log.loud(f'{thread_id} has the transaction lock.') + self._worms_transaction_owner = thread_id + return True + def assert_table_exists(self, table) -> None: if table not in self.COLUMN_INDEX: raise BadTable(f'Table {table} does not exist.') + def begin(self): + self.acquire_transaction_lock() + self.execute('BEGIN') + def close(self): # Wrapped in hasattr because if the object fails __init__, Python will # still call __del__ and thus close(), even though the attributes # we're trying to clean up never got set. - if hasattr(self, 'sql'): - self.sql.close() + if not hasattr(self, 'sql_read'): + return + + if self._worms_transaction_owner: + self.rollback() + + self.sql_read.close() + del self.sql_read + + self.sql_write.close() + del self.sql_write def commit(self, message=None) -> None: if message is None: @@ -157,6 +268,7 @@ class Database(metaclass=abc.ABCMeta): args = task.get('args', []) kwargs = task.get('kwargs', {}) action = task['action'] + log.loud(f'{action} {args} {kwargs}') try: action(*args, **kwargs) except Exception as exc: @@ -165,8 +277,9 @@ class Database(metaclass=abc.ABCMeta): raise self.savepoints.clear() - self.sql.commit() + self.sql_write.commit() self.last_commit_id = RNG.getrandbits(32) + self.release_transaction_lock() def delete(self, table, pairs) -> sqlite3.Cursor: if isinstance(table, type) and issubclass(table, Object): @@ -176,10 +289,26 @@ class Database(metaclass=abc.ABCMeta): query = f'DELETE FROM {table} {qmarks}' return self.execute(query, bindings) - def execute(self, query, bindings=[]): + def execute_read(self, query, bindings=[]): if bindings is None: bindings = [] - cur = self.sql.cursor() + + thread_id = threading.current_thread().ident + if self._worms_transaction_owner == thread_id: + sql = self.sql_write + else: + sql = self.sql_read + + cur = sql.cursor() + log.loud('%s %s', query, bindings) + cur.execute(query, bindings) + return cur + + def execute(self, query, bindings=[]): + self.assert_transaction_active() + if bindings is None: + bindings = [] + cur = self.sql_write.cursor() log.loud('%s %s', query, bindings) cur.execute(query, bindings) return cur @@ -189,10 +318,11 @@ class Database(metaclass=abc.ABCMeta): The problem with Python's default executescript is that it executes a COMMIT before running your script. If I wanted a commit I'd write one! ''' + self.assert_transaction_active() lines = re.split(r';(:?\n|$)', script) lines = (line.strip() for line in lines) lines = (line for line in lines if line) - cur = self.sql.cursor() + cur = self.sql_write.cursor() for line in lines: log.loud(line) cur.execute(line) @@ -335,6 +465,17 @@ class Database(metaclass=abc.ABCMeta): return (good, bad) + def pragma_read(self, key): + pragma = self.execute_read(f'PRAGMA {key}').fetchone() + if pragma is not None: + return pragma[0] + return None + + def pragma_write(self, key, value) -> None: + # We are bypassing self.execute because some pragmas are not allowed to + # happen during transactions. + return self.sql_write.cursor().execute(f'PRAGMA {key} = {value}') + def release_savepoint(self, savepoint, allow_commit=False) -> None: ''' Releasing a savepoint removes that savepoint from the timeline, so that @@ -359,6 +500,19 @@ class Database(metaclass=abc.ABCMeta): self.execute(f'RELEASE "{savepoint}"') self.savepoints = slice_before(self.savepoints, savepoint) + def release_transaction_lock(self): + thread_id = threading.current_thread().ident + if not self._worms_transaction_lock.locked(): + return + + if self._worms_transaction_owner != thread_id: + log.warning(f'{thread_id} tried to release the transaction lock without holding it.') + return + + log.loud(f'{thread_id} releases the transaction lock.') + self._worms_transaction_owner = None + self._worms_transaction_lock.release() + def rollback(self, savepoint=None) -> None: ''' Given a savepoint, roll the database back to the moment before that @@ -371,10 +525,6 @@ class Database(metaclass=abc.ABCMeta): log.warn('Tried to restore nonexistent savepoint %s.', savepoint) return - if len(self.savepoints) == 0: - log.debug('Nothing to roll back.') - return - while len(self.on_rollback_queue) > 0: task = self.on_rollback_queue.pop(-1) if task == savepoint: @@ -397,6 +547,7 @@ class Database(metaclass=abc.ABCMeta): self.execute('ROLLBACK') self.savepoints.clear() self.on_commit_queue.clear() + self.release_transaction_lock() def savepoint(self, message=None) -> int: savepoint_id = RNG.getrandbits(32) @@ -412,7 +563,7 @@ class Database(metaclass=abc.ABCMeta): return savepoint_id def select(self, query, bindings=None) -> typing.Iterable: - cur = self.execute(query, bindings) + cur = self.execute_read(query, bindings) while True: fetch = cur.fetchone() if fetch is None: @@ -432,7 +583,7 @@ class Database(metaclass=abc.ABCMeta): ''' Select a single row, or None if no rows match your query. ''' - cur = self.execute(query, bindings) + cur = self.execute_read(query, bindings) return cur.fetchone() def select_one_value(self, query, bindings=None, fallback=None): @@ -441,7 +592,7 @@ class Database(metaclass=abc.ABCMeta): your query. The fallback can help you distinguish between rows that don't exist and a null value. ''' - cur = self.execute(query, bindings) + cur = self.execute_read(query, bindings) row = cur.fetchone() if row: return row[0]