diff --git a/etiquette/helpers.py b/etiquette/helpers.py index 696868a..91a576b 100644 --- a/etiquette/helpers.py +++ b/etiquette/helpers.py @@ -408,6 +408,10 @@ def seconds_to_hms(seconds): hms = ':'.join(f'{part:02d}' for part in parts) return hms +def slice_before(li, item): + index = li.index(item) + return li[:index] + def split_easybake_string(ebstring): ''' Given an easybake string, return (tagname, synonym, rename_to), where diff --git a/etiquette/photodb.py b/etiquette/photodb.py index 2b094a7..ce78490 100644 --- a/etiquette/photodb.py +++ b/etiquette/photodb.py @@ -693,32 +693,32 @@ class PDBSQLMixin: self.sql.commit() def rollback(self, savepoint=None): - if savepoint is not None: - valid_savepoint = savepoint in self.savepoints + ''' + Given a savepoint, roll the database back to the moment before that + savepoint was created. Keep in mind that a @transaction savepoint is + always created *before* the method actually does anything. + + If no savepoint is provided then just roll back the most recent save. + ''' + if savepoint is None: + try: + savepoint = self.savepoints.pop(-1) + except IndexError: + self.log.debug('Nothing to roll back.') + return else: - valid_savepoint = None + try: + # Will reassign after everything goes well. + _savepoints = helpers.slice_before(self.savepoints, savepoint) + except ValueError: + self.log.warn('Tried to restore nonexistent savepoint %s.', savepoint) + return - if valid_savepoint is False: - self.log.warn('Tried to restore to a nonexistent savepoint. Did you commit too early?') - - if len(self.savepoints) == 0: - self.log.debug('Nothing to rollback.') - return - - if valid_savepoint: - restore_to = savepoint - while self.savepoints.pop(-1) != restore_to: - pass - else: - restore_to = self.savepoints.pop(-1) - - self.log.debug('Rolling back to %s', restore_to) - query = 'ROLLBACK TO "%s"' % restore_to + self.log.debug('Rolling back to %s', savepoint) + query = 'ROLLBACK TO "%s"' % savepoint self.sql_execute(query) - while len(self.on_commit_queue) > 0: - item = self.on_commit_queue.pop(-1) - if item == restore_to: - break + self.savepoints = _savepoints + self.on_commit_queue = helpers.slice_before(self.on_commit_queue, savepoint) def savepoint(self, message=None): savepoint_id = helpers.random_hex(length=16)