Add argument raise_for_missing.

This commit is contained in:
voussoir 2021-11-10 23:01:55 -08:00
parent 4dddb21f74
commit 691b293939
No known key found for this signature in database
GPG key ID: 5F7554F8C26DACCB

View file

@ -193,20 +193,24 @@ class Database(metaclass=abc.ABCMeta):
instance = object_class(self, object_row) instance = object_class(self, object_row)
yield instance yield instance
def get_objects_by_id(self, object_class, object_ids): def get_objects_by_id(self, object_class, object_ids, *, raise_for_missing=False):
''' '''
Select many objects by their IDs. Select many objects by their IDs.
This is better than calling get_object_by_id in a loop because we can This is better than calling get_object_by_id in a loop because we can
use a single SQL select to get batches of up to 999 items. use a single SQL select to get batches of up to 999 items.
This method does not raise exceptions for IDs that are not present in Note: The order of the output will most likely not match the order of
the database. You will need to compare the set of returned objects with
the set of requested IDs.
Note: The order of the output is not guaranteed to match the order of
the input. Consider using get_objects_by_sql if that is a necessity. the input. Consider using get_objects_by_sql if that is a necessity.
raise_for_missing:
If any of the requested object ids are not found in the database,
we can raise that class's no_such_exception with the set of missing
IDs.
''' '''
ids_needed = list(set(object_ids)) object_ids = set(object_ids)
ids_needed = list(object_ids)
ids_found = set()
while ids_needed: while ids_needed:
# SQLite3 has a limit of 999 ? in a query, so we must batch them. # SQLite3 has a limit of 999 ? in a query, so we must batch them.
@ -218,8 +222,14 @@ class Database(metaclass=abc.ABCMeta):
query = f'SELECT * FROM {object_class.table} WHERE id IN {qmarks}' query = f'SELECT * FROM {object_class.table} WHERE id IN {qmarks}'
for object_row in self.select(query, id_batch): for object_row in self.select(query, id_batch):
instance = object_class(self, db_row=object_row) instance = object_class(self, db_row=object_row)
ids_found.add(instance.id)
yield instance yield instance
if raise_for_missing:
missing = object_ids.difference(ids_found)
if missing:
raise object_class.no_such_exception(missing)
def get_objects_by_sql(self, object_class, query, bindings=None): def get_objects_by_sql(self, object_class, query, bindings=None):
''' '''
Use an arbitrary SQL query to select objects from the database. Use an arbitrary SQL query to select objects from the database.
@ -459,20 +469,28 @@ class DatabaseWithCaching(Database, metaclass=abc.ABCMeta):
instance = self.get_cached_instance(object_class, object_row) instance = self.get_cached_instance(object_class, object_row)
yield instance yield instance
def get_objects_by_id(self, object_class, object_ids): def get_objects_by_id(self, object_class, object_ids, *, raise_for_missing=False):
''' '''
Given multiple IDs, this method will find which ones are in the cache Given multiple IDs, this method will find which ones are in the cache
and which ones need to be selected from the db. and which ones need to be selected from the db.
This is better than calling get_object_by_id in a loop because we can This is better than calling get_object_by_id in a loop because we can
use a single SQL select to get batches of up to 999 items. use a single SQL select to get batches of up to 999 items.
Note: The order of the output will most likely not match the order of Note: The order of the output will most likely not match the order of
the input, because we first pull items from the cache before requesting the input, because we first pull items from the cache before requesting
the rest from the database. the rest from the database.
raise_for_missing:
If any of the requested object ids are not found in the database,
we can raise that class's no_such_exception with the set of missing
IDs.
''' '''
object_cache = self.caches.get(object_class, None) object_cache = self.caches.get(object_class, None)
object_ids = set(object_ids)
ids_needed = set() ids_needed = set()
ids_found = set()
if object_cache is None: if object_cache is None:
ids_needed.update(object_ids) ids_needed.update(object_ids)
@ -483,6 +501,7 @@ class DatabaseWithCaching(Database, metaclass=abc.ABCMeta):
except KeyError: except KeyError:
ids_needed.add(object_id) ids_needed.add(object_id)
else: else:
ids_found.add(object_id)
yield instance yield instance
if not ids_needed: if not ids_needed:
@ -508,8 +527,14 @@ class DatabaseWithCaching(Database, metaclass=abc.ABCMeta):
instance = object_class(self, db_row=object_row) instance = object_class(self, db_row=object_row)
if object_cache is not None: if object_cache is not None:
object_cache[instance.id] = instance object_cache[instance.id] = instance
ids_found.add(instance.id)
yield instance yield instance
if raise_for_missing:
missing = object_ids.difference(ids_found)
if missing:
raise object_class.no_such_exception(missing)
def get_objects_by_sql(self, object_class, query, bindings=None): def get_objects_by_sql(self, object_class, query, bindings=None):
''' '''
Use an arbitrary SQL query to select objects from the database. Use an arbitrary SQL query to select objects from the database.