Let SessionManager.get require Request object; Check IP addr.
So far there is no use case in which get needs to be called with something other than a Request, and I don't think there will be. So let's make that part of the design and we can also take the opportunity to check IP.
This commit is contained in:
		
							parent
							
								
									c6a396c658
								
							
						
					
					
						commit
						6f4530c88c
					
				
					 1 changed files with 18 additions and 6 deletions
				
			
		|  | @ -7,6 +7,7 @@ import werkzeug.wrappers | |||
| import etiquette | ||||
| 
 | ||||
| SESSION_MAX_AGE = 86400 | ||||
| REQUEST_TYPES = (flask.Request, werkzeug.wrappers.Request, werkzeug.local.LocalProxy) | ||||
| 
 | ||||
| def _generate_token(length=32): | ||||
|     randbytes = os.urandom(math.ceil(length / 2)) | ||||
|  | @ -15,10 +16,13 @@ def _generate_token(length=32): | |||
|     return token | ||||
| 
 | ||||
| def _normalize_token(token): | ||||
|     if isinstance(token, (flask.Request, werkzeug.wrappers.Request, werkzeug.local.LocalProxy)): | ||||
|     if isinstance(token, REQUEST_TYPES): | ||||
|         request = token | ||||
|         token = request.cookies.get('etiquette_session', None) | ||||
|         if token is None: | ||||
|             # During normal usage, this does not occur because give_token is | ||||
|             # applied *before* the request handler even sees the request. | ||||
|             # Just a precaution. | ||||
|             message = 'Cannot normalize token for request with no etiquette_session header.' | ||||
|             raise TypeError(message, request) | ||||
|     elif isinstance(token, str): | ||||
|  | @ -35,10 +39,15 @@ class SessionManager: | |||
|     def add(self, session): | ||||
|         self.sessions[session.token] = session | ||||
| 
 | ||||
|     def get(self, token): | ||||
|         token = _normalize_token(token) | ||||
|     def get(self, request): | ||||
|         token = _normalize_token(request) | ||||
|         session = self.sessions[token] | ||||
|         if session.expired(): | ||||
|         invalid = ( | ||||
|             request.remote_addr != session.ip_address or | ||||
|             session.expired() | ||||
|         ) | ||||
|         if invalid: | ||||
|             self.remove(token) | ||||
|             raise KeyError(token) | ||||
|         return session | ||||
| 
 | ||||
|  | @ -59,7 +68,7 @@ class SessionManager: | |||
|                 request.cookies['etiquette_session'] = token | ||||
| 
 | ||||
|             try: | ||||
|                 session = self.get(token) | ||||
|                 session = self.get(request) | ||||
|             except KeyError: | ||||
|                 session = Session(request, user=None) | ||||
|                 self.add(session) | ||||
|  | @ -85,8 +94,11 @@ class SessionManager: | |||
| 
 | ||||
|     def remove(self, token): | ||||
|         token = _normalize_token(token) | ||||
|         if token in self.sessions: | ||||
|         try: | ||||
|             self.sessions.pop(token) | ||||
|         except KeyError: | ||||
|             pass | ||||
| 
 | ||||
| 
 | ||||
| class Session: | ||||
|     def __init__(self, request, user): | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue