Simplify and improve rollback logic.
This commit is contained in:
parent
30d96139c2
commit
30e3aa9c6f
2 changed files with 27 additions and 23 deletions
|
@ -408,6 +408,10 @@ def seconds_to_hms(seconds):
|
||||||
hms = ':'.join(f'{part:02d}' for part in parts)
|
hms = ':'.join(f'{part:02d}' for part in parts)
|
||||||
return hms
|
return hms
|
||||||
|
|
||||||
|
def slice_before(li, item):
|
||||||
|
index = li.index(item)
|
||||||
|
return li[:index]
|
||||||
|
|
||||||
def split_easybake_string(ebstring):
|
def split_easybake_string(ebstring):
|
||||||
'''
|
'''
|
||||||
Given an easybake string, return (tagname, synonym, rename_to), where
|
Given an easybake string, return (tagname, synonym, rename_to), where
|
||||||
|
|
|
@ -693,32 +693,32 @@ class PDBSQLMixin:
|
||||||
self.sql.commit()
|
self.sql.commit()
|
||||||
|
|
||||||
def rollback(self, savepoint=None):
|
def rollback(self, savepoint=None):
|
||||||
if savepoint is not None:
|
'''
|
||||||
valid_savepoint = savepoint in self.savepoints
|
Given a savepoint, roll the database back to the moment before that
|
||||||
|
savepoint was created. Keep in mind that a @transaction savepoint is
|
||||||
|
always created *before* the method actually does anything.
|
||||||
|
|
||||||
|
If no savepoint is provided then just roll back the most recent save.
|
||||||
|
'''
|
||||||
|
if savepoint is None:
|
||||||
|
try:
|
||||||
|
savepoint = self.savepoints.pop(-1)
|
||||||
|
except IndexError:
|
||||||
|
self.log.debug('Nothing to roll back.')
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
valid_savepoint = None
|
try:
|
||||||
|
# Will reassign after everything goes well.
|
||||||
|
_savepoints = helpers.slice_before(self.savepoints, savepoint)
|
||||||
|
except ValueError:
|
||||||
|
self.log.warn('Tried to restore nonexistent savepoint %s.', savepoint)
|
||||||
|
return
|
||||||
|
|
||||||
if valid_savepoint is False:
|
self.log.debug('Rolling back to %s', savepoint)
|
||||||
self.log.warn('Tried to restore to a nonexistent savepoint. Did you commit too early?')
|
query = 'ROLLBACK TO "%s"' % savepoint
|
||||||
|
|
||||||
if len(self.savepoints) == 0:
|
|
||||||
self.log.debug('Nothing to rollback.')
|
|
||||||
return
|
|
||||||
|
|
||||||
if valid_savepoint:
|
|
||||||
restore_to = savepoint
|
|
||||||
while self.savepoints.pop(-1) != restore_to:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
restore_to = self.savepoints.pop(-1)
|
|
||||||
|
|
||||||
self.log.debug('Rolling back to %s', restore_to)
|
|
||||||
query = 'ROLLBACK TO "%s"' % restore_to
|
|
||||||
self.sql_execute(query)
|
self.sql_execute(query)
|
||||||
while len(self.on_commit_queue) > 0:
|
self.savepoints = _savepoints
|
||||||
item = self.on_commit_queue.pop(-1)
|
self.on_commit_queue = helpers.slice_before(self.on_commit_queue, savepoint)
|
||||||
if item == restore_to:
|
|
||||||
break
|
|
||||||
|
|
||||||
def savepoint(self, message=None):
|
def savepoint(self, message=None):
|
||||||
savepoint_id = helpers.random_hex(length=16)
|
savepoint_id = helpers.random_hex(length=16)
|
||||||
|
|
Loading…
Reference in a new issue