Add new transaction locking and cleaner atomicity in worms.

This change has really made it easier to reason about database
transactions in my projects. When you use 'with db.transaction'
you know for sure that either the db will commit or rollback at
the end and you won't leave in a dirty state. And it will lock
out all other writers so nothing gets messed up.

Previously I was conflating atomicity of each function with
committing of the entire transaction, and that was causing me
grief. I think this is closer to correct.
This commit is contained in:
voussoir 2022-07-15 21:52:43 -07:00
parent abfaf27cee
commit cbef38ba7f
No known key found for this signature in database
GPG key ID: 5F7554F8C26DACCB

View file

@ -6,8 +6,10 @@ import functools
import random
import re
import sqlite3
import threading
import typing
from voussoirkit import pathclass
from voussoirkit import sqlhelpers
from voussoirkit import vlogging
@ -21,6 +23,12 @@ class WormException(Exception):
class BadTable(WormException):
pass
class NoTransaction(WormException):
pass
class TransactionActive(WormException):
pass
class DeletedObject(WormException):
'''
For when thing.deleted == True.
@ -36,16 +44,13 @@ def slice_before(li, item):
index = li.index(item)
return li[:index]
def transaction(method):
def atomic(method):
'''
This decorator can be added to functions that modify your worms database.
A savepoint is opened, then your function is run, then we roll back to the
savepoint if an exception is raised.
A savepoint is opened, then your function is run. If an exception is raised,
we roll back to the savepoint.
This decorator adds the keyword argument 'commit' to your function, so that
callers can commit it immediately.
This decorator adds the attribute 'is_worms_transaction = True' to your
This decorator adds the attribute 'is_worms_atomic = True' to your
function. You can use this to distinguish readonly vs writing methods during
runtime.
@ -55,7 +60,7 @@ def transaction(method):
the action's failure.
'''
@functools.wraps(method)
def wrapped_transaction(self, *args, commit=False, **kwargs):
def wrapped_atomic(self, *args, **kwargs):
if isinstance(self, Object):
self.assert_not_deleted()
@ -63,6 +68,7 @@ def transaction(method):
is_root = len(database.savepoints) == 0
savepoint_id = database.savepoint(message=method.__qualname__)
log.loud(f'{method.__qualname__} got savepoint {savepoint_id}.')
try:
result = method(self, *args, **kwargs)
@ -74,19 +80,36 @@ def transaction(method):
if isinstance(result, raise_without_rollback):
raise result.exc from result.exc
if commit:
database.commit(message=method.__qualname__)
elif not is_root:
if not is_root:
# In order to prevent a huge pile-up of savepoints when a
# @transaction calls another @transaction many times, the sub-call
# savepoints are removed from the stack. When an exception occurs,
# we're going to rollback from the rootmost savepoint anyway, we'll
# never rollback one sub-transaction.
database.release_savepoint(savepoint=savepoint_id)
return result
wrapped_transaction.is_worms_transaction = True
return wrapped_transaction
wrapped_atomic.is_worms_atomic = True
return wrapped_atomic
class TransactionContextManager:
def __init__(self, database):
self.database = database
def __enter__(self):
log.loud('Entering transaction.')
self.database.begin()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
log.loud('Exiting transaction.')
if exc_type is not None:
log.loud(f'Transaction raised {exc_type}.')
self.database.rollback()
raise exc_value
self.database.commit()
class Database(metaclass=abc.ABCMeta):
'''
@ -98,11 +121,17 @@ class Database(metaclass=abc.ABCMeta):
'''
def __init__(self):
super().__init__()
# Used for transaction
# Used for @atomic decorator
self._worms_database = self
self.on_commit_queue = []
self.on_rollback_queue = []
self.savepoints = []
# To prevent two transactions from running at the same time in different
# threads, and committing the database in an odd state, we lock out and
# run one transaction at a time.
self._worms_transaction_lock = threading.Lock()
self._worms_transaction_owner = None
self.transaction = TransactionContextManager(database=self)
# If your IDs are integers, you could change this to int. This way, when
# taking user input as strings, they will automatically be converted to
# int when querying and caching, and you don't have to do the conversion
@ -125,23 +154,105 @@ class Database(metaclass=abc.ABCMeta):
@abc.abstractmethod
def _init_sql(self):
'''
Your subclass needs to set self.sql, which is a database connection.
Your subclass needs to prepare self.sql_read and self.sql_write, which
are both connection objects. They can be the same object if you want, or
they can be separate connections so that the readers can not get blocked
by the writers.
It is recommended to set self.sql.row_factory = sqlite3.Row so that you
get dictionary-style named access to row members in your objects' init.
You can do it yourself or use the provided _init_connections to get the
basic handles going. Then use the rest of this method to do any other
setup your application needs.
'''
raise NotImplementedError
def _make_sqlite_read_connection(self, path):
'''
Provided for convenience of _init_sql.
'''
if isinstance(path, pathclass.Path):
path = path.absolute_path
if path == ':memory:':
sql_read = sqlite3.connect('file:memdb1?mode=memory&cache=shared&mode=ro', uri=True)
sql_read.row_factory = sqlite3.Row
else:
log.debug('Connecting to sqlite file "%s".', path)
sql_read = sqlite3.connect(f'file:{path}?mode=ro', uri=True)
sql_read.row_factory = sqlite3.Row
return sql_read
def _make_sqlite_write_connection(self, path):
if isinstance(path, pathclass.Path):
path = path.absolute_path
if path == ':memory:':
sql_write = sqlite3.connect('file:memdb1?mode=memory&cache=shared', uri=True)
sql_write.row_factory = sqlite3.Row
else:
log.debug('Connecting to sqlite file "%s".', path)
sql_write = sqlite3.connect(path)
sql_write.row_factory = sqlite3.Row
return sql_write
def assert_no_transaction(self) -> None:
thread_id = threading.current_thread().ident
if self._worms_transaction_owner == thread_id:
raise TransactionActive()
def assert_transaction_active(self) -> None:
thread_id = threading.current_thread().ident
if self._worms_transaction_owner != thread_id:
raise NoTransaction()
def acquire_transaction_lock(self):
'''
If no transaction is running, the caller gets the lock.
If a transaction is running on the same thread as the caller, the caller
does not get the lock but the function returns so it can do its work,
since it is a descendant of the original transaction call.
If a transaction is running and the caller is on a different thread, it
gets blocked until the previous transaction finishes.
'''
# Don't worry about race conditions, ownership of lock changing while
# the if statement is evaluating, because this individual thread cannot
# be checking its identity and releasing the lock at the same time! If
# transaction_owner is the current thread, we know that will remain
# true until this thread releases it, which can't happen at the same
# time here.
thread_id = threading.current_thread().ident
if self._worms_transaction_owner == thread_id:
return False
log.loud(f'{thread_id} wants the transaction lock.')
self._worms_transaction_lock.acquire()
log.loud(f'{thread_id} has the transaction lock.')
self._worms_transaction_owner = thread_id
return True
def assert_table_exists(self, table) -> None:
if table not in self.COLUMN_INDEX:
raise BadTable(f'Table {table} does not exist.')
def begin(self):
self.acquire_transaction_lock()
self.execute('BEGIN')
def close(self):
# Wrapped in hasattr because if the object fails __init__, Python will
# still call __del__ and thus close(), even though the attributes
# we're trying to clean up never got set.
if hasattr(self, 'sql'):
self.sql.close()
if not hasattr(self, 'sql_read'):
return
if self._worms_transaction_owner:
self.rollback()
self.sql_read.close()
del self.sql_read
self.sql_write.close()
del self.sql_write
def commit(self, message=None) -> None:
if message is None:
@ -157,6 +268,7 @@ class Database(metaclass=abc.ABCMeta):
args = task.get('args', [])
kwargs = task.get('kwargs', {})
action = task['action']
log.loud(f'{action} {args} {kwargs}')
try:
action(*args, **kwargs)
except Exception as exc:
@ -165,8 +277,9 @@ class Database(metaclass=abc.ABCMeta):
raise
self.savepoints.clear()
self.sql.commit()
self.sql_write.commit()
self.last_commit_id = RNG.getrandbits(32)
self.release_transaction_lock()
def delete(self, table, pairs) -> sqlite3.Cursor:
if isinstance(table, type) and issubclass(table, Object):
@ -176,10 +289,26 @@ class Database(metaclass=abc.ABCMeta):
query = f'DELETE FROM {table} {qmarks}'
return self.execute(query, bindings)
def execute(self, query, bindings=[]):
def execute_read(self, query, bindings=[]):
if bindings is None:
bindings = []
cur = self.sql.cursor()
thread_id = threading.current_thread().ident
if self._worms_transaction_owner == thread_id:
sql = self.sql_write
else:
sql = self.sql_read
cur = sql.cursor()
log.loud('%s %s', query, bindings)
cur.execute(query, bindings)
return cur
def execute(self, query, bindings=[]):
self.assert_transaction_active()
if bindings is None:
bindings = []
cur = self.sql_write.cursor()
log.loud('%s %s', query, bindings)
cur.execute(query, bindings)
return cur
@ -189,10 +318,11 @@ class Database(metaclass=abc.ABCMeta):
The problem with Python's default executescript is that it executes a
COMMIT before running your script. If I wanted a commit I'd write one!
'''
self.assert_transaction_active()
lines = re.split(r';(:?\n|$)', script)
lines = (line.strip() for line in lines)
lines = (line for line in lines if line)
cur = self.sql.cursor()
cur = self.sql_write.cursor()
for line in lines:
log.loud(line)
cur.execute(line)
@ -335,6 +465,17 @@ class Database(metaclass=abc.ABCMeta):
return (good, bad)
def pragma_read(self, key):
pragma = self.execute_read(f'PRAGMA {key}').fetchone()
if pragma is not None:
return pragma[0]
return None
def pragma_write(self, key, value) -> None:
# We are bypassing self.execute because some pragmas are not allowed to
# happen during transactions.
return self.sql_write.cursor().execute(f'PRAGMA {key} = {value}')
def release_savepoint(self, savepoint, allow_commit=False) -> None:
'''
Releasing a savepoint removes that savepoint from the timeline, so that
@ -359,6 +500,19 @@ class Database(metaclass=abc.ABCMeta):
self.execute(f'RELEASE "{savepoint}"')
self.savepoints = slice_before(self.savepoints, savepoint)
def release_transaction_lock(self):
thread_id = threading.current_thread().ident
if not self._worms_transaction_lock.locked():
return
if self._worms_transaction_owner != thread_id:
log.warning(f'{thread_id} tried to release the transaction lock without holding it.')
return
log.loud(f'{thread_id} releases the transaction lock.')
self._worms_transaction_owner = None
self._worms_transaction_lock.release()
def rollback(self, savepoint=None) -> None:
'''
Given a savepoint, roll the database back to the moment before that
@ -371,10 +525,6 @@ class Database(metaclass=abc.ABCMeta):
log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
return
if len(self.savepoints) == 0:
log.debug('Nothing to roll back.')
return
while len(self.on_rollback_queue) > 0:
task = self.on_rollback_queue.pop(-1)
if task == savepoint:
@ -397,6 +547,7 @@ class Database(metaclass=abc.ABCMeta):
self.execute('ROLLBACK')
self.savepoints.clear()
self.on_commit_queue.clear()
self.release_transaction_lock()
def savepoint(self, message=None) -> int:
savepoint_id = RNG.getrandbits(32)
@ -412,7 +563,7 @@ class Database(metaclass=abc.ABCMeta):
return savepoint_id
def select(self, query, bindings=None) -> typing.Iterable:
cur = self.execute(query, bindings)
cur = self.execute_read(query, bindings)
while True:
fetch = cur.fetchone()
if fetch is None:
@ -432,7 +583,7 @@ class Database(metaclass=abc.ABCMeta):
'''
Select a single row, or None if no rows match your query.
'''
cur = self.execute(query, bindings)
cur = self.execute_read(query, bindings)
return cur.fetchone()
def select_one_value(self, query, bindings=None, fallback=None):
@ -441,7 +592,7 @@ class Database(metaclass=abc.ABCMeta):
your query. The fallback can help you distinguish between rows that
don't exist and a null value.
'''
cur = self.execute(query, bindings)
cur = self.execute_read(query, bindings)
row = cur.fetchone()
if row:
return row[0]