Let SessionManager.get require Request object; Check IP addr.

So far there is no use case in which get needs to be called with
something other than a Request, and I don't think there will be.
So let's make that part of the design and we can also take the
opportunity to check IP.
This commit is contained in:
voussoir 2018-02-03 02:10:07 -08:00
parent c6a396c658
commit 6f4530c88c

View file

@ -7,6 +7,7 @@ import werkzeug.wrappers
import etiquette import etiquette
SESSION_MAX_AGE = 86400 SESSION_MAX_AGE = 86400
REQUEST_TYPES = (flask.Request, werkzeug.wrappers.Request, werkzeug.local.LocalProxy)
def _generate_token(length=32): def _generate_token(length=32):
randbytes = os.urandom(math.ceil(length / 2)) randbytes = os.urandom(math.ceil(length / 2))
@ -15,10 +16,13 @@ def _generate_token(length=32):
return token return token
def _normalize_token(token): def _normalize_token(token):
if isinstance(token, (flask.Request, werkzeug.wrappers.Request, werkzeug.local.LocalProxy)): if isinstance(token, REQUEST_TYPES):
request = token request = token
token = request.cookies.get('etiquette_session', None) token = request.cookies.get('etiquette_session', None)
if token is None: if token is None:
# During normal usage, this does not occur because give_token is
# applied *before* the request handler even sees the request.
# Just a precaution.
message = 'Cannot normalize token for request with no etiquette_session header.' message = 'Cannot normalize token for request with no etiquette_session header.'
raise TypeError(message, request) raise TypeError(message, request)
elif isinstance(token, str): elif isinstance(token, str):
@ -35,10 +39,15 @@ class SessionManager:
def add(self, session): def add(self, session):
self.sessions[session.token] = session self.sessions[session.token] = session
def get(self, token): def get(self, request):
token = _normalize_token(token) token = _normalize_token(request)
session = self.sessions[token] session = self.sessions[token]
if session.expired(): invalid = (
request.remote_addr != session.ip_address or
session.expired()
)
if invalid:
self.remove(token)
raise KeyError(token) raise KeyError(token)
return session return session
@ -59,7 +68,7 @@ class SessionManager:
request.cookies['etiquette_session'] = token request.cookies['etiquette_session'] = token
try: try:
session = self.get(token) session = self.get(request)
except KeyError: except KeyError:
session = Session(request, user=None) session = Session(request, user=None)
self.add(session) self.add(session)
@ -85,8 +94,11 @@ class SessionManager:
def remove(self, token): def remove(self, token):
token = _normalize_token(token) token = _normalize_token(token)
if token in self.sessions: try:
self.sessions.pop(token) self.sessions.pop(token)
except KeyError:
pass
class Session: class Session:
def __init__(self, request, user): def __init__(self, request, user):