2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
Worms is an SQL ORM with the strength and resilience of the humble earthworm.
|
|
|
|
'''
|
2021-10-09 19:12:18 +00:00
|
|
|
import abc
|
2021-09-18 01:24:03 +00:00
|
|
|
import functools
|
2022-03-10 19:26:51 +00:00
|
|
|
import random
|
2021-09-18 01:24:03 +00:00
|
|
|
import re
|
2022-03-20 03:03:50 +00:00
|
|
|
import sqlite3
|
2021-09-18 01:24:03 +00:00
|
|
|
import typing
|
|
|
|
|
|
|
|
from voussoirkit import sqlhelpers
|
|
|
|
from voussoirkit import vlogging
|
|
|
|
|
|
|
|
log = vlogging.getLogger(__name__, 'worms')
|
|
|
|
|
2022-03-10 19:26:51 +00:00
|
|
|
RNG = random.SystemRandom()
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
class WormException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class BadTable(WormException):
|
|
|
|
pass
|
|
|
|
|
2021-10-02 19:45:17 +00:00
|
|
|
class DeletedObject(WormException):
|
|
|
|
'''
|
|
|
|
For when thing.deleted == True.
|
|
|
|
'''
|
|
|
|
pass
|
|
|
|
|
2022-03-14 22:32:11 +00:00
|
|
|
# snake-cased because I want the ergonomics of a function from the caller's end.
|
|
|
|
class raise_without_rollback:
|
|
|
|
def __init__(self, exc):
|
|
|
|
self.exc = exc
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
def slice_before(li, item):
|
|
|
|
index = li.index(item)
|
|
|
|
return li[:index]
|
|
|
|
|
|
|
|
def transaction(method):
|
|
|
|
'''
|
2022-03-14 22:32:11 +00:00
|
|
|
This decorator can be added to functions that modify your worms database.
|
|
|
|
A savepoint is opened, then your function is run, then we roll back to the
|
|
|
|
savepoint if an exception is raised.
|
|
|
|
|
|
|
|
This decorator adds the keyword argument 'commit' to your function, so that
|
|
|
|
callers can commit it immediately.
|
|
|
|
|
2022-03-23 05:17:32 +00:00
|
|
|
This decorator adds the attribute 'is_worms_transaction = True' to your
|
|
|
|
function. You can use this to distinguish readonly vs writing methods during
|
|
|
|
runtime.
|
|
|
|
|
2022-03-14 22:32:11 +00:00
|
|
|
If you want to raise an exception without rolling back, you can return
|
|
|
|
worms.raise_without_rollback(exc). This could be useful if you want to
|
|
|
|
preserve some kind of attempted action in the database while still raising
|
|
|
|
the action's failure.
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
@functools.wraps(method)
|
|
|
|
def wrapped_transaction(self, *args, commit=False, **kwargs):
|
|
|
|
if isinstance(self, Object):
|
|
|
|
self.assert_not_deleted()
|
|
|
|
|
|
|
|
database = self._worms_database
|
|
|
|
|
|
|
|
is_root = len(database.savepoints) == 0
|
|
|
|
savepoint_id = database.savepoint(message=method.__qualname__)
|
|
|
|
|
|
|
|
try:
|
|
|
|
result = method(self, *args, **kwargs)
|
|
|
|
except BaseException as exc:
|
|
|
|
log.debug(f'{method} raised {repr(exc)}.')
|
|
|
|
database.rollback(savepoint=savepoint_id)
|
|
|
|
raise
|
|
|
|
|
2022-03-14 22:32:11 +00:00
|
|
|
if isinstance(result, raise_without_rollback):
|
|
|
|
raise result.exc from result.exc
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
if commit:
|
|
|
|
database.commit(message=method.__qualname__)
|
|
|
|
elif not is_root:
|
2022-03-14 22:32:11 +00:00
|
|
|
# In order to prevent a huge pile-up of savepoints when a
|
|
|
|
# @transaction calls another @transaction many times, the sub-call
|
|
|
|
# savepoints are removed from the stack. When an exception occurs,
|
|
|
|
# we're going to rollback from the rootmost savepoint anyway, we'll
|
|
|
|
# never rollback one sub-transaction.
|
2021-09-18 01:24:03 +00:00
|
|
|
database.release_savepoint(savepoint=savepoint_id)
|
|
|
|
return result
|
|
|
|
|
2022-03-23 05:17:32 +00:00
|
|
|
wrapped_transaction.is_worms_transaction = True
|
2021-09-18 01:24:03 +00:00
|
|
|
return wrapped_transaction
|
|
|
|
|
2021-10-09 19:12:18 +00:00
|
|
|
class Database(metaclass=abc.ABCMeta):
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
When your class subclasses this class, you need to ensure the following:
|
|
|
|
- self.COLUMNS is a dictionary of {table: [columns]} like what comes out of
|
|
|
|
sqlhelpers.extract_table_column_map.
|
|
|
|
- self.COLUMN_INDEX is a dictionary of {table: {column: index}} like what
|
|
|
|
comes out of sqlhelpers.reverse_table_column_map.
|
|
|
|
'''
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
# Used for transaction
|
|
|
|
self._worms_database = self
|
|
|
|
self.on_commit_queue = []
|
|
|
|
self.on_rollback_queue = []
|
|
|
|
self.savepoints = []
|
2022-03-22 02:41:59 +00:00
|
|
|
# If your IDs are integers, you could change this to int. This way, when
|
|
|
|
# taking user input as strings, they will automatically be converted to
|
|
|
|
# int when querying and caching, and you don't have to do the conversion
|
|
|
|
# on the application side.
|
|
|
|
self.id_type = str
|
2022-03-14 22:37:46 +00:00
|
|
|
self.last_commit_id = None
|
2021-09-18 01:24:03 +00:00
|
|
|
|
2021-10-09 19:12:18 +00:00
|
|
|
@abc.abstractmethod
|
|
|
|
def _init_column_index(self):
|
|
|
|
'''
|
|
|
|
Your subclass needs to set self.COLUMNS and self.COLUMN_INDEX, where
|
|
|
|
COLUMNS is a dictionary of {'table': ['column1', 'column2', ...]} and
|
|
|
|
COLUMN_INDEX is a dict of {'table': {'column1': 0, 'column2': 1}}.
|
|
|
|
|
|
|
|
These outputs can come from sqlhelpers.extract_table_column_map and
|
|
|
|
reverse_table_column_map.
|
|
|
|
'''
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
def _init_sql(self):
|
|
|
|
'''
|
|
|
|
Your subclass needs to set self.sql, which is a database connection.
|
2022-03-16 01:17:23 +00:00
|
|
|
|
|
|
|
It is recommended to set self.sql.row_factory = sqlite3.Row so that you
|
|
|
|
get dictionary-style named access to row members in your objects' init.
|
2021-10-09 19:12:18 +00:00
|
|
|
'''
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
def assert_table_exists(self, table) -> None:
|
2022-03-20 03:03:26 +00:00
|
|
|
if table not in self.COLUMN_INDEX:
|
2021-09-18 01:24:03 +00:00
|
|
|
raise BadTable(f'Table {table} does not exist.')
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
# Wrapped in hasattr because if the object fails __init__, Python will
|
|
|
|
# still call __del__ and thus close(), even though the attributes
|
|
|
|
# we're trying to clean up never got set.
|
|
|
|
if hasattr(self, 'sql'):
|
|
|
|
self.sql.close()
|
|
|
|
|
|
|
|
def commit(self, message=None) -> None:
|
2021-12-22 01:02:26 +00:00
|
|
|
if message is None:
|
|
|
|
log.debug('Committing.')
|
|
|
|
else:
|
2021-09-18 01:24:03 +00:00
|
|
|
log.debug('Committing - %s.', message)
|
|
|
|
|
|
|
|
while len(self.on_commit_queue) > 0:
|
|
|
|
task = self.on_commit_queue.pop(-1)
|
2022-03-10 19:26:51 +00:00
|
|
|
if isinstance(task, int):
|
2021-09-18 01:24:03 +00:00
|
|
|
# savepoints.
|
|
|
|
continue
|
|
|
|
args = task.get('args', [])
|
|
|
|
kwargs = task.get('kwargs', {})
|
|
|
|
action = task['action']
|
|
|
|
try:
|
|
|
|
action(*args, **kwargs)
|
|
|
|
except Exception as exc:
|
|
|
|
log.debug(f'{action} raised {repr(exc)}.')
|
|
|
|
self.rollback()
|
|
|
|
raise
|
|
|
|
|
|
|
|
self.savepoints.clear()
|
|
|
|
self.sql.commit()
|
2022-03-14 22:37:46 +00:00
|
|
|
self.last_commit_id = RNG.getrandbits(32)
|
2021-09-18 01:24:03 +00:00
|
|
|
|
2022-03-23 05:16:26 +00:00
|
|
|
def delete(self, table, pairs) -> sqlite3.Cursor:
|
2021-09-18 01:24:03 +00:00
|
|
|
if isinstance(table, type) and issubclass(table, Object):
|
|
|
|
table = table.table
|
|
|
|
self.assert_table_exists(table)
|
|
|
|
(qmarks, bindings) = sqlhelpers.delete_filler(pairs)
|
|
|
|
query = f'DELETE FROM {table} {qmarks}'
|
2022-03-14 22:35:51 +00:00
|
|
|
return self.execute(query, bindings)
|
2021-09-18 01:24:03 +00:00
|
|
|
|
|
|
|
def execute(self, query, bindings=[]):
|
|
|
|
if bindings is None:
|
|
|
|
bindings = []
|
|
|
|
cur = self.sql.cursor()
|
|
|
|
log.loud('%s %s', query, bindings)
|
|
|
|
cur.execute(query, bindings)
|
|
|
|
return cur
|
|
|
|
|
|
|
|
def executescript(self, script) -> None:
|
|
|
|
'''
|
|
|
|
The problem with Python's default executescript is that it executes a
|
|
|
|
COMMIT before running your script. If I wanted a commit I'd write one!
|
|
|
|
'''
|
|
|
|
lines = re.split(r';(:?\n|$)', script)
|
|
|
|
lines = (line.strip() for line in lines)
|
|
|
|
lines = (line for line in lines if line)
|
|
|
|
cur = self.sql.cursor()
|
|
|
|
for line in lines:
|
|
|
|
log.loud(line)
|
|
|
|
cur.execute(line)
|
|
|
|
|
|
|
|
def get_object_by_id(self, object_class, object_id):
|
|
|
|
'''
|
|
|
|
Select an object by its ID.
|
|
|
|
'''
|
|
|
|
if isinstance(object_id, object_class):
|
|
|
|
object_id = object_id.id
|
|
|
|
|
2022-03-22 02:41:59 +00:00
|
|
|
object_id = self.normalize_object_id(object_class, object_id)
|
2021-09-18 01:24:03 +00:00
|
|
|
query = f'SELECT * FROM {object_class.table} WHERE id == ?'
|
|
|
|
bindings = [object_id]
|
|
|
|
object_row = self.select_one(query, bindings)
|
|
|
|
if object_row is None:
|
|
|
|
raise object_class.no_such_exception(object_id)
|
|
|
|
|
|
|
|
instance = object_class(self, object_row)
|
|
|
|
|
|
|
|
return instance
|
|
|
|
|
|
|
|
def get_objects(self, object_class):
|
|
|
|
'''
|
|
|
|
Yield objects, unfiltered, in whatever order they appear in the database.
|
|
|
|
'''
|
|
|
|
table = object_class.table
|
|
|
|
query = f'SELECT * FROM {table}'
|
|
|
|
|
|
|
|
objects = self.select(query)
|
|
|
|
for object_row in objects:
|
|
|
|
instance = object_class(self, object_row)
|
|
|
|
yield instance
|
|
|
|
|
2021-11-11 07:01:55 +00:00
|
|
|
def get_objects_by_id(self, object_class, object_ids, *, raise_for_missing=False):
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
Select many objects by their IDs.
|
2021-11-11 07:01:55 +00:00
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
This is better than calling get_object_by_id in a loop because we can
|
|
|
|
use a single SQL select to get batches of up to 999 items.
|
|
|
|
|
2021-11-11 07:01:55 +00:00
|
|
|
Note: The order of the output will most likely not match the order of
|
2021-09-18 01:24:03 +00:00
|
|
|
the input. Consider using get_objects_by_sql if that is a necessity.
|
2021-11-11 07:01:55 +00:00
|
|
|
|
|
|
|
raise_for_missing:
|
|
|
|
If any of the requested object ids are not found in the database,
|
|
|
|
we can raise that class's no_such_exception with the set of missing
|
|
|
|
IDs.
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
2022-03-22 02:41:59 +00:00
|
|
|
(object_ids, missing) = self.normalize_object_ids(object_ids)
|
2021-11-11 07:01:55 +00:00
|
|
|
ids_needed = list(object_ids)
|
|
|
|
ids_found = set()
|
2021-09-18 01:24:03 +00:00
|
|
|
|
|
|
|
while ids_needed:
|
|
|
|
# SQLite3 has a limit of 999 ? in a query, so we must batch them.
|
|
|
|
id_batch = ids_needed[:999]
|
|
|
|
ids_needed = ids_needed[999:]
|
|
|
|
|
|
|
|
qmarks = ','.join('?' * len(id_batch))
|
|
|
|
qmarks = f'({qmarks})'
|
|
|
|
query = f'SELECT * FROM {object_class.table} WHERE id IN {qmarks}'
|
|
|
|
for object_row in self.select(query, id_batch):
|
|
|
|
instance = object_class(self, db_row=object_row)
|
2021-11-11 07:01:55 +00:00
|
|
|
ids_found.add(instance.id)
|
2021-09-18 01:24:03 +00:00
|
|
|
yield instance
|
|
|
|
|
2021-11-11 07:01:55 +00:00
|
|
|
if raise_for_missing:
|
2022-03-22 02:41:59 +00:00
|
|
|
missing.update(object_ids.difference(ids_found))
|
2021-11-11 07:01:55 +00:00
|
|
|
if missing:
|
|
|
|
raise object_class.no_such_exception(missing)
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
def get_objects_by_sql(self, object_class, query, bindings=None):
|
|
|
|
'''
|
|
|
|
Use an arbitrary SQL query to select objects from the database.
|
|
|
|
Your query should select * from the object's table.
|
|
|
|
'''
|
|
|
|
object_rows = self.select(query, bindings)
|
|
|
|
for object_row in object_rows:
|
|
|
|
yield object_class(self, object_row)
|
|
|
|
|
2022-03-15 20:37:28 +00:00
|
|
|
def get_tables(self) -> set[str]:
|
|
|
|
'''
|
|
|
|
Return the set of all table names in the database.
|
|
|
|
'''
|
|
|
|
query = 'SELECT name FROM sqlite_master WHERE type = "table"'
|
2022-03-20 02:51:25 +00:00
|
|
|
tables = set(self.select_column(query))
|
2022-03-15 20:37:28 +00:00
|
|
|
return tables
|
|
|
|
|
2022-03-23 05:16:26 +00:00
|
|
|
def insert(self, table, data) -> sqlite3.Cursor:
|
2021-09-18 01:24:03 +00:00
|
|
|
if isinstance(table, type) and issubclass(table, Object):
|
|
|
|
table = table.table
|
|
|
|
self.assert_table_exists(table)
|
|
|
|
column_names = self.COLUMNS[table]
|
|
|
|
(qmarks, bindings) = sqlhelpers.insert_filler(column_names, data)
|
|
|
|
query = f'INSERT INTO {table} VALUES({qmarks})'
|
2022-03-14 22:35:51 +00:00
|
|
|
return self.execute(query, bindings)
|
2021-09-18 01:24:03 +00:00
|
|
|
|
2022-03-22 02:41:59 +00:00
|
|
|
def normalize_object_id(self, object_class, object_id):
|
|
|
|
'''
|
|
|
|
Given an object ID as input by the user, try to convert it using
|
|
|
|
self.id_type. If that raises a ValueError, then we raise
|
|
|
|
that class's no_such_exception.
|
|
|
|
|
|
|
|
Just because an ID passes the type conversion does not mean that ID
|
|
|
|
actually exists. We can raise the no_such_exception because an invalid
|
|
|
|
ID certainly doesn't exist, but a valid one still might not exist.
|
|
|
|
'''
|
|
|
|
try:
|
|
|
|
return self.id_type(object_id)
|
|
|
|
except ValueError:
|
|
|
|
raise object_class.no_such_exception(object_id)
|
|
|
|
|
|
|
|
def normalize_object_ids(self, object_ids):
|
|
|
|
'''
|
|
|
|
Given a list of object ids, return two sets: the first set contains all
|
|
|
|
the IDs that were able to be normalized using self.id_type; the second
|
|
|
|
contains all the IDs that raised ValueError. This method does not raise
|
|
|
|
the no_such_exception. as you may prefer to process the good instead of
|
|
|
|
losing it all with an exception.
|
|
|
|
|
|
|
|
Just because an ID passes the type conversion does not mean that ID
|
|
|
|
actually exists.
|
|
|
|
'''
|
|
|
|
good = set()
|
|
|
|
bad = set()
|
|
|
|
for object_id in object_ids:
|
|
|
|
try:
|
|
|
|
good.add(self.id_type(object_id))
|
|
|
|
except ValueError:
|
|
|
|
bad.add(object_id)
|
|
|
|
|
|
|
|
return (good, bad)
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
def release_savepoint(self, savepoint, allow_commit=False) -> None:
|
|
|
|
'''
|
|
|
|
Releasing a savepoint removes that savepoint from the timeline, so that
|
|
|
|
you can no longer roll back to it. Then your choices are to commit
|
|
|
|
everything, or roll back to a previous point. If you release the
|
|
|
|
earliest savepoint, the database will commit.
|
|
|
|
'''
|
|
|
|
if savepoint not in self.savepoints:
|
|
|
|
log.warn('Tried to release nonexistent savepoint %s.', savepoint)
|
|
|
|
return
|
|
|
|
|
|
|
|
is_commit = savepoint == self.savepoints[0]
|
|
|
|
if is_commit and not allow_commit:
|
|
|
|
log.debug('Not committing %s without allow_commit=True.', savepoint)
|
|
|
|
return
|
|
|
|
|
|
|
|
if is_commit:
|
|
|
|
# We want to perform the on_commit_queue so let's use our commit
|
|
|
|
# method instead of allowing sql's release to commit.
|
|
|
|
self.commit()
|
|
|
|
else:
|
|
|
|
self.execute(f'RELEASE "{savepoint}"')
|
|
|
|
self.savepoints = slice_before(self.savepoints, savepoint)
|
|
|
|
|
|
|
|
def rollback(self, savepoint=None) -> None:
|
|
|
|
'''
|
|
|
|
Given a savepoint, roll the database back to the moment before that
|
|
|
|
savepoint was created. Keep in mind that a @transaction savepoint is
|
|
|
|
always created *before* the method actually does anything.
|
|
|
|
|
|
|
|
If no savepoint is provided then rollback the entire transaction.
|
|
|
|
'''
|
|
|
|
if savepoint is not None and savepoint not in self.savepoints:
|
|
|
|
log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
|
|
|
|
return
|
|
|
|
|
|
|
|
if len(self.savepoints) == 0:
|
|
|
|
log.debug('Nothing to roll back.')
|
|
|
|
return
|
|
|
|
|
|
|
|
while len(self.on_rollback_queue) > 0:
|
|
|
|
task = self.on_rollback_queue.pop(-1)
|
|
|
|
if task == savepoint:
|
|
|
|
break
|
2022-03-10 19:26:51 +00:00
|
|
|
if isinstance(task, int):
|
2021-09-18 01:24:03 +00:00
|
|
|
# Intermediate savepoints.
|
|
|
|
continue
|
|
|
|
args = task.get('args', [])
|
|
|
|
kwargs = task.get('kwargs', {})
|
|
|
|
task['action'](*args, **kwargs)
|
|
|
|
|
|
|
|
if savepoint is not None:
|
|
|
|
log.debug('Rolling back to %s.', savepoint)
|
|
|
|
self.execute(f'ROLLBACK TO "{savepoint}"')
|
|
|
|
self.savepoints = slice_before(self.savepoints, savepoint)
|
|
|
|
self.on_commit_queue = slice_before(self.on_commit_queue, savepoint)
|
|
|
|
|
|
|
|
else:
|
|
|
|
log.debug('Rolling back.')
|
|
|
|
self.execute('ROLLBACK')
|
|
|
|
self.savepoints.clear()
|
|
|
|
self.on_commit_queue.clear()
|
|
|
|
|
2022-03-23 05:16:42 +00:00
|
|
|
def savepoint(self, message=None) -> int:
|
2022-03-10 19:26:51 +00:00
|
|
|
savepoint_id = RNG.getrandbits(32)
|
2021-09-18 01:24:03 +00:00
|
|
|
if message:
|
|
|
|
log.log(5, 'Savepoint %s for %s.', savepoint_id, message)
|
|
|
|
else:
|
|
|
|
log.log(5, 'Savepoint %s.', savepoint_id)
|
|
|
|
query = f'SAVEPOINT "{savepoint_id}"'
|
|
|
|
self.execute(query)
|
|
|
|
self.savepoints.append(savepoint_id)
|
|
|
|
self.on_commit_queue.append(savepoint_id)
|
|
|
|
self.on_rollback_queue.append(savepoint_id)
|
|
|
|
return savepoint_id
|
|
|
|
|
|
|
|
def select(self, query, bindings=None) -> typing.Iterable:
|
|
|
|
cur = self.execute(query, bindings)
|
|
|
|
while True:
|
|
|
|
fetch = cur.fetchone()
|
|
|
|
if fetch is None:
|
|
|
|
break
|
|
|
|
yield fetch
|
|
|
|
|
|
|
|
def select_column(self, query, bindings=None) -> typing.Iterable:
|
|
|
|
'''
|
|
|
|
If your SELECT query only selects a single column, you can use this
|
|
|
|
function to get a generator of the individual values instead
|
|
|
|
of one-tuples.
|
|
|
|
'''
|
|
|
|
for row in self.select(query, bindings):
|
|
|
|
yield row[0]
|
|
|
|
|
|
|
|
def select_one(self, query, bindings=None):
|
2022-03-14 22:34:41 +00:00
|
|
|
'''
|
|
|
|
Select a single row, or None if no rows match your query.
|
|
|
|
'''
|
2021-09-18 01:24:03 +00:00
|
|
|
cur = self.execute(query, bindings)
|
|
|
|
return cur.fetchone()
|
|
|
|
|
2022-06-04 02:42:20 +00:00
|
|
|
def select_one_value(self, query, bindings=None, fallback=None):
|
2022-03-14 22:34:41 +00:00
|
|
|
'''
|
2022-06-04 02:42:20 +00:00
|
|
|
Select a single column out of a single row, or fallback if no rows match
|
|
|
|
your query. The fallback can help you distinguish between rows that
|
|
|
|
don't exist and a null value.
|
2022-03-14 22:34:41 +00:00
|
|
|
'''
|
|
|
|
cur = self.execute(query, bindings)
|
|
|
|
row = cur.fetchone()
|
|
|
|
if row:
|
|
|
|
return row[0]
|
|
|
|
else:
|
2022-06-04 02:42:20 +00:00
|
|
|
return fallback
|
2022-03-14 22:34:41 +00:00
|
|
|
|
2022-03-23 05:16:26 +00:00
|
|
|
def update(self, table, pairs, where_key) -> sqlite3.Cursor:
|
2021-09-18 01:24:03 +00:00
|
|
|
if isinstance(table, type) and issubclass(table, Object):
|
|
|
|
table = table.table
|
|
|
|
self.assert_table_exists(table)
|
|
|
|
(qmarks, bindings) = sqlhelpers.update_filler(pairs, where_key=where_key)
|
|
|
|
query = f'UPDATE {table} {qmarks}'
|
2022-03-14 22:35:51 +00:00
|
|
|
return self.execute(query, bindings)
|
2021-09-18 01:24:03 +00:00
|
|
|
|
2021-10-09 19:12:18 +00:00
|
|
|
class DatabaseWithCaching(Database, metaclass=abc.ABCMeta):
|
2021-09-18 01:24:03 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.caches = {}
|
|
|
|
|
2021-10-09 19:12:18 +00:00
|
|
|
def _init_caches(self):
|
|
|
|
'''
|
|
|
|
Your subclass needs to set self.caches, which is a dictionary of
|
|
|
|
{object: cache} where object is one of your data object types
|
|
|
|
(use the class itself as the key) and cache is a dictionary or
|
|
|
|
cacheclass.Cache or anything that supports subscripting.
|
|
|
|
|
|
|
|
If any types are omitted from this dictionary, objects of those
|
|
|
|
types will not be cached.
|
|
|
|
'''
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
def clear_all_caches(self) -> None:
|
|
|
|
for cache in self.caches:
|
|
|
|
cache.clear()
|
|
|
|
|
|
|
|
def get_cached_instance(self, object_class, db_row):
|
|
|
|
'''
|
|
|
|
Check if there is already an instance in the cache and return that.
|
|
|
|
Otherwise, a new instance is created, cached, and returned.
|
|
|
|
|
|
|
|
Note that in order to call this method you have to already have a
|
|
|
|
db_row which means performing some select. If you only have the ID,
|
|
|
|
use get_object_by_id, as there may already be a cached instance to save
|
|
|
|
you the select.
|
|
|
|
'''
|
|
|
|
object_table = object_class.table
|
|
|
|
object_cache = self.caches.get(object_class, None)
|
|
|
|
|
2022-03-20 03:03:50 +00:00
|
|
|
if isinstance(db_row, (dict, sqlite3.Row)):
|
2021-09-18 01:24:03 +00:00
|
|
|
object_id = db_row['id']
|
|
|
|
else:
|
|
|
|
object_index = self.COLUMN_INDEX[object_table]
|
|
|
|
object_id = db_row[object_index['id']]
|
|
|
|
|
|
|
|
if object_cache is None:
|
|
|
|
return object_class(self, db_row)
|
|
|
|
|
|
|
|
try:
|
|
|
|
instance = object_cache[object_id]
|
|
|
|
except KeyError:
|
|
|
|
log.loud('Cache miss %s %s.', object_class, object_id)
|
|
|
|
instance = object_class(self, db_row)
|
|
|
|
object_cache[object_id] = instance
|
|
|
|
return instance
|
|
|
|
|
|
|
|
def get_object_by_id(self, object_class, object_id):
|
|
|
|
'''
|
|
|
|
This method will first check the cache to see if there is already an
|
|
|
|
instance with that ID, in which case we don't need to perform any SQL
|
|
|
|
select. If it is not in the cache, then a new instance is created,
|
|
|
|
cached, and returned.
|
|
|
|
'''
|
|
|
|
if isinstance(object_id, object_class):
|
|
|
|
# This could be used to check if your old reference to an object is
|
|
|
|
# still in the cache, or re-select it from the db to make sure it
|
|
|
|
# still exists and re-cache.
|
|
|
|
# Probably an uncommon need but... no harm I think.
|
|
|
|
object_id = object_id.id
|
|
|
|
|
2022-03-22 02:41:59 +00:00
|
|
|
object_id = self.normalize_object_id(object_class, object_id)
|
2021-09-18 01:24:03 +00:00
|
|
|
object_cache = self.caches.get(object_class, None)
|
|
|
|
|
|
|
|
if object_cache is not None:
|
|
|
|
try:
|
|
|
|
return object_cache[object_id]
|
|
|
|
except KeyError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
query = f'SELECT * FROM {object_class.table} WHERE id == ?'
|
|
|
|
bindings = [object_id]
|
|
|
|
object_row = self.select_one(query, bindings)
|
|
|
|
if object_row is None:
|
|
|
|
raise object_class.no_such_exception(object_id)
|
|
|
|
|
|
|
|
# Normally we would call `get_cached_instance` instead of
|
|
|
|
# constructing here. But we already know for a fact that this
|
|
|
|
# object is not in the cache.
|
|
|
|
instance = object_class(self, object_row)
|
|
|
|
|
|
|
|
if object_cache is not None:
|
2022-03-10 19:24:52 +00:00
|
|
|
object_cache[instance.id] = instance
|
2021-09-18 01:24:03 +00:00
|
|
|
|
|
|
|
return instance
|
|
|
|
|
|
|
|
def get_objects(self, object_class):
|
|
|
|
'''
|
|
|
|
Yield objects, unfiltered, in whatever order they appear in the database.
|
|
|
|
'''
|
|
|
|
table = object_class.table
|
|
|
|
query = f'SELECT * FROM {table}'
|
|
|
|
|
|
|
|
objects = self.select(query)
|
|
|
|
for object_row in objects:
|
|
|
|
instance = self.get_cached_instance(object_class, object_row)
|
|
|
|
yield instance
|
|
|
|
|
2021-11-11 07:01:55 +00:00
|
|
|
def get_objects_by_id(self, object_class, object_ids, *, raise_for_missing=False):
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
Given multiple IDs, this method will find which ones are in the cache
|
|
|
|
and which ones need to be selected from the db.
|
2021-11-11 07:01:55 +00:00
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
This is better than calling get_object_by_id in a loop because we can
|
|
|
|
use a single SQL select to get batches of up to 999 items.
|
|
|
|
|
|
|
|
Note: The order of the output will most likely not match the order of
|
|
|
|
the input, because we first pull items from the cache before requesting
|
|
|
|
the rest from the database.
|
2021-11-11 07:01:55 +00:00
|
|
|
|
|
|
|
raise_for_missing:
|
|
|
|
If any of the requested object ids are not found in the database,
|
|
|
|
we can raise that class's no_such_exception with the set of missing
|
|
|
|
IDs.
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
object_cache = self.caches.get(object_class, None)
|
|
|
|
|
2022-03-22 02:41:59 +00:00
|
|
|
(object_ids, missing) = self.normalize_object_ids(object_ids)
|
2021-09-18 01:24:03 +00:00
|
|
|
ids_needed = set()
|
2021-11-11 07:01:55 +00:00
|
|
|
ids_found = set()
|
2021-09-18 01:24:03 +00:00
|
|
|
|
|
|
|
if object_cache is None:
|
|
|
|
ids_needed.update(object_ids)
|
|
|
|
else:
|
|
|
|
for object_id in object_ids:
|
|
|
|
try:
|
|
|
|
instance = object_cache[object_id]
|
|
|
|
except KeyError:
|
|
|
|
ids_needed.add(object_id)
|
|
|
|
else:
|
2021-11-11 07:01:55 +00:00
|
|
|
ids_found.add(object_id)
|
2021-09-18 01:24:03 +00:00
|
|
|
yield instance
|
|
|
|
|
|
|
|
if not ids_needed:
|
|
|
|
return
|
|
|
|
|
|
|
|
if object_cache is not None:
|
|
|
|
log.loud('Cache miss %s %s.', object_class.table, ids_needed)
|
|
|
|
|
|
|
|
ids_needed = list(ids_needed)
|
|
|
|
while ids_needed:
|
|
|
|
# SQLite3 has a limit of 999 ? in a query, so we must batch them.
|
|
|
|
id_batch = ids_needed[:999]
|
|
|
|
ids_needed = ids_needed[999:]
|
|
|
|
|
|
|
|
qmarks = ','.join('?' * len(id_batch))
|
|
|
|
qmarks = f'({qmarks})'
|
|
|
|
query = f'SELECT * FROM {object_class.table} WHERE id IN {qmarks}'
|
|
|
|
for object_row in self.select(query, id_batch):
|
|
|
|
# Normally we would call `get_cached_instance` instead of
|
|
|
|
# constructing here. But we already know for a fact that this
|
|
|
|
# object is not in the cache because it made it past the
|
|
|
|
# previous loop.
|
|
|
|
instance = object_class(self, db_row=object_row)
|
|
|
|
if object_cache is not None:
|
|
|
|
object_cache[instance.id] = instance
|
2021-11-11 07:01:55 +00:00
|
|
|
ids_found.add(instance.id)
|
2021-09-18 01:24:03 +00:00
|
|
|
yield instance
|
|
|
|
|
2021-11-11 07:01:55 +00:00
|
|
|
if raise_for_missing:
|
2022-03-22 02:41:59 +00:00
|
|
|
missing.update(object_ids.difference(ids_found))
|
2021-11-11 07:01:55 +00:00
|
|
|
if missing:
|
|
|
|
raise object_class.no_such_exception(missing)
|
|
|
|
|
2021-09-18 01:24:03 +00:00
|
|
|
def get_objects_by_sql(self, object_class, query, bindings=None):
|
|
|
|
'''
|
|
|
|
Use an arbitrary SQL query to select objects from the database.
|
|
|
|
Your query should select * from the object's table.
|
|
|
|
'''
|
|
|
|
object_rows = self.select(query, bindings)
|
|
|
|
for object_row in object_rows:
|
|
|
|
yield self.get_cached_instance(object_class, object_row)
|
|
|
|
|
2021-10-09 19:12:18 +00:00
|
|
|
class Object(metaclass=abc.ABCMeta):
|
2021-09-18 01:24:03 +00:00
|
|
|
'''
|
|
|
|
When your objects subclass this class, you need to ensure the following:
|
|
|
|
|
|
|
|
- self.table should be a string.
|
|
|
|
- self.no_such_exception should be an exception class, to be raised when
|
|
|
|
the user requests an instance of this class that does not exist.
|
|
|
|
Initialized with a single argument, the requested ID.
|
|
|
|
'''
|
|
|
|
def __init__(self, database):
|
2022-03-15 20:37:28 +00:00
|
|
|
'''
|
|
|
|
Your subclass should call super().__init__(database).
|
|
|
|
'''
|
2021-09-18 01:24:03 +00:00
|
|
|
# Used for transaction
|
|
|
|
self._worms_database = database
|
2021-10-02 19:45:17 +00:00
|
|
|
self.deleted = False
|
2021-09-18 01:24:03 +00:00
|
|
|
|
|
|
|
def __reinit__(self):
|
|
|
|
'''
|
|
|
|
Reload the row from the database and do __init__ with it.
|
|
|
|
'''
|
|
|
|
query = f'SELECT * FROM {self.table} WHERE id == ?'
|
|
|
|
bindings = [self.id]
|
|
|
|
row = self._worms_database.select_one(query, bindings)
|
|
|
|
if row is None:
|
|
|
|
self.deleted = True
|
|
|
|
else:
|
|
|
|
self.__init__(self._worms_database, row)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (
|
|
|
|
isinstance(other, type(self)) and
|
|
|
|
self._worms_database == other._worms_database and
|
|
|
|
self.id == other.id
|
|
|
|
)
|
|
|
|
|
|
|
|
def __format__(self, formcode):
|
|
|
|
if formcode == 'r':
|
|
|
|
return repr(self)
|
|
|
|
else:
|
|
|
|
return str(self)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(f'{self.table}.{self.id}')
|
2021-10-02 19:45:17 +00:00
|
|
|
|
|
|
|
def assert_not_deleted(self) -> None:
|
|
|
|
'''
|
|
|
|
Raises DeletedObject if this object is deleted.
|
2022-03-15 20:37:28 +00:00
|
|
|
|
|
|
|
You need to set self.deleted during any method that deletes the object
|
|
|
|
from the database.
|
2021-10-02 19:45:17 +00:00
|
|
|
'''
|
|
|
|
if self.deleted:
|
|
|
|
raise DeletedObject(self)
|