diff --git a/voussoirkit/worms.py b/voussoirkit/worms.py index 9e17013..ed633cc 100644 --- a/voussoirkit/worms.py +++ b/voussoirkit/worms.py @@ -98,6 +98,11 @@ class Database(metaclass=abc.ABCMeta): self.on_commit_queue = [] self.on_rollback_queue = [] self.savepoints = [] + # If your IDs are integers, you could change this to int. This way, when + # taking user input as strings, they will automatically be converted to + # int when querying and caching, and you don't have to do the conversion + # on the application side. + self.id_type = str self.last_commit_id = None @abc.abstractmethod @@ -194,6 +199,7 @@ class Database(metaclass=abc.ABCMeta): if isinstance(object_id, object_class): object_id = object_id.id + object_id = self.normalize_object_id(object_class, object_id) query = f'SELECT * FROM {object_class.table} WHERE id == ?' bindings = [object_id] object_row = self.select_one(query, bindings) @@ -231,7 +237,7 @@ class Database(metaclass=abc.ABCMeta): we can raise that class's no_such_exception with the set of missing IDs. ''' - object_ids = set(object_ids) + (object_ids, missing) = self.normalize_object_ids(object_ids) ids_needed = list(object_ids) ids_found = set() @@ -249,7 +255,7 @@ class Database(metaclass=abc.ABCMeta): yield instance if raise_for_missing: - missing = object_ids.difference(ids_found) + missing.update(object_ids.difference(ids_found)) if missing: raise object_class.no_such_exception(missing) @@ -280,6 +286,42 @@ class Database(metaclass=abc.ABCMeta): query = f'INSERT INTO {table} VALUES({qmarks})' return self.execute(query, bindings) + def normalize_object_id(self, object_class, object_id): + ''' + Given an object ID as input by the user, try to convert it using + self.id_type. If that raises a ValueError, then we raise + that class's no_such_exception. + + Just because an ID passes the type conversion does not mean that ID + actually exists. We can raise the no_such_exception because an invalid + ID certainly doesn't exist, but a valid one still might not exist. + ''' + try: + return self.id_type(object_id) + except ValueError: + raise object_class.no_such_exception(object_id) + + def normalize_object_ids(self, object_ids): + ''' + Given a list of object ids, return two sets: the first set contains all + the IDs that were able to be normalized using self.id_type; the second + contains all the IDs that raised ValueError. This method does not raise + the no_such_exception. as you may prefer to process the good instead of + losing it all with an exception. + + Just because an ID passes the type conversion does not mean that ID + actually exists. + ''' + good = set() + bad = set() + for object_id in object_ids: + try: + good.add(self.id_type(object_id)) + except ValueError: + bad.add(object_id) + + return (good, bad) + def release_savepoint(self, savepoint, allow_commit=False) -> None: ''' Releasing a savepoint removes that savepoint from the timeline, so that @@ -465,6 +507,7 @@ class DatabaseWithCaching(Database, metaclass=abc.ABCMeta): # Probably an uncommon need but... no harm I think. object_id = object_id.id + object_id = self.normalize_object_id(object_class, object_id) object_cache = self.caches.get(object_class, None) if object_cache is not None: @@ -520,7 +563,7 @@ class DatabaseWithCaching(Database, metaclass=abc.ABCMeta): ''' object_cache = self.caches.get(object_class, None) - object_ids = set(object_ids) + (object_ids, missing) = self.normalize_object_ids(object_ids) ids_needed = set() ids_found = set() @@ -563,7 +606,7 @@ class DatabaseWithCaching(Database, metaclass=abc.ABCMeta): yield instance if raise_for_missing: - missing = object_ids.difference(ids_found) + missing.update(object_ids.difference(ids_found)) if missing: raise object_class.no_such_exception(missing)