Add new transaction locking and cleaner atomicity in worms.
This change has really made it easier to reason about database transactions in my projects. When you use 'with db.transaction' you know for sure that either the db will commit or rollback at the end and you won't leave in a dirty state. And it will lock out all other writers so nothing gets messed up. Previously I was conflating atomicity of each function with committing of the entire transaction, and that was causing me grief. I think this is closer to correct.
This commit is contained in:
parent
abfaf27cee
commit
cbef38ba7f
1 changed files with 181 additions and 30 deletions
|
@ -6,8 +6,10 @@ import functools
|
|||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import typing
|
||||
|
||||
from voussoirkit import pathclass
|
||||
from voussoirkit import sqlhelpers
|
||||
from voussoirkit import vlogging
|
||||
|
||||
|
@ -21,6 +23,12 @@ class WormException(Exception):
|
|||
class BadTable(WormException):
|
||||
pass
|
||||
|
||||
class NoTransaction(WormException):
|
||||
pass
|
||||
|
||||
class TransactionActive(WormException):
|
||||
pass
|
||||
|
||||
class DeletedObject(WormException):
|
||||
'''
|
||||
For when thing.deleted == True.
|
||||
|
@ -36,16 +44,13 @@ def slice_before(li, item):
|
|||
index = li.index(item)
|
||||
return li[:index]
|
||||
|
||||
def transaction(method):
|
||||
def atomic(method):
|
||||
'''
|
||||
This decorator can be added to functions that modify your worms database.
|
||||
A savepoint is opened, then your function is run, then we roll back to the
|
||||
savepoint if an exception is raised.
|
||||
A savepoint is opened, then your function is run. If an exception is raised,
|
||||
we roll back to the savepoint.
|
||||
|
||||
This decorator adds the keyword argument 'commit' to your function, so that
|
||||
callers can commit it immediately.
|
||||
|
||||
This decorator adds the attribute 'is_worms_transaction = True' to your
|
||||
This decorator adds the attribute 'is_worms_atomic = True' to your
|
||||
function. You can use this to distinguish readonly vs writing methods during
|
||||
runtime.
|
||||
|
||||
|
@ -55,7 +60,7 @@ def transaction(method):
|
|||
the action's failure.
|
||||
'''
|
||||
@functools.wraps(method)
|
||||
def wrapped_transaction(self, *args, commit=False, **kwargs):
|
||||
def wrapped_atomic(self, *args, **kwargs):
|
||||
if isinstance(self, Object):
|
||||
self.assert_not_deleted()
|
||||
|
||||
|
@ -63,6 +68,7 @@ def transaction(method):
|
|||
|
||||
is_root = len(database.savepoints) == 0
|
||||
savepoint_id = database.savepoint(message=method.__qualname__)
|
||||
log.loud(f'{method.__qualname__} got savepoint {savepoint_id}.')
|
||||
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
|
@ -74,19 +80,36 @@ def transaction(method):
|
|||
if isinstance(result, raise_without_rollback):
|
||||
raise result.exc from result.exc
|
||||
|
||||
if commit:
|
||||
database.commit(message=method.__qualname__)
|
||||
elif not is_root:
|
||||
if not is_root:
|
||||
# In order to prevent a huge pile-up of savepoints when a
|
||||
# @transaction calls another @transaction many times, the sub-call
|
||||
# savepoints are removed from the stack. When an exception occurs,
|
||||
# we're going to rollback from the rootmost savepoint anyway, we'll
|
||||
# never rollback one sub-transaction.
|
||||
database.release_savepoint(savepoint=savepoint_id)
|
||||
|
||||
return result
|
||||
|
||||
wrapped_transaction.is_worms_transaction = True
|
||||
return wrapped_transaction
|
||||
wrapped_atomic.is_worms_atomic = True
|
||||
return wrapped_atomic
|
||||
|
||||
class TransactionContextManager:
|
||||
def __init__(self, database):
|
||||
self.database = database
|
||||
|
||||
def __enter__(self):
|
||||
log.loud('Entering transaction.')
|
||||
self.database.begin()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
log.loud('Exiting transaction.')
|
||||
if exc_type is not None:
|
||||
log.loud(f'Transaction raised {exc_type}.')
|
||||
self.database.rollback()
|
||||
raise exc_value
|
||||
|
||||
self.database.commit()
|
||||
|
||||
class Database(metaclass=abc.ABCMeta):
|
||||
'''
|
||||
|
@ -98,11 +121,17 @@ class Database(metaclass=abc.ABCMeta):
|
|||
'''
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Used for transaction
|
||||
# Used for @atomic decorator
|
||||
self._worms_database = self
|
||||
self.on_commit_queue = []
|
||||
self.on_rollback_queue = []
|
||||
self.savepoints = []
|
||||
# To prevent two transactions from running at the same time in different
|
||||
# threads, and committing the database in an odd state, we lock out and
|
||||
# run one transaction at a time.
|
||||
self._worms_transaction_lock = threading.Lock()
|
||||
self._worms_transaction_owner = None
|
||||
self.transaction = TransactionContextManager(database=self)
|
||||
# If your IDs are integers, you could change this to int. This way, when
|
||||
# taking user input as strings, they will automatically be converted to
|
||||
# int when querying and caching, and you don't have to do the conversion
|
||||
|
@ -125,23 +154,105 @@ class Database(metaclass=abc.ABCMeta):
|
|||
@abc.abstractmethod
|
||||
def _init_sql(self):
|
||||
'''
|
||||
Your subclass needs to set self.sql, which is a database connection.
|
||||
Your subclass needs to prepare self.sql_read and self.sql_write, which
|
||||
are both connection objects. They can be the same object if you want, or
|
||||
they can be separate connections so that the readers can not get blocked
|
||||
by the writers.
|
||||
|
||||
It is recommended to set self.sql.row_factory = sqlite3.Row so that you
|
||||
get dictionary-style named access to row members in your objects' init.
|
||||
You can do it yourself or use the provided _init_connections to get the
|
||||
basic handles going. Then use the rest of this method to do any other
|
||||
setup your application needs.
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
def _make_sqlite_read_connection(self, path):
|
||||
'''
|
||||
Provided for convenience of _init_sql.
|
||||
'''
|
||||
if isinstance(path, pathclass.Path):
|
||||
path = path.absolute_path
|
||||
if path == ':memory:':
|
||||
sql_read = sqlite3.connect('file:memdb1?mode=memory&cache=shared&mode=ro', uri=True)
|
||||
sql_read.row_factory = sqlite3.Row
|
||||
else:
|
||||
log.debug('Connecting to sqlite file "%s".', path)
|
||||
sql_read = sqlite3.connect(f'file:{path}?mode=ro', uri=True)
|
||||
sql_read.row_factory = sqlite3.Row
|
||||
return sql_read
|
||||
|
||||
def _make_sqlite_write_connection(self, path):
|
||||
if isinstance(path, pathclass.Path):
|
||||
path = path.absolute_path
|
||||
|
||||
if path == ':memory:':
|
||||
sql_write = sqlite3.connect('file:memdb1?mode=memory&cache=shared', uri=True)
|
||||
sql_write.row_factory = sqlite3.Row
|
||||
else:
|
||||
log.debug('Connecting to sqlite file "%s".', path)
|
||||
sql_write = sqlite3.connect(path)
|
||||
sql_write.row_factory = sqlite3.Row
|
||||
return sql_write
|
||||
|
||||
def assert_no_transaction(self) -> None:
|
||||
thread_id = threading.current_thread().ident
|
||||
if self._worms_transaction_owner == thread_id:
|
||||
raise TransactionActive()
|
||||
|
||||
def assert_transaction_active(self) -> None:
|
||||
thread_id = threading.current_thread().ident
|
||||
if self._worms_transaction_owner != thread_id:
|
||||
raise NoTransaction()
|
||||
|
||||
def acquire_transaction_lock(self):
|
||||
'''
|
||||
If no transaction is running, the caller gets the lock.
|
||||
|
||||
If a transaction is running on the same thread as the caller, the caller
|
||||
does not get the lock but the function returns so it can do its work,
|
||||
since it is a descendant of the original transaction call.
|
||||
|
||||
If a transaction is running and the caller is on a different thread, it
|
||||
gets blocked until the previous transaction finishes.
|
||||
'''
|
||||
# Don't worry about race conditions, ownership of lock changing while
|
||||
# the if statement is evaluating, because this individual thread cannot
|
||||
# be checking its identity and releasing the lock at the same time! If
|
||||
# transaction_owner is the current thread, we know that will remain
|
||||
# true until this thread releases it, which can't happen at the same
|
||||
# time here.
|
||||
thread_id = threading.current_thread().ident
|
||||
if self._worms_transaction_owner == thread_id:
|
||||
return False
|
||||
|
||||
log.loud(f'{thread_id} wants the transaction lock.')
|
||||
self._worms_transaction_lock.acquire()
|
||||
log.loud(f'{thread_id} has the transaction lock.')
|
||||
self._worms_transaction_owner = thread_id
|
||||
return True
|
||||
|
||||
def assert_table_exists(self, table) -> None:
|
||||
if table not in self.COLUMN_INDEX:
|
||||
raise BadTable(f'Table {table} does not exist.')
|
||||
|
||||
def begin(self):
|
||||
self.acquire_transaction_lock()
|
||||
self.execute('BEGIN')
|
||||
|
||||
def close(self):
|
||||
# Wrapped in hasattr because if the object fails __init__, Python will
|
||||
# still call __del__ and thus close(), even though the attributes
|
||||
# we're trying to clean up never got set.
|
||||
if hasattr(self, 'sql'):
|
||||
self.sql.close()
|
||||
if not hasattr(self, 'sql_read'):
|
||||
return
|
||||
|
||||
if self._worms_transaction_owner:
|
||||
self.rollback()
|
||||
|
||||
self.sql_read.close()
|
||||
del self.sql_read
|
||||
|
||||
self.sql_write.close()
|
||||
del self.sql_write
|
||||
|
||||
def commit(self, message=None) -> None:
|
||||
if message is None:
|
||||
|
@ -157,6 +268,7 @@ class Database(metaclass=abc.ABCMeta):
|
|||
args = task.get('args', [])
|
||||
kwargs = task.get('kwargs', {})
|
||||
action = task['action']
|
||||
log.loud(f'{action} {args} {kwargs}')
|
||||
try:
|
||||
action(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
|
@ -165,8 +277,9 @@ class Database(metaclass=abc.ABCMeta):
|
|||
raise
|
||||
|
||||
self.savepoints.clear()
|
||||
self.sql.commit()
|
||||
self.sql_write.commit()
|
||||
self.last_commit_id = RNG.getrandbits(32)
|
||||
self.release_transaction_lock()
|
||||
|
||||
def delete(self, table, pairs) -> sqlite3.Cursor:
|
||||
if isinstance(table, type) and issubclass(table, Object):
|
||||
|
@ -176,10 +289,26 @@ class Database(metaclass=abc.ABCMeta):
|
|||
query = f'DELETE FROM {table} {qmarks}'
|
||||
return self.execute(query, bindings)
|
||||
|
||||
def execute(self, query, bindings=[]):
|
||||
def execute_read(self, query, bindings=[]):
|
||||
if bindings is None:
|
||||
bindings = []
|
||||
cur = self.sql.cursor()
|
||||
|
||||
thread_id = threading.current_thread().ident
|
||||
if self._worms_transaction_owner == thread_id:
|
||||
sql = self.sql_write
|
||||
else:
|
||||
sql = self.sql_read
|
||||
|
||||
cur = sql.cursor()
|
||||
log.loud('%s %s', query, bindings)
|
||||
cur.execute(query, bindings)
|
||||
return cur
|
||||
|
||||
def execute(self, query, bindings=[]):
|
||||
self.assert_transaction_active()
|
||||
if bindings is None:
|
||||
bindings = []
|
||||
cur = self.sql_write.cursor()
|
||||
log.loud('%s %s', query, bindings)
|
||||
cur.execute(query, bindings)
|
||||
return cur
|
||||
|
@ -189,10 +318,11 @@ class Database(metaclass=abc.ABCMeta):
|
|||
The problem with Python's default executescript is that it executes a
|
||||
COMMIT before running your script. If I wanted a commit I'd write one!
|
||||
'''
|
||||
self.assert_transaction_active()
|
||||
lines = re.split(r';(:?\n|$)', script)
|
||||
lines = (line.strip() for line in lines)
|
||||
lines = (line for line in lines if line)
|
||||
cur = self.sql.cursor()
|
||||
cur = self.sql_write.cursor()
|
||||
for line in lines:
|
||||
log.loud(line)
|
||||
cur.execute(line)
|
||||
|
@ -335,6 +465,17 @@ class Database(metaclass=abc.ABCMeta):
|
|||
|
||||
return (good, bad)
|
||||
|
||||
def pragma_read(self, key):
|
||||
pragma = self.execute_read(f'PRAGMA {key}').fetchone()
|
||||
if pragma is not None:
|
||||
return pragma[0]
|
||||
return None
|
||||
|
||||
def pragma_write(self, key, value) -> None:
|
||||
# We are bypassing self.execute because some pragmas are not allowed to
|
||||
# happen during transactions.
|
||||
return self.sql_write.cursor().execute(f'PRAGMA {key} = {value}')
|
||||
|
||||
def release_savepoint(self, savepoint, allow_commit=False) -> None:
|
||||
'''
|
||||
Releasing a savepoint removes that savepoint from the timeline, so that
|
||||
|
@ -359,6 +500,19 @@ class Database(metaclass=abc.ABCMeta):
|
|||
self.execute(f'RELEASE "{savepoint}"')
|
||||
self.savepoints = slice_before(self.savepoints, savepoint)
|
||||
|
||||
def release_transaction_lock(self):
|
||||
thread_id = threading.current_thread().ident
|
||||
if not self._worms_transaction_lock.locked():
|
||||
return
|
||||
|
||||
if self._worms_transaction_owner != thread_id:
|
||||
log.warning(f'{thread_id} tried to release the transaction lock without holding it.')
|
||||
return
|
||||
|
||||
log.loud(f'{thread_id} releases the transaction lock.')
|
||||
self._worms_transaction_owner = None
|
||||
self._worms_transaction_lock.release()
|
||||
|
||||
def rollback(self, savepoint=None) -> None:
|
||||
'''
|
||||
Given a savepoint, roll the database back to the moment before that
|
||||
|
@ -371,10 +525,6 @@ class Database(metaclass=abc.ABCMeta):
|
|||
log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
|
||||
return
|
||||
|
||||
if len(self.savepoints) == 0:
|
||||
log.debug('Nothing to roll back.')
|
||||
return
|
||||
|
||||
while len(self.on_rollback_queue) > 0:
|
||||
task = self.on_rollback_queue.pop(-1)
|
||||
if task == savepoint:
|
||||
|
@ -397,6 +547,7 @@ class Database(metaclass=abc.ABCMeta):
|
|||
self.execute('ROLLBACK')
|
||||
self.savepoints.clear()
|
||||
self.on_commit_queue.clear()
|
||||
self.release_transaction_lock()
|
||||
|
||||
def savepoint(self, message=None) -> int:
|
||||
savepoint_id = RNG.getrandbits(32)
|
||||
|
@ -412,7 +563,7 @@ class Database(metaclass=abc.ABCMeta):
|
|||
return savepoint_id
|
||||
|
||||
def select(self, query, bindings=None) -> typing.Iterable:
|
||||
cur = self.execute(query, bindings)
|
||||
cur = self.execute_read(query, bindings)
|
||||
while True:
|
||||
fetch = cur.fetchone()
|
||||
if fetch is None:
|
||||
|
@ -432,7 +583,7 @@ class Database(metaclass=abc.ABCMeta):
|
|||
'''
|
||||
Select a single row, or None if no rows match your query.
|
||||
'''
|
||||
cur = self.execute(query, bindings)
|
||||
cur = self.execute_read(query, bindings)
|
||||
return cur.fetchone()
|
||||
|
||||
def select_one_value(self, query, bindings=None, fallback=None):
|
||||
|
@ -441,7 +592,7 @@ class Database(metaclass=abc.ABCMeta):
|
|||
your query. The fallback can help you distinguish between rows that
|
||||
don't exist and a null value.
|
||||
'''
|
||||
cur = self.execute(query, bindings)
|
||||
cur = self.execute_read(query, bindings)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
|
|
Loading…
Reference in a new issue