Add new transaction locking and cleaner atomicity in worms.

This change has really made it easier to reason about database
transactions in my projects. When you use 'with db.transaction'
you know for sure that either the db will commit or rollback at
the end and you won't leave in a dirty state. And it will lock
out all other writers so nothing gets messed up.

Previously I was conflating atomicity of each function with
committing of the entire transaction, and that was causing me
grief. I think this is closer to correct.
master
voussoir 2022-07-15 21:52:43 -07:00
parent abfaf27cee
commit cbef38ba7f
No known key found for this signature in database
GPG Key ID: 5F7554F8C26DACCB
1 changed files with 181 additions and 30 deletions

View File

@ -6,8 +6,10 @@ import functools
import random import random
import re import re
import sqlite3 import sqlite3
import threading
import typing import typing
from voussoirkit import pathclass
from voussoirkit import sqlhelpers from voussoirkit import sqlhelpers
from voussoirkit import vlogging from voussoirkit import vlogging
@ -21,6 +23,12 @@ class WormException(Exception):
class BadTable(WormException): class BadTable(WormException):
pass pass
class NoTransaction(WormException):
pass
class TransactionActive(WormException):
pass
class DeletedObject(WormException): class DeletedObject(WormException):
''' '''
For when thing.deleted == True. For when thing.deleted == True.
@ -36,16 +44,13 @@ def slice_before(li, item):
index = li.index(item) index = li.index(item)
return li[:index] return li[:index]
def transaction(method): def atomic(method):
''' '''
This decorator can be added to functions that modify your worms database. This decorator can be added to functions that modify your worms database.
A savepoint is opened, then your function is run, then we roll back to the A savepoint is opened, then your function is run. If an exception is raised,
savepoint if an exception is raised. we roll back to the savepoint.
This decorator adds the keyword argument 'commit' to your function, so that This decorator adds the attribute 'is_worms_atomic = True' to your
callers can commit it immediately.
This decorator adds the attribute 'is_worms_transaction = True' to your
function. You can use this to distinguish readonly vs writing methods during function. You can use this to distinguish readonly vs writing methods during
runtime. runtime.
@ -55,7 +60,7 @@ def transaction(method):
the action's failure. the action's failure.
''' '''
@functools.wraps(method) @functools.wraps(method)
def wrapped_transaction(self, *args, commit=False, **kwargs): def wrapped_atomic(self, *args, **kwargs):
if isinstance(self, Object): if isinstance(self, Object):
self.assert_not_deleted() self.assert_not_deleted()
@ -63,6 +68,7 @@ def transaction(method):
is_root = len(database.savepoints) == 0 is_root = len(database.savepoints) == 0
savepoint_id = database.savepoint(message=method.__qualname__) savepoint_id = database.savepoint(message=method.__qualname__)
log.loud(f'{method.__qualname__} got savepoint {savepoint_id}.')
try: try:
result = method(self, *args, **kwargs) result = method(self, *args, **kwargs)
@ -74,19 +80,36 @@ def transaction(method):
if isinstance(result, raise_without_rollback): if isinstance(result, raise_without_rollback):
raise result.exc from result.exc raise result.exc from result.exc
if commit: if not is_root:
database.commit(message=method.__qualname__)
elif not is_root:
# In order to prevent a huge pile-up of savepoints when a # In order to prevent a huge pile-up of savepoints when a
# @transaction calls another @transaction many times, the sub-call # @transaction calls another @transaction many times, the sub-call
# savepoints are removed from the stack. When an exception occurs, # savepoints are removed from the stack. When an exception occurs,
# we're going to rollback from the rootmost savepoint anyway, we'll # we're going to rollback from the rootmost savepoint anyway, we'll
# never rollback one sub-transaction. # never rollback one sub-transaction.
database.release_savepoint(savepoint=savepoint_id) database.release_savepoint(savepoint=savepoint_id)
return result return result
wrapped_transaction.is_worms_transaction = True wrapped_atomic.is_worms_atomic = True
return wrapped_transaction return wrapped_atomic
class TransactionContextManager:
def __init__(self, database):
self.database = database
def __enter__(self):
log.loud('Entering transaction.')
self.database.begin()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
log.loud('Exiting transaction.')
if exc_type is not None:
log.loud(f'Transaction raised {exc_type}.')
self.database.rollback()
raise exc_value
self.database.commit()
class Database(metaclass=abc.ABCMeta): class Database(metaclass=abc.ABCMeta):
''' '''
@ -98,11 +121,17 @@ class Database(metaclass=abc.ABCMeta):
''' '''
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# Used for transaction # Used for @atomic decorator
self._worms_database = self self._worms_database = self
self.on_commit_queue = [] self.on_commit_queue = []
self.on_rollback_queue = [] self.on_rollback_queue = []
self.savepoints = [] self.savepoints = []
# To prevent two transactions from running at the same time in different
# threads, and committing the database in an odd state, we lock out and
# run one transaction at a time.
self._worms_transaction_lock = threading.Lock()
self._worms_transaction_owner = None
self.transaction = TransactionContextManager(database=self)
# If your IDs are integers, you could change this to int. This way, when # If your IDs are integers, you could change this to int. This way, when
# taking user input as strings, they will automatically be converted to # taking user input as strings, they will automatically be converted to
# int when querying and caching, and you don't have to do the conversion # int when querying and caching, and you don't have to do the conversion
@ -125,23 +154,105 @@ class Database(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def _init_sql(self): def _init_sql(self):
''' '''
Your subclass needs to set self.sql, which is a database connection. Your subclass needs to prepare self.sql_read and self.sql_write, which
are both connection objects. They can be the same object if you want, or
they can be separate connections so that the readers can not get blocked
by the writers.
It is recommended to set self.sql.row_factory = sqlite3.Row so that you You can do it yourself or use the provided _init_connections to get the
get dictionary-style named access to row members in your objects' init. basic handles going. Then use the rest of this method to do any other
setup your application needs.
''' '''
raise NotImplementedError raise NotImplementedError
def _make_sqlite_read_connection(self, path):
'''
Provided for convenience of _init_sql.
'''
if isinstance(path, pathclass.Path):
path = path.absolute_path
if path == ':memory:':
sql_read = sqlite3.connect('file:memdb1?mode=memory&cache=shared&mode=ro', uri=True)
sql_read.row_factory = sqlite3.Row
else:
log.debug('Connecting to sqlite file "%s".', path)
sql_read = sqlite3.connect(f'file:{path}?mode=ro', uri=True)
sql_read.row_factory = sqlite3.Row
return sql_read
def _make_sqlite_write_connection(self, path):
if isinstance(path, pathclass.Path):
path = path.absolute_path
if path == ':memory:':
sql_write = sqlite3.connect('file:memdb1?mode=memory&cache=shared', uri=True)
sql_write.row_factory = sqlite3.Row
else:
log.debug('Connecting to sqlite file "%s".', path)
sql_write = sqlite3.connect(path)
sql_write.row_factory = sqlite3.Row
return sql_write
def assert_no_transaction(self) -> None:
thread_id = threading.current_thread().ident
if self._worms_transaction_owner == thread_id:
raise TransactionActive()
def assert_transaction_active(self) -> None:
thread_id = threading.current_thread().ident
if self._worms_transaction_owner != thread_id:
raise NoTransaction()
def acquire_transaction_lock(self):
'''
If no transaction is running, the caller gets the lock.
If a transaction is running on the same thread as the caller, the caller
does not get the lock but the function returns so it can do its work,
since it is a descendant of the original transaction call.
If a transaction is running and the caller is on a different thread, it
gets blocked until the previous transaction finishes.
'''
# Don't worry about race conditions, ownership of lock changing while
# the if statement is evaluating, because this individual thread cannot
# be checking its identity and releasing the lock at the same time! If
# transaction_owner is the current thread, we know that will remain
# true until this thread releases it, which can't happen at the same
# time here.
thread_id = threading.current_thread().ident
if self._worms_transaction_owner == thread_id:
return False
log.loud(f'{thread_id} wants the transaction lock.')
self._worms_transaction_lock.acquire()
log.loud(f'{thread_id} has the transaction lock.')
self._worms_transaction_owner = thread_id
return True
def assert_table_exists(self, table) -> None: def assert_table_exists(self, table) -> None:
if table not in self.COLUMN_INDEX: if table not in self.COLUMN_INDEX:
raise BadTable(f'Table {table} does not exist.') raise BadTable(f'Table {table} does not exist.')
def begin(self):
self.acquire_transaction_lock()
self.execute('BEGIN')
def close(self): def close(self):
# Wrapped in hasattr because if the object fails __init__, Python will # Wrapped in hasattr because if the object fails __init__, Python will
# still call __del__ and thus close(), even though the attributes # still call __del__ and thus close(), even though the attributes
# we're trying to clean up never got set. # we're trying to clean up never got set.
if hasattr(self, 'sql'): if not hasattr(self, 'sql_read'):
self.sql.close() return
if self._worms_transaction_owner:
self.rollback()
self.sql_read.close()
del self.sql_read
self.sql_write.close()
del self.sql_write
def commit(self, message=None) -> None: def commit(self, message=None) -> None:
if message is None: if message is None:
@ -157,6 +268,7 @@ class Database(metaclass=abc.ABCMeta):
args = task.get('args', []) args = task.get('args', [])
kwargs = task.get('kwargs', {}) kwargs = task.get('kwargs', {})
action = task['action'] action = task['action']
log.loud(f'{action} {args} {kwargs}')
try: try:
action(*args, **kwargs) action(*args, **kwargs)
except Exception as exc: except Exception as exc:
@ -165,8 +277,9 @@ class Database(metaclass=abc.ABCMeta):
raise raise
self.savepoints.clear() self.savepoints.clear()
self.sql.commit() self.sql_write.commit()
self.last_commit_id = RNG.getrandbits(32) self.last_commit_id = RNG.getrandbits(32)
self.release_transaction_lock()
def delete(self, table, pairs) -> sqlite3.Cursor: def delete(self, table, pairs) -> sqlite3.Cursor:
if isinstance(table, type) and issubclass(table, Object): if isinstance(table, type) and issubclass(table, Object):
@ -176,10 +289,26 @@ class Database(metaclass=abc.ABCMeta):
query = f'DELETE FROM {table} {qmarks}' query = f'DELETE FROM {table} {qmarks}'
return self.execute(query, bindings) return self.execute(query, bindings)
def execute(self, query, bindings=[]): def execute_read(self, query, bindings=[]):
if bindings is None: if bindings is None:
bindings = [] bindings = []
cur = self.sql.cursor()
thread_id = threading.current_thread().ident
if self._worms_transaction_owner == thread_id:
sql = self.sql_write
else:
sql = self.sql_read
cur = sql.cursor()
log.loud('%s %s', query, bindings)
cur.execute(query, bindings)
return cur
def execute(self, query, bindings=[]):
self.assert_transaction_active()
if bindings is None:
bindings = []
cur = self.sql_write.cursor()
log.loud('%s %s', query, bindings) log.loud('%s %s', query, bindings)
cur.execute(query, bindings) cur.execute(query, bindings)
return cur return cur
@ -189,10 +318,11 @@ class Database(metaclass=abc.ABCMeta):
The problem with Python's default executescript is that it executes a The problem with Python's default executescript is that it executes a
COMMIT before running your script. If I wanted a commit I'd write one! COMMIT before running your script. If I wanted a commit I'd write one!
''' '''
self.assert_transaction_active()
lines = re.split(r';(:?\n|$)', script) lines = re.split(r';(:?\n|$)', script)
lines = (line.strip() for line in lines) lines = (line.strip() for line in lines)
lines = (line for line in lines if line) lines = (line for line in lines if line)
cur = self.sql.cursor() cur = self.sql_write.cursor()
for line in lines: for line in lines:
log.loud(line) log.loud(line)
cur.execute(line) cur.execute(line)
@ -335,6 +465,17 @@ class Database(metaclass=abc.ABCMeta):
return (good, bad) return (good, bad)
def pragma_read(self, key):
pragma = self.execute_read(f'PRAGMA {key}').fetchone()
if pragma is not None:
return pragma[0]
return None
def pragma_write(self, key, value) -> None:
# We are bypassing self.execute because some pragmas are not allowed to
# happen during transactions.
return self.sql_write.cursor().execute(f'PRAGMA {key} = {value}')
def release_savepoint(self, savepoint, allow_commit=False) -> None: def release_savepoint(self, savepoint, allow_commit=False) -> None:
''' '''
Releasing a savepoint removes that savepoint from the timeline, so that Releasing a savepoint removes that savepoint from the timeline, so that
@ -359,6 +500,19 @@ class Database(metaclass=abc.ABCMeta):
self.execute(f'RELEASE "{savepoint}"') self.execute(f'RELEASE "{savepoint}"')
self.savepoints = slice_before(self.savepoints, savepoint) self.savepoints = slice_before(self.savepoints, savepoint)
def release_transaction_lock(self):
thread_id = threading.current_thread().ident
if not self._worms_transaction_lock.locked():
return
if self._worms_transaction_owner != thread_id:
log.warning(f'{thread_id} tried to release the transaction lock without holding it.')
return
log.loud(f'{thread_id} releases the transaction lock.')
self._worms_transaction_owner = None
self._worms_transaction_lock.release()
def rollback(self, savepoint=None) -> None: def rollback(self, savepoint=None) -> None:
''' '''
Given a savepoint, roll the database back to the moment before that Given a savepoint, roll the database back to the moment before that
@ -371,10 +525,6 @@ class Database(metaclass=abc.ABCMeta):
log.warn('Tried to restore nonexistent savepoint %s.', savepoint) log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
return return
if len(self.savepoints) == 0:
log.debug('Nothing to roll back.')
return
while len(self.on_rollback_queue) > 0: while len(self.on_rollback_queue) > 0:
task = self.on_rollback_queue.pop(-1) task = self.on_rollback_queue.pop(-1)
if task == savepoint: if task == savepoint:
@ -397,6 +547,7 @@ class Database(metaclass=abc.ABCMeta):
self.execute('ROLLBACK') self.execute('ROLLBACK')
self.savepoints.clear() self.savepoints.clear()
self.on_commit_queue.clear() self.on_commit_queue.clear()
self.release_transaction_lock()
def savepoint(self, message=None) -> int: def savepoint(self, message=None) -> int:
savepoint_id = RNG.getrandbits(32) savepoint_id = RNG.getrandbits(32)
@ -412,7 +563,7 @@ class Database(metaclass=abc.ABCMeta):
return savepoint_id return savepoint_id
def select(self, query, bindings=None) -> typing.Iterable: def select(self, query, bindings=None) -> typing.Iterable:
cur = self.execute(query, bindings) cur = self.execute_read(query, bindings)
while True: while True:
fetch = cur.fetchone() fetch = cur.fetchone()
if fetch is None: if fetch is None:
@ -432,7 +583,7 @@ class Database(metaclass=abc.ABCMeta):
''' '''
Select a single row, or None if no rows match your query. Select a single row, or None if no rows match your query.
''' '''
cur = self.execute(query, bindings) cur = self.execute_read(query, bindings)
return cur.fetchone() return cur.fetchone()
def select_one_value(self, query, bindings=None, fallback=None): def select_one_value(self, query, bindings=None, fallback=None):
@ -441,7 +592,7 @@ class Database(metaclass=abc.ABCMeta):
your query. The fallback can help you distinguish between rows that your query. The fallback can help you distinguish between rows that
don't exist and a null value. don't exist and a null value.
''' '''
cur = self.execute(query, bindings) cur = self.execute_read(query, bindings)
row = cur.fetchone() row = cur.fetchone()
if row: if row:
return row[0] return row[0]