Add worms.py.
This commit is contained in:
parent
a3ddeb8e09
commit
aa1e2d5756
1 changed files with 519 additions and 0 deletions
519
voussoirkit/worms.py
Normal file
519
voussoirkit/worms.py
Normal file
|
@ -0,0 +1,519 @@
|
|||
'''
|
||||
Worms is an SQL ORM with the strength and resilience of the humble earthworm.
|
||||
'''
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
from voussoirkit import passwordy
|
||||
from voussoirkit import sqlhelpers
|
||||
from voussoirkit import vlogging
|
||||
|
||||
log = vlogging.getLogger(__name__, 'worms')
|
||||
|
||||
class WormException(Exception):
|
||||
pass
|
||||
|
||||
class BadTable(WormException):
|
||||
pass
|
||||
|
||||
def slice_before(li, item):
|
||||
index = li.index(item)
|
||||
return li[:index]
|
||||
|
||||
def transaction(method):
|
||||
'''
|
||||
Open a savepoint before running the method.
|
||||
If the method fails, roll back to that savepoint.
|
||||
'''
|
||||
@functools.wraps(method)
|
||||
def wrapped_transaction(self, *args, commit=False, **kwargs):
|
||||
if isinstance(self, Object):
|
||||
self.assert_not_deleted()
|
||||
|
||||
database = self._worms_database
|
||||
|
||||
is_root = len(database.savepoints) == 0
|
||||
savepoint_id = database.savepoint(message=method.__qualname__)
|
||||
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
except BaseException as exc:
|
||||
log.debug(f'{method} raised {repr(exc)}.')
|
||||
database.rollback(savepoint=savepoint_id)
|
||||
raise
|
||||
|
||||
if commit:
|
||||
database.commit(message=method.__qualname__)
|
||||
elif not is_root:
|
||||
database.release_savepoint(savepoint=savepoint_id)
|
||||
return result
|
||||
|
||||
return wrapped_transaction
|
||||
|
||||
class Database:
|
||||
'''
|
||||
When your class subclasses this class, you need to ensure the following:
|
||||
- self.COLUMNS is a dictionary of {table: [columns]} like what comes out of
|
||||
sqlhelpers.extract_table_column_map.
|
||||
- self.COLUMN_INDEX is a dictionary of {table: {column: index}} like what
|
||||
comes out of sqlhelpers.reverse_table_column_map.
|
||||
'''
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Used for transaction
|
||||
self._worms_database = self
|
||||
self.on_commit_queue = []
|
||||
self.on_rollback_queue = []
|
||||
self.savepoints = []
|
||||
|
||||
def assert_table_exists(self, table) -> None:
|
||||
if table not in self.get_tables():
|
||||
raise BadTable(f'Table {table} does not exist.')
|
||||
|
||||
def close(self):
|
||||
# Wrapped in hasattr because if the object fails __init__, Python will
|
||||
# still call __del__ and thus close(), even though the attributes
|
||||
# we're trying to clean up never got set.
|
||||
if hasattr(self, 'sql'):
|
||||
self.sql.close()
|
||||
|
||||
def commit(self, message=None) -> None:
|
||||
if message is not None:
|
||||
log.debug('Committing - %s.', message)
|
||||
|
||||
while len(self.on_commit_queue) > 0:
|
||||
task = self.on_commit_queue.pop(-1)
|
||||
if isinstance(task, str):
|
||||
# savepoints.
|
||||
continue
|
||||
args = task.get('args', [])
|
||||
kwargs = task.get('kwargs', {})
|
||||
action = task['action']
|
||||
try:
|
||||
action(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
log.debug(f'{action} raised {repr(exc)}.')
|
||||
self.rollback()
|
||||
raise
|
||||
|
||||
self.savepoints.clear()
|
||||
self.sql.commit()
|
||||
|
||||
def get_tables(self) -> set[str]:
|
||||
'''
|
||||
Return the set of all table names in the database.
|
||||
'''
|
||||
query = 'SELECT name FROM sqlite_master WHERE type = "table"'
|
||||
table_rows = self.select(query)
|
||||
tables = set(name for (name,) in table_rows)
|
||||
return tables
|
||||
|
||||
def delete(self, table, pairs) -> None:
|
||||
if isinstance(table, type) and issubclass(table, Object):
|
||||
table = table.table
|
||||
self.assert_table_exists(table)
|
||||
(qmarks, bindings) = sqlhelpers.delete_filler(pairs)
|
||||
query = f'DELETE FROM {table} {qmarks}'
|
||||
self.execute(query, bindings)
|
||||
|
||||
def execute(self, query, bindings=[]):
|
||||
if bindings is None:
|
||||
bindings = []
|
||||
cur = self.sql.cursor()
|
||||
log.loud('%s %s', query, bindings)
|
||||
cur.execute(query, bindings)
|
||||
return cur
|
||||
|
||||
def executescript(self, script) -> None:
|
||||
'''
|
||||
The problem with Python's default executescript is that it executes a
|
||||
COMMIT before running your script. If I wanted a commit I'd write one!
|
||||
'''
|
||||
lines = re.split(r';(:?\n|$)', script)
|
||||
lines = (line.strip() for line in lines)
|
||||
lines = (line for line in lines if line)
|
||||
cur = self.sql.cursor()
|
||||
for line in lines:
|
||||
log.loud(line)
|
||||
cur.execute(line)
|
||||
|
||||
def get_object_by_id(self, object_class, object_id):
|
||||
'''
|
||||
Select an object by its ID.
|
||||
'''
|
||||
if isinstance(object_id, object_class):
|
||||
object_id = object_id.id
|
||||
|
||||
query = f'SELECT * FROM {object_class.table} WHERE id == ?'
|
||||
bindings = [object_id]
|
||||
object_row = self.select_one(query, bindings)
|
||||
if object_row is None:
|
||||
raise object_class.no_such_exception(object_id)
|
||||
|
||||
instance = object_class(self, object_row)
|
||||
|
||||
return instance
|
||||
|
||||
def get_objects(self, object_class):
|
||||
'''
|
||||
Yield objects, unfiltered, in whatever order they appear in the database.
|
||||
'''
|
||||
table = object_class.table
|
||||
query = f'SELECT * FROM {table}'
|
||||
|
||||
objects = self.select(query)
|
||||
for object_row in objects:
|
||||
instance = object_class(self, object_row)
|
||||
yield instance
|
||||
|
||||
def get_objects_by_id(self, object_class, object_ids):
|
||||
'''
|
||||
Select many objects by their IDs.
|
||||
This is better than calling get_object_by_id in a loop because we can
|
||||
use a single SQL select to get batches of up to 999 items.
|
||||
|
||||
Note: The order of the output is not guaranteed to match the order of
|
||||
the input. Consider using get_objects_by_sql if that is a necessity.
|
||||
'''
|
||||
ids_needed = list(set(object_ids))
|
||||
|
||||
while ids_needed:
|
||||
# SQLite3 has a limit of 999 ? in a query, so we must batch them.
|
||||
id_batch = ids_needed[:999]
|
||||
ids_needed = ids_needed[999:]
|
||||
|
||||
qmarks = ','.join('?' * len(id_batch))
|
||||
qmarks = f'({qmarks})'
|
||||
query = f'SELECT * FROM {object_class.table} WHERE id IN {qmarks}'
|
||||
for object_row in self.select(query, id_batch):
|
||||
instance = object_class(self, db_row=object_row)
|
||||
yield instance
|
||||
|
||||
def get_objects_by_sql(self, object_class, query, bindings=None):
|
||||
'''
|
||||
Use an arbitrary SQL query to select objects from the database.
|
||||
Your query should select * from the object's table.
|
||||
'''
|
||||
object_rows = self.select(query, bindings)
|
||||
for object_row in object_rows:
|
||||
yield object_class(self, object_row)
|
||||
|
||||
def insert(self, table, data) -> None:
|
||||
if isinstance(table, type) and issubclass(table, Object):
|
||||
table = table.table
|
||||
self.assert_table_exists(table)
|
||||
column_names = self.COLUMNS[table]
|
||||
(qmarks, bindings) = sqlhelpers.insert_filler(column_names, data)
|
||||
|
||||
query = f'INSERT INTO {table} VALUES({qmarks})'
|
||||
self.execute(query, bindings)
|
||||
|
||||
def normalize_db_row(self, db_row, table) -> dict:
|
||||
'''
|
||||
Raises KeyError if table is not one of the recognized tables.
|
||||
|
||||
Raises TypeError if db_row is not the right type.
|
||||
'''
|
||||
if isinstance(db_row, dict):
|
||||
return db_row
|
||||
|
||||
if isinstance(db_row, (list, tuple)):
|
||||
return dict(zip(self.COLUMNS[table], db_row))
|
||||
|
||||
raise TypeError(f'db_row should be {dict}, {list}, or {tuple}, not {type(db_row)}.')
|
||||
|
||||
def release_savepoint(self, savepoint, allow_commit=False) -> None:
|
||||
'''
|
||||
Releasing a savepoint removes that savepoint from the timeline, so that
|
||||
you can no longer roll back to it. Then your choices are to commit
|
||||
everything, or roll back to a previous point. If you release the
|
||||
earliest savepoint, the database will commit.
|
||||
'''
|
||||
if savepoint not in self.savepoints:
|
||||
log.warn('Tried to release nonexistent savepoint %s.', savepoint)
|
||||
return
|
||||
|
||||
is_commit = savepoint == self.savepoints[0]
|
||||
if is_commit and not allow_commit:
|
||||
log.debug('Not committing %s without allow_commit=True.', savepoint)
|
||||
return
|
||||
|
||||
if is_commit:
|
||||
# We want to perform the on_commit_queue so let's use our commit
|
||||
# method instead of allowing sql's release to commit.
|
||||
self.commit()
|
||||
else:
|
||||
self.execute(f'RELEASE "{savepoint}"')
|
||||
self.savepoints = slice_before(self.savepoints, savepoint)
|
||||
|
||||
def rollback(self, savepoint=None) -> None:
|
||||
'''
|
||||
Given a savepoint, roll the database back to the moment before that
|
||||
savepoint was created. Keep in mind that a @transaction savepoint is
|
||||
always created *before* the method actually does anything.
|
||||
|
||||
If no savepoint is provided then rollback the entire transaction.
|
||||
'''
|
||||
if savepoint is not None and savepoint not in self.savepoints:
|
||||
log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
|
||||
return
|
||||
|
||||
if len(self.savepoints) == 0:
|
||||
log.debug('Nothing to roll back.')
|
||||
return
|
||||
|
||||
while len(self.on_rollback_queue) > 0:
|
||||
task = self.on_rollback_queue.pop(-1)
|
||||
if task == savepoint:
|
||||
break
|
||||
if isinstance(task, str):
|
||||
# Intermediate savepoints.
|
||||
continue
|
||||
args = task.get('args', [])
|
||||
kwargs = task.get('kwargs', {})
|
||||
task['action'](*args, **kwargs)
|
||||
|
||||
if savepoint is not None:
|
||||
log.debug('Rolling back to %s.', savepoint)
|
||||
self.execute(f'ROLLBACK TO "{savepoint}"')
|
||||
self.savepoints = slice_before(self.savepoints, savepoint)
|
||||
self.on_commit_queue = slice_before(self.on_commit_queue, savepoint)
|
||||
|
||||
else:
|
||||
log.debug('Rolling back.')
|
||||
self.execute('ROLLBACK')
|
||||
self.savepoints.clear()
|
||||
self.on_commit_queue.clear()
|
||||
|
||||
def savepoint(self, message=None) -> str:
|
||||
savepoint_id = passwordy.random_hex(length=16)
|
||||
if message:
|
||||
log.log(5, 'Savepoint %s for %s.', savepoint_id, message)
|
||||
else:
|
||||
log.log(5, 'Savepoint %s.', savepoint_id)
|
||||
query = f'SAVEPOINT "{savepoint_id}"'
|
||||
self.execute(query)
|
||||
self.savepoints.append(savepoint_id)
|
||||
self.on_commit_queue.append(savepoint_id)
|
||||
self.on_rollback_queue.append(savepoint_id)
|
||||
return savepoint_id
|
||||
|
||||
def select(self, query, bindings=None) -> typing.Iterable:
|
||||
cur = self.execute(query, bindings)
|
||||
while True:
|
||||
fetch = cur.fetchone()
|
||||
if fetch is None:
|
||||
break
|
||||
yield fetch
|
||||
|
||||
def select_column(self, query, bindings=None) -> typing.Iterable:
|
||||
'''
|
||||
If your SELECT query only selects a single column, you can use this
|
||||
function to get a generator of the individual values instead
|
||||
of one-tuples.
|
||||
'''
|
||||
for row in self.select(query, bindings):
|
||||
yield row[0]
|
||||
|
||||
def select_one(self, query, bindings=None):
|
||||
cur = self.execute(query, bindings)
|
||||
return cur.fetchone()
|
||||
|
||||
def update(self, table, pairs, where_key) -> None:
|
||||
if isinstance(table, type) and issubclass(table, Object):
|
||||
table = table.table
|
||||
self.assert_table_exists(table)
|
||||
(qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key)
|
||||
query = f'UPDATE {table} {qmarks}'
|
||||
self.execute(query, bindings)
|
||||
|
||||
class DatabaseWithCaching(Database):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.caches = {}
|
||||
|
||||
def clear_all_caches(self) -> None:
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
def get_cached_instance(self, object_class, db_row):
|
||||
'''
|
||||
Check if there is already an instance in the cache and return that.
|
||||
Otherwise, a new instance is created, cached, and returned.
|
||||
|
||||
Note that in order to call this method you have to already have a
|
||||
db_row which means performing some select. If you only have the ID,
|
||||
use get_object_by_id, as there may already be a cached instance to save
|
||||
you the select.
|
||||
'''
|
||||
object_table = object_class.table
|
||||
object_cache = self.caches.get(object_class, None)
|
||||
|
||||
if isinstance(db_row, dict):
|
||||
object_id = db_row['id']
|
||||
else:
|
||||
object_index = self.COLUMN_INDEX[object_table]
|
||||
object_id = db_row[object_index['id']]
|
||||
|
||||
if object_cache is None:
|
||||
return object_class(self, db_row)
|
||||
|
||||
try:
|
||||
instance = object_cache[object_id]
|
||||
except KeyError:
|
||||
log.loud('Cache miss %s %s.', object_class, object_id)
|
||||
instance = object_class(self, db_row)
|
||||
object_cache[object_id] = instance
|
||||
return instance
|
||||
|
||||
def get_object_by_id(self, object_class, object_id):
|
||||
'''
|
||||
This method will first check the cache to see if there is already an
|
||||
instance with that ID, in which case we don't need to perform any SQL
|
||||
select. If it is not in the cache, then a new instance is created,
|
||||
cached, and returned.
|
||||
'''
|
||||
if isinstance(object_id, object_class):
|
||||
# This could be used to check if your old reference to an object is
|
||||
# still in the cache, or re-select it from the db to make sure it
|
||||
# still exists and re-cache.
|
||||
# Probably an uncommon need but... no harm I think.
|
||||
object_id = object_id.id
|
||||
|
||||
object_cache = self.caches.get(object_class, None)
|
||||
|
||||
if object_cache is not None:
|
||||
try:
|
||||
return object_cache[object_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
query = f'SELECT * FROM {object_class.table} WHERE id == ?'
|
||||
bindings = [object_id]
|
||||
object_row = self.select_one(query, bindings)
|
||||
if object_row is None:
|
||||
raise object_class.no_such_exception(object_id)
|
||||
|
||||
# Normally we would call `get_cached_instance` instead of
|
||||
# constructing here. But we already know for a fact that this
|
||||
# object is not in the cache.
|
||||
instance = object_class(self, object_row)
|
||||
|
||||
if object_cache is not None:
|
||||
object_cache[object_id] = instance
|
||||
|
||||
return instance
|
||||
|
||||
def get_objects(self, object_class):
|
||||
'''
|
||||
Yield objects, unfiltered, in whatever order they appear in the database.
|
||||
'''
|
||||
table = object_class.table
|
||||
query = f'SELECT * FROM {table}'
|
||||
|
||||
objects = self.select(query)
|
||||
for object_row in objects:
|
||||
instance = self.get_cached_instance(object_class, object_row)
|
||||
yield instance
|
||||
|
||||
def get_objects_by_id(self, object_class, object_ids):
|
||||
'''
|
||||
Given multiple IDs, this method will find which ones are in the cache
|
||||
and which ones need to be selected from the db.
|
||||
This is better than calling get_object_by_id in a loop because we can
|
||||
use a single SQL select to get batches of up to 999 items.
|
||||
|
||||
Note: The order of the output will most likely not match the order of
|
||||
the input, because we first pull items from the cache before requesting
|
||||
the rest from the database.
|
||||
'''
|
||||
object_cache = self.caches.get(object_class, None)
|
||||
|
||||
ids_needed = set()
|
||||
|
||||
if object_cache is None:
|
||||
ids_needed.update(object_ids)
|
||||
else:
|
||||
for object_id in object_ids:
|
||||
try:
|
||||
instance = object_cache[object_id]
|
||||
except KeyError:
|
||||
ids_needed.add(object_id)
|
||||
else:
|
||||
yield instance
|
||||
|
||||
if not ids_needed:
|
||||
return
|
||||
|
||||
if object_cache is not None:
|
||||
log.loud('Cache miss %s %s.', object_class.table, ids_needed)
|
||||
|
||||
ids_needed = list(ids_needed)
|
||||
while ids_needed:
|
||||
# SQLite3 has a limit of 999 ? in a query, so we must batch them.
|
||||
id_batch = ids_needed[:999]
|
||||
ids_needed = ids_needed[999:]
|
||||
|
||||
qmarks = ','.join('?' * len(id_batch))
|
||||
qmarks = f'({qmarks})'
|
||||
query = f'SELECT * FROM {object_class.table} WHERE id IN {qmarks}'
|
||||
for object_row in self.select(query, id_batch):
|
||||
# Normally we would call `get_cached_instance` instead of
|
||||
# constructing here. But we already know for a fact that this
|
||||
# object is not in the cache because it made it past the
|
||||
# previous loop.
|
||||
instance = object_class(self, db_row=object_row)
|
||||
if object_cache is not None:
|
||||
object_cache[instance.id] = instance
|
||||
yield instance
|
||||
|
||||
def get_objects_by_sql(self, object_class, query, bindings=None):
|
||||
'''
|
||||
Use an arbitrary SQL query to select objects from the database.
|
||||
Your query should select * from the object's table.
|
||||
'''
|
||||
object_rows = self.select(query, bindings)
|
||||
for object_row in object_rows:
|
||||
yield self.get_cached_instance(object_class, object_row)
|
||||
|
||||
class Object:
|
||||
'''
|
||||
When your objects subclass this class, you need to ensure the following:
|
||||
|
||||
- self.table should be a string.
|
||||
- self.no_such_exception should be an exception class, to be raised when
|
||||
the user requests an instance of this class that does not exist.
|
||||
Initialized with a single argument, the requested ID.
|
||||
'''
|
||||
def __init__(self, database):
|
||||
# Used for transaction
|
||||
self._worms_database = database
|
||||
|
||||
def __reinit__(self):
|
||||
'''
|
||||
Reload the row from the database and do __init__ with it.
|
||||
'''
|
||||
query = f'SELECT * FROM {self.table} WHERE id == ?'
|
||||
bindings = [self.id]
|
||||
row = self._worms_database.select_one(query, bindings)
|
||||
if row is None:
|
||||
self.deleted = True
|
||||
else:
|
||||
self.__init__(self._worms_database, row)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, type(self)) and
|
||||
self._worms_database == other._worms_database and
|
||||
self.id == other.id
|
||||
)
|
||||
|
||||
def __format__(self, formcode):
|
||||
if formcode == 'r':
|
||||
return repr(self)
|
||||
else:
|
||||
return str(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(f'{self.table}.{self.id}')
|
Loading…
Reference in a new issue