Simplify and improve rollback logic.
This commit is contained in:
		
							parent
							
								
									30d96139c2
								
							
						
					
					
						commit
						30e3aa9c6f
					
				
					 2 changed files with 27 additions and 23 deletions
				
			
		|  | @ -408,6 +408,10 @@ def seconds_to_hms(seconds): | ||||||
|     hms = ':'.join(f'{part:02d}' for part in parts) |     hms = ':'.join(f'{part:02d}' for part in parts) | ||||||
|     return hms |     return hms | ||||||
| 
 | 
 | ||||||
|  | def slice_before(li, item): | ||||||
|  |     index = li.index(item) | ||||||
|  |     return li[:index] | ||||||
|  | 
 | ||||||
| def split_easybake_string(ebstring): | def split_easybake_string(ebstring): | ||||||
|     ''' |     ''' | ||||||
|     Given an easybake string, return (tagname, synonym, rename_to), where |     Given an easybake string, return (tagname, synonym, rename_to), where | ||||||
|  |  | ||||||
|  | @ -693,32 +693,32 @@ class PDBSQLMixin: | ||||||
|         self.sql.commit() |         self.sql.commit() | ||||||
| 
 | 
 | ||||||
|     def rollback(self, savepoint=None): |     def rollback(self, savepoint=None): | ||||||
|         if savepoint is not None: |         ''' | ||||||
|             valid_savepoint = savepoint in self.savepoints |         Given a savepoint, roll the database back to the moment before that | ||||||
|  |         savepoint was created. Keep in mind that a @transaction savepoint is | ||||||
|  |         always created *before* the method actually does anything. | ||||||
|  | 
 | ||||||
|  |         If no savepoint is provided then just roll back the most recent save. | ||||||
|  |         ''' | ||||||
|  |         if savepoint is None: | ||||||
|  |             try: | ||||||
|  |                 savepoint = self.savepoints.pop(-1) | ||||||
|  |             except IndexError: | ||||||
|  |                 self.log.debug('Nothing to roll back.') | ||||||
|  |                 return | ||||||
|         else: |         else: | ||||||
|             valid_savepoint = None |             try: | ||||||
|  |                 # Will reassign after everything goes well. | ||||||
|  |                 _savepoints = helpers.slice_before(self.savepoints, savepoint) | ||||||
|  |             except ValueError: | ||||||
|  |                 self.log.warn('Tried to restore nonexistent savepoint %s.', savepoint) | ||||||
|  |                 return | ||||||
| 
 | 
 | ||||||
|         if valid_savepoint is False: |         self.log.debug('Rolling back to %s', savepoint) | ||||||
|             self.log.warn('Tried to restore to a nonexistent savepoint. Did you commit too early?') |         query = 'ROLLBACK TO "%s"' % savepoint | ||||||
| 
 |  | ||||||
|         if len(self.savepoints) == 0: |  | ||||||
|             self.log.debug('Nothing to rollback.') |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         if valid_savepoint: |  | ||||||
|             restore_to = savepoint |  | ||||||
|             while self.savepoints.pop(-1) != restore_to: |  | ||||||
|                 pass |  | ||||||
|         else: |  | ||||||
|             restore_to = self.savepoints.pop(-1) |  | ||||||
| 
 |  | ||||||
|         self.log.debug('Rolling back to %s', restore_to) |  | ||||||
|         query = 'ROLLBACK TO "%s"' % restore_to |  | ||||||
|         self.sql_execute(query) |         self.sql_execute(query) | ||||||
|         while len(self.on_commit_queue) > 0: |         self.savepoints = _savepoints | ||||||
|             item = self.on_commit_queue.pop(-1) |         self.on_commit_queue = helpers.slice_before(self.on_commit_queue, savepoint) | ||||||
|             if item == restore_to: |  | ||||||
|                 break |  | ||||||
| 
 | 
 | ||||||
|     def savepoint(self, message=None): |     def savepoint(self, message=None): | ||||||
|         savepoint_id = helpers.random_hex(length=16) |         savepoint_id = helpers.random_hex(length=16) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue