Simplify and improve rollback logic.

This commit is contained in:
voussoir 2018-07-29 16:28:57 -07:00
parent 30d96139c2
commit 30e3aa9c6f
2 changed files with 27 additions and 23 deletions

View file

@ -408,6 +408,10 @@ def seconds_to_hms(seconds):
hms = ':'.join(f'{part:02d}' for part in parts) hms = ':'.join(f'{part:02d}' for part in parts)
return hms return hms
def slice_before(li, item):
index = li.index(item)
return li[:index]
def split_easybake_string(ebstring): def split_easybake_string(ebstring):
''' '''
Given an easybake string, return (tagname, synonym, rename_to), where Given an easybake string, return (tagname, synonym, rename_to), where

View file

@ -693,32 +693,32 @@ class PDBSQLMixin:
self.sql.commit() self.sql.commit()
def rollback(self, savepoint=None): def rollback(self, savepoint=None):
if savepoint is not None: '''
valid_savepoint = savepoint in self.savepoints Given a savepoint, roll the database back to the moment before that
savepoint was created. Keep in mind that a @transaction savepoint is
always created *before* the method actually does anything.
If no savepoint is provided then just roll back the most recent save.
'''
if savepoint is None:
try:
savepoint = self.savepoints.pop(-1)
except IndexError:
self.log.debug('Nothing to roll back.')
return
else: else:
valid_savepoint = None try:
# Will reassign after everything goes well.
_savepoints = helpers.slice_before(self.savepoints, savepoint)
except ValueError:
self.log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
return
if valid_savepoint is False: self.log.debug('Rolling back to %s', savepoint)
self.log.warn('Tried to restore to a nonexistent savepoint. Did you commit too early?') query = 'ROLLBACK TO "%s"' % savepoint
if len(self.savepoints) == 0:
self.log.debug('Nothing to rollback.')
return
if valid_savepoint:
restore_to = savepoint
while self.savepoints.pop(-1) != restore_to:
pass
else:
restore_to = self.savepoints.pop(-1)
self.log.debug('Rolling back to %s', restore_to)
query = 'ROLLBACK TO "%s"' % restore_to
self.sql_execute(query) self.sql_execute(query)
while len(self.on_commit_queue) > 0: self.savepoints = _savepoints
item = self.on_commit_queue.pop(-1) self.on_commit_queue = helpers.slice_before(self.on_commit_queue, savepoint)
if item == restore_to:
break
def savepoint(self, message=None): def savepoint(self, message=None):
savepoint_id = helpers.random_hex(length=16) savepoint_id = helpers.random_hex(length=16)