Simplify and improve rollback logic.
This commit is contained in:
		
							parent
							
								
									30d96139c2
								
							
						
					
					
						commit
						30e3aa9c6f
					
				
					 2 changed files with 27 additions and 23 deletions
				
			
		|  | @ -408,6 +408,10 @@ def seconds_to_hms(seconds): | |||
|     hms = ':'.join(f'{part:02d}' for part in parts) | ||||
|     return hms | ||||
| 
 | ||||
| def slice_before(li, item): | ||||
|     index = li.index(item) | ||||
|     return li[:index] | ||||
| 
 | ||||
| def split_easybake_string(ebstring): | ||||
|     ''' | ||||
|     Given an easybake string, return (tagname, synonym, rename_to), where | ||||
|  |  | |||
|  | @ -693,32 +693,32 @@ class PDBSQLMixin: | |||
|         self.sql.commit() | ||||
| 
 | ||||
|     def rollback(self, savepoint=None): | ||||
|         if savepoint is not None: | ||||
|             valid_savepoint = savepoint in self.savepoints | ||||
|         ''' | ||||
|         Given a savepoint, roll the database back to the moment before that | ||||
|         savepoint was created. Keep in mind that a @transaction savepoint is | ||||
|         always created *before* the method actually does anything. | ||||
| 
 | ||||
|         If no savepoint is provided then just roll back the most recent save. | ||||
|         ''' | ||||
|         if savepoint is None: | ||||
|             try: | ||||
|                 savepoint = self.savepoints.pop(-1) | ||||
|             except IndexError: | ||||
|                 self.log.debug('Nothing to roll back.') | ||||
|                 return | ||||
|         else: | ||||
|             valid_savepoint = None | ||||
|             try: | ||||
|                 # Will reassign after everything goes well. | ||||
|                 _savepoints = helpers.slice_before(self.savepoints, savepoint) | ||||
|             except ValueError: | ||||
|                 self.log.warn('Tried to restore nonexistent savepoint %s.', savepoint) | ||||
|                 return | ||||
| 
 | ||||
|         if valid_savepoint is False: | ||||
|             self.log.warn('Tried to restore to a nonexistent savepoint. Did you commit too early?') | ||||
| 
 | ||||
|         if len(self.savepoints) == 0: | ||||
|             self.log.debug('Nothing to rollback.') | ||||
|             return | ||||
| 
 | ||||
|         if valid_savepoint: | ||||
|             restore_to = savepoint | ||||
|             while self.savepoints.pop(-1) != restore_to: | ||||
|                 pass | ||||
|         else: | ||||
|             restore_to = self.savepoints.pop(-1) | ||||
| 
 | ||||
|         self.log.debug('Rolling back to %s', restore_to) | ||||
|         query = 'ROLLBACK TO "%s"' % restore_to | ||||
|         self.log.debug('Rolling back to %s', savepoint) | ||||
|         query = 'ROLLBACK TO "%s"' % savepoint | ||||
|         self.sql_execute(query) | ||||
|         while len(self.on_commit_queue) > 0: | ||||
|             item = self.on_commit_queue.pop(-1) | ||||
|             if item == restore_to: | ||||
|                 break | ||||
|         self.savepoints = _savepoints | ||||
|         self.on_commit_queue = helpers.slice_before(self.on_commit_queue, savepoint) | ||||
| 
 | ||||
|     def savepoint(self, message=None): | ||||
|         savepoint_id = helpers.random_hex(length=16) | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue