Add new transaction locking and cleaner atomicity in worms.
This change has really made it easier to reason about database transactions in my projects. When you use 'with db.transaction' you know for sure that either the db will commit or rollback at the end and you won't leave in a dirty state. And it will lock out all other writers so nothing gets messed up. Previously I was conflating atomicity of each function with committing of the entire transaction, and that was causing me grief. I think this is closer to correct.
This commit is contained in:
parent
abfaf27cee
commit
cbef38ba7f
1 changed files with 181 additions and 30 deletions
|
@ -6,8 +6,10 @@ import functools
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import threading
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
from voussoirkit import pathclass
|
||||||
from voussoirkit import sqlhelpers
|
from voussoirkit import sqlhelpers
|
||||||
from voussoirkit import vlogging
|
from voussoirkit import vlogging
|
||||||
|
|
||||||
|
@ -21,6 +23,12 @@ class WormException(Exception):
|
||||||
class BadTable(WormException):
|
class BadTable(WormException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class NoTransaction(WormException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TransactionActive(WormException):
|
||||||
|
pass
|
||||||
|
|
||||||
class DeletedObject(WormException):
|
class DeletedObject(WormException):
|
||||||
'''
|
'''
|
||||||
For when thing.deleted == True.
|
For when thing.deleted == True.
|
||||||
|
@ -36,16 +44,13 @@ def slice_before(li, item):
|
||||||
index = li.index(item)
|
index = li.index(item)
|
||||||
return li[:index]
|
return li[:index]
|
||||||
|
|
||||||
def transaction(method):
|
def atomic(method):
|
||||||
'''
|
'''
|
||||||
This decorator can be added to functions that modify your worms database.
|
This decorator can be added to functions that modify your worms database.
|
||||||
A savepoint is opened, then your function is run, then we roll back to the
|
A savepoint is opened, then your function is run. If an exception is raised,
|
||||||
savepoint if an exception is raised.
|
we roll back to the savepoint.
|
||||||
|
|
||||||
This decorator adds the keyword argument 'commit' to your function, so that
|
This decorator adds the attribute 'is_worms_atomic = True' to your
|
||||||
callers can commit it immediately.
|
|
||||||
|
|
||||||
This decorator adds the attribute 'is_worms_transaction = True' to your
|
|
||||||
function. You can use this to distinguish readonly vs writing methods during
|
function. You can use this to distinguish readonly vs writing methods during
|
||||||
runtime.
|
runtime.
|
||||||
|
|
||||||
|
@ -55,7 +60,7 @@ def transaction(method):
|
||||||
the action's failure.
|
the action's failure.
|
||||||
'''
|
'''
|
||||||
@functools.wraps(method)
|
@functools.wraps(method)
|
||||||
def wrapped_transaction(self, *args, commit=False, **kwargs):
|
def wrapped_atomic(self, *args, **kwargs):
|
||||||
if isinstance(self, Object):
|
if isinstance(self, Object):
|
||||||
self.assert_not_deleted()
|
self.assert_not_deleted()
|
||||||
|
|
||||||
|
@ -63,6 +68,7 @@ def transaction(method):
|
||||||
|
|
||||||
is_root = len(database.savepoints) == 0
|
is_root = len(database.savepoints) == 0
|
||||||
savepoint_id = database.savepoint(message=method.__qualname__)
|
savepoint_id = database.savepoint(message=method.__qualname__)
|
||||||
|
log.loud(f'{method.__qualname__} got savepoint {savepoint_id}.')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = method(self, *args, **kwargs)
|
result = method(self, *args, **kwargs)
|
||||||
|
@ -74,19 +80,36 @@ def transaction(method):
|
||||||
if isinstance(result, raise_without_rollback):
|
if isinstance(result, raise_without_rollback):
|
||||||
raise result.exc from result.exc
|
raise result.exc from result.exc
|
||||||
|
|
||||||
if commit:
|
if not is_root:
|
||||||
database.commit(message=method.__qualname__)
|
|
||||||
elif not is_root:
|
|
||||||
# In order to prevent a huge pile-up of savepoints when a
|
# In order to prevent a huge pile-up of savepoints when a
|
||||||
# @transaction calls another @transaction many times, the sub-call
|
# @transaction calls another @transaction many times, the sub-call
|
||||||
# savepoints are removed from the stack. When an exception occurs,
|
# savepoints are removed from the stack. When an exception occurs,
|
||||||
# we're going to rollback from the rootmost savepoint anyway, we'll
|
# we're going to rollback from the rootmost savepoint anyway, we'll
|
||||||
# never rollback one sub-transaction.
|
# never rollback one sub-transaction.
|
||||||
database.release_savepoint(savepoint=savepoint_id)
|
database.release_savepoint(savepoint=savepoint_id)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
wrapped_transaction.is_worms_transaction = True
|
wrapped_atomic.is_worms_atomic = True
|
||||||
return wrapped_transaction
|
return wrapped_atomic
|
||||||
|
|
||||||
|
class TransactionContextManager:
|
||||||
|
def __init__(self, database):
|
||||||
|
self.database = database
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
log.loud('Entering transaction.')
|
||||||
|
self.database.begin()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
log.loud('Exiting transaction.')
|
||||||
|
if exc_type is not None:
|
||||||
|
log.loud(f'Transaction raised {exc_type}.')
|
||||||
|
self.database.rollback()
|
||||||
|
raise exc_value
|
||||||
|
|
||||||
|
self.database.commit()
|
||||||
|
|
||||||
class Database(metaclass=abc.ABCMeta):
|
class Database(metaclass=abc.ABCMeta):
|
||||||
'''
|
'''
|
||||||
|
@ -98,11 +121,17 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
'''
|
'''
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Used for transaction
|
# Used for @atomic decorator
|
||||||
self._worms_database = self
|
self._worms_database = self
|
||||||
self.on_commit_queue = []
|
self.on_commit_queue = []
|
||||||
self.on_rollback_queue = []
|
self.on_rollback_queue = []
|
||||||
self.savepoints = []
|
self.savepoints = []
|
||||||
|
# To prevent two transactions from running at the same time in different
|
||||||
|
# threads, and committing the database in an odd state, we lock out and
|
||||||
|
# run one transaction at a time.
|
||||||
|
self._worms_transaction_lock = threading.Lock()
|
||||||
|
self._worms_transaction_owner = None
|
||||||
|
self.transaction = TransactionContextManager(database=self)
|
||||||
# If your IDs are integers, you could change this to int. This way, when
|
# If your IDs are integers, you could change this to int. This way, when
|
||||||
# taking user input as strings, they will automatically be converted to
|
# taking user input as strings, they will automatically be converted to
|
||||||
# int when querying and caching, and you don't have to do the conversion
|
# int when querying and caching, and you don't have to do the conversion
|
||||||
|
@ -125,23 +154,105 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _init_sql(self):
|
def _init_sql(self):
|
||||||
'''
|
'''
|
||||||
Your subclass needs to set self.sql, which is a database connection.
|
Your subclass needs to prepare self.sql_read and self.sql_write, which
|
||||||
|
are both connection objects. They can be the same object if you want, or
|
||||||
|
they can be separate connections so that the readers can not get blocked
|
||||||
|
by the writers.
|
||||||
|
|
||||||
It is recommended to set self.sql.row_factory = sqlite3.Row so that you
|
You can do it yourself or use the provided _init_connections to get the
|
||||||
get dictionary-style named access to row members in your objects' init.
|
basic handles going. Then use the rest of this method to do any other
|
||||||
|
setup your application needs.
|
||||||
'''
|
'''
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _make_sqlite_read_connection(self, path):
|
||||||
|
'''
|
||||||
|
Provided for convenience of _init_sql.
|
||||||
|
'''
|
||||||
|
if isinstance(path, pathclass.Path):
|
||||||
|
path = path.absolute_path
|
||||||
|
if path == ':memory:':
|
||||||
|
sql_read = sqlite3.connect('file:memdb1?mode=memory&cache=shared&mode=ro', uri=True)
|
||||||
|
sql_read.row_factory = sqlite3.Row
|
||||||
|
else:
|
||||||
|
log.debug('Connecting to sqlite file "%s".', path)
|
||||||
|
sql_read = sqlite3.connect(f'file:{path}?mode=ro', uri=True)
|
||||||
|
sql_read.row_factory = sqlite3.Row
|
||||||
|
return sql_read
|
||||||
|
|
||||||
|
def _make_sqlite_write_connection(self, path):
|
||||||
|
if isinstance(path, pathclass.Path):
|
||||||
|
path = path.absolute_path
|
||||||
|
|
||||||
|
if path == ':memory:':
|
||||||
|
sql_write = sqlite3.connect('file:memdb1?mode=memory&cache=shared', uri=True)
|
||||||
|
sql_write.row_factory = sqlite3.Row
|
||||||
|
else:
|
||||||
|
log.debug('Connecting to sqlite file "%s".', path)
|
||||||
|
sql_write = sqlite3.connect(path)
|
||||||
|
sql_write.row_factory = sqlite3.Row
|
||||||
|
return sql_write
|
||||||
|
|
||||||
|
def assert_no_transaction(self) -> None:
|
||||||
|
thread_id = threading.current_thread().ident
|
||||||
|
if self._worms_transaction_owner == thread_id:
|
||||||
|
raise TransactionActive()
|
||||||
|
|
||||||
|
def assert_transaction_active(self) -> None:
|
||||||
|
thread_id = threading.current_thread().ident
|
||||||
|
if self._worms_transaction_owner != thread_id:
|
||||||
|
raise NoTransaction()
|
||||||
|
|
||||||
|
def acquire_transaction_lock(self):
|
||||||
|
'''
|
||||||
|
If no transaction is running, the caller gets the lock.
|
||||||
|
|
||||||
|
If a transaction is running on the same thread as the caller, the caller
|
||||||
|
does not get the lock but the function returns so it can do its work,
|
||||||
|
since it is a descendant of the original transaction call.
|
||||||
|
|
||||||
|
If a transaction is running and the caller is on a different thread, it
|
||||||
|
gets blocked until the previous transaction finishes.
|
||||||
|
'''
|
||||||
|
# Don't worry about race conditions, ownership of lock changing while
|
||||||
|
# the if statement is evaluating, because this individual thread cannot
|
||||||
|
# be checking its identity and releasing the lock at the same time! If
|
||||||
|
# transaction_owner is the current thread, we know that will remain
|
||||||
|
# true until this thread releases it, which can't happen at the same
|
||||||
|
# time here.
|
||||||
|
thread_id = threading.current_thread().ident
|
||||||
|
if self._worms_transaction_owner == thread_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
log.loud(f'{thread_id} wants the transaction lock.')
|
||||||
|
self._worms_transaction_lock.acquire()
|
||||||
|
log.loud(f'{thread_id} has the transaction lock.')
|
||||||
|
self._worms_transaction_owner = thread_id
|
||||||
|
return True
|
||||||
|
|
||||||
def assert_table_exists(self, table) -> None:
|
def assert_table_exists(self, table) -> None:
|
||||||
if table not in self.COLUMN_INDEX:
|
if table not in self.COLUMN_INDEX:
|
||||||
raise BadTable(f'Table {table} does not exist.')
|
raise BadTable(f'Table {table} does not exist.')
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
self.acquire_transaction_lock()
|
||||||
|
self.execute('BEGIN')
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# Wrapped in hasattr because if the object fails __init__, Python will
|
# Wrapped in hasattr because if the object fails __init__, Python will
|
||||||
# still call __del__ and thus close(), even though the attributes
|
# still call __del__ and thus close(), even though the attributes
|
||||||
# we're trying to clean up never got set.
|
# we're trying to clean up never got set.
|
||||||
if hasattr(self, 'sql'):
|
if not hasattr(self, 'sql_read'):
|
||||||
self.sql.close()
|
return
|
||||||
|
|
||||||
|
if self._worms_transaction_owner:
|
||||||
|
self.rollback()
|
||||||
|
|
||||||
|
self.sql_read.close()
|
||||||
|
del self.sql_read
|
||||||
|
|
||||||
|
self.sql_write.close()
|
||||||
|
del self.sql_write
|
||||||
|
|
||||||
def commit(self, message=None) -> None:
|
def commit(self, message=None) -> None:
|
||||||
if message is None:
|
if message is None:
|
||||||
|
@ -157,6 +268,7 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
args = task.get('args', [])
|
args = task.get('args', [])
|
||||||
kwargs = task.get('kwargs', {})
|
kwargs = task.get('kwargs', {})
|
||||||
action = task['action']
|
action = task['action']
|
||||||
|
log.loud(f'{action} {args} {kwargs}')
|
||||||
try:
|
try:
|
||||||
action(*args, **kwargs)
|
action(*args, **kwargs)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
@ -165,8 +277,9 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.savepoints.clear()
|
self.savepoints.clear()
|
||||||
self.sql.commit()
|
self.sql_write.commit()
|
||||||
self.last_commit_id = RNG.getrandbits(32)
|
self.last_commit_id = RNG.getrandbits(32)
|
||||||
|
self.release_transaction_lock()
|
||||||
|
|
||||||
def delete(self, table, pairs) -> sqlite3.Cursor:
|
def delete(self, table, pairs) -> sqlite3.Cursor:
|
||||||
if isinstance(table, type) and issubclass(table, Object):
|
if isinstance(table, type) and issubclass(table, Object):
|
||||||
|
@ -176,10 +289,26 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
query = f'DELETE FROM {table} {qmarks}'
|
query = f'DELETE FROM {table} {qmarks}'
|
||||||
return self.execute(query, bindings)
|
return self.execute(query, bindings)
|
||||||
|
|
||||||
def execute(self, query, bindings=[]):
|
def execute_read(self, query, bindings=[]):
|
||||||
if bindings is None:
|
if bindings is None:
|
||||||
bindings = []
|
bindings = []
|
||||||
cur = self.sql.cursor()
|
|
||||||
|
thread_id = threading.current_thread().ident
|
||||||
|
if self._worms_transaction_owner == thread_id:
|
||||||
|
sql = self.sql_write
|
||||||
|
else:
|
||||||
|
sql = self.sql_read
|
||||||
|
|
||||||
|
cur = sql.cursor()
|
||||||
|
log.loud('%s %s', query, bindings)
|
||||||
|
cur.execute(query, bindings)
|
||||||
|
return cur
|
||||||
|
|
||||||
|
def execute(self, query, bindings=[]):
|
||||||
|
self.assert_transaction_active()
|
||||||
|
if bindings is None:
|
||||||
|
bindings = []
|
||||||
|
cur = self.sql_write.cursor()
|
||||||
log.loud('%s %s', query, bindings)
|
log.loud('%s %s', query, bindings)
|
||||||
cur.execute(query, bindings)
|
cur.execute(query, bindings)
|
||||||
return cur
|
return cur
|
||||||
|
@ -189,10 +318,11 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
The problem with Python's default executescript is that it executes a
|
The problem with Python's default executescript is that it executes a
|
||||||
COMMIT before running your script. If I wanted a commit I'd write one!
|
COMMIT before running your script. If I wanted a commit I'd write one!
|
||||||
'''
|
'''
|
||||||
|
self.assert_transaction_active()
|
||||||
lines = re.split(r';(:?\n|$)', script)
|
lines = re.split(r';(:?\n|$)', script)
|
||||||
lines = (line.strip() for line in lines)
|
lines = (line.strip() for line in lines)
|
||||||
lines = (line for line in lines if line)
|
lines = (line for line in lines if line)
|
||||||
cur = self.sql.cursor()
|
cur = self.sql_write.cursor()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
log.loud(line)
|
log.loud(line)
|
||||||
cur.execute(line)
|
cur.execute(line)
|
||||||
|
@ -335,6 +465,17 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
return (good, bad)
|
return (good, bad)
|
||||||
|
|
||||||
|
def pragma_read(self, key):
|
||||||
|
pragma = self.execute_read(f'PRAGMA {key}').fetchone()
|
||||||
|
if pragma is not None:
|
||||||
|
return pragma[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def pragma_write(self, key, value) -> None:
|
||||||
|
# We are bypassing self.execute because some pragmas are not allowed to
|
||||||
|
# happen during transactions.
|
||||||
|
return self.sql_write.cursor().execute(f'PRAGMA {key} = {value}')
|
||||||
|
|
||||||
def release_savepoint(self, savepoint, allow_commit=False) -> None:
|
def release_savepoint(self, savepoint, allow_commit=False) -> None:
|
||||||
'''
|
'''
|
||||||
Releasing a savepoint removes that savepoint from the timeline, so that
|
Releasing a savepoint removes that savepoint from the timeline, so that
|
||||||
|
@ -359,6 +500,19 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
self.execute(f'RELEASE "{savepoint}"')
|
self.execute(f'RELEASE "{savepoint}"')
|
||||||
self.savepoints = slice_before(self.savepoints, savepoint)
|
self.savepoints = slice_before(self.savepoints, savepoint)
|
||||||
|
|
||||||
|
def release_transaction_lock(self):
|
||||||
|
thread_id = threading.current_thread().ident
|
||||||
|
if not self._worms_transaction_lock.locked():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._worms_transaction_owner != thread_id:
|
||||||
|
log.warning(f'{thread_id} tried to release the transaction lock without holding it.')
|
||||||
|
return
|
||||||
|
|
||||||
|
log.loud(f'{thread_id} releases the transaction lock.')
|
||||||
|
self._worms_transaction_owner = None
|
||||||
|
self._worms_transaction_lock.release()
|
||||||
|
|
||||||
def rollback(self, savepoint=None) -> None:
|
def rollback(self, savepoint=None) -> None:
|
||||||
'''
|
'''
|
||||||
Given a savepoint, roll the database back to the moment before that
|
Given a savepoint, roll the database back to the moment before that
|
||||||
|
@ -371,10 +525,6 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
|
log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self.savepoints) == 0:
|
|
||||||
log.debug('Nothing to roll back.')
|
|
||||||
return
|
|
||||||
|
|
||||||
while len(self.on_rollback_queue) > 0:
|
while len(self.on_rollback_queue) > 0:
|
||||||
task = self.on_rollback_queue.pop(-1)
|
task = self.on_rollback_queue.pop(-1)
|
||||||
if task == savepoint:
|
if task == savepoint:
|
||||||
|
@ -397,6 +547,7 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
self.execute('ROLLBACK')
|
self.execute('ROLLBACK')
|
||||||
self.savepoints.clear()
|
self.savepoints.clear()
|
||||||
self.on_commit_queue.clear()
|
self.on_commit_queue.clear()
|
||||||
|
self.release_transaction_lock()
|
||||||
|
|
||||||
def savepoint(self, message=None) -> int:
|
def savepoint(self, message=None) -> int:
|
||||||
savepoint_id = RNG.getrandbits(32)
|
savepoint_id = RNG.getrandbits(32)
|
||||||
|
@ -412,7 +563,7 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
return savepoint_id
|
return savepoint_id
|
||||||
|
|
||||||
def select(self, query, bindings=None) -> typing.Iterable:
|
def select(self, query, bindings=None) -> typing.Iterable:
|
||||||
cur = self.execute(query, bindings)
|
cur = self.execute_read(query, bindings)
|
||||||
while True:
|
while True:
|
||||||
fetch = cur.fetchone()
|
fetch = cur.fetchone()
|
||||||
if fetch is None:
|
if fetch is None:
|
||||||
|
@ -432,7 +583,7 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
'''
|
'''
|
||||||
Select a single row, or None if no rows match your query.
|
Select a single row, or None if no rows match your query.
|
||||||
'''
|
'''
|
||||||
cur = self.execute(query, bindings)
|
cur = self.execute_read(query, bindings)
|
||||||
return cur.fetchone()
|
return cur.fetchone()
|
||||||
|
|
||||||
def select_one_value(self, query, bindings=None, fallback=None):
|
def select_one_value(self, query, bindings=None, fallback=None):
|
||||||
|
@ -441,7 +592,7 @@ class Database(metaclass=abc.ABCMeta):
|
||||||
your query. The fallback can help you distinguish between rows that
|
your query. The fallback can help you distinguish between rows that
|
||||||
don't exist and a null value.
|
don't exist and a null value.
|
||||||
'''
|
'''
|
||||||
cur = self.execute(query, bindings)
|
cur = self.execute_read(query, bindings)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return row[0]
|
return row[0]
|
||||||
|
|
Loading…
Reference in a new issue