diff --git a/README.md b/README.md index 552f09b..ba3f973 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,6 @@ If you are interested in helping, please raise an issue before making any pull r - Currently, the Jinja templates are having a tangling influence on the backend objects, because Jinja cannot import my other modules like bytestring, but it can access the methods of the objects I pass into the template. As a result, the objects have excess helper methods. Consider making them into Jinja filters instead. Which is also kind of ugly but will move that pollution out of the backend at least. - Perhaps instead of actually deleting objects, they should just have a `deleted` flag, to make easy restoration possible. Also consider regrouping the children of restored Groupables if those children haven't already been reassigned somewhere else. - Add a new table to store permanent history of add/remove of tags on photos, so that accidents or trolling can be reversed. -- Improve transaction rollbacking. I'm not satisfied with the @transaction decorator because sometimes I want to use exceptions as control flow without them rolling things back. Context managers are good but it's a matter of how abstracted they should be. - Photo thumbnail paths should be relative to the data_dir, they are currently one level up. Or maybe should remove the paths entirely and just recalculate it by the ID. Can't think of any reason to have a thumbnail point elsewhere. - Fix album size cache when photo reload metadata and generally improve that validation. - Better bookmark url validation. diff --git a/etiquette/decorators.py b/etiquette/decorators.py index e264995..46ca925 100644 --- a/etiquette/decorators.py +++ b/etiquette/decorators.py @@ -5,24 +5,27 @@ import warnings from . import exceptions +def _get_relevant_photodb(instance): + from . import objects + if isinstance(instance, objects.ObjectBase): + photodb = instance.photodb + else: + photodb = instance + return photodb + def required_feature(features): ''' Declare that the photodb or object method requires certain 'enable_*' fields in the config. ''' - from . import objects if isinstance(features, str): features = [features] def wrapper(function): @functools.wraps(function) def wrapped(self, *args, **kwargs): - if isinstance(self, objects.ObjectBase): - config = self.photodb.config - else: - config = self.config - - config = config['enable_feature'] + photodb = _get_relevant_photodb(self) + config = photodb.config['enable_feature'] # Using the received string like "photo.new", try to navigate the # config and wind up at a True. @@ -62,13 +65,19 @@ def time_me(function): return timed_function def transaction(method): + ''' + Open a savepoint before running the method. + If the method fails, roll back to that savepoint. + ''' @functools.wraps(method) def wrapped(self, *args, **kwargs): + photodb = _get_relevant_photodb(self) + photodb.savepoint() try: - ret = method(self, *args, **kwargs) - return ret + result = method(self, *args, **kwargs) except Exception as e: - self.log.debug('Rolling back') - self.sql.rollback() + photodb.rollback() raise + else: + return result return wrapped diff --git a/etiquette/photodb.py b/etiquette/photodb.py index c537f3a..dfd700f 100644 --- a/etiquette/photodb.py +++ b/etiquette/photodb.py @@ -685,6 +685,7 @@ class PDBSQLMixin: def __init__(self): super().__init__() self.on_commit_queue = [] + self.savepoints = [] def close(self): self.sql.close() @@ -698,8 +699,34 @@ class PDBSQLMixin: args = task.get('args', []) kwargs = task.get('kwargs', {}) task['action'](*args, **kwargs) + self.savepoints.clear() self.sql.commit() + def rollback(self): + if len(self.savepoints) == 0: + self.log.debug('Nothing to rollback.') + return + + if len(self.savepoints) == 1: + self.log.debug('Final rollback.') + self.sql.rollback() + self.savepoints.clear() + return + + cur = self.sql.cursor() + restore_to = self.savepoints.pop(-1) + self.log.debug('Rolling back to %s', restore_to) + query = 'ROLLBACK TO "%s"' % restore_to + cur.execute(query) + + def savepoint(self): + savepoint_id = helpers.random_hex(length=16) + self.log.debug('Savepoint %s.', savepoint_id) + query = 'SAVEPOINT "%s"' % savepoint_id + self.sql.execute(query) + self.savepoints.append(savepoint_id) + return savepoint_id + def sql_delete(self, table, pairs, *, commit=False): cur = self.sql.cursor()