Preserve user login sessions in json file across runs.
This commit is contained in:
parent
f1d3319d36
commit
7f930c3bce
2 changed files with 100 additions and 17 deletions
|
|
@ -12,6 +12,7 @@ from voussoirkit import bytestring
|
|||
from voussoirkit import configlayers
|
||||
from voussoirkit import flasktools
|
||||
from voussoirkit import pathclass
|
||||
from voussoirkit import timetools
|
||||
from voussoirkit import vlogging
|
||||
|
||||
import etiquette
|
||||
|
|
@ -44,6 +45,7 @@ TEMPLATE_DIR = root_dir.with_child('templates')
|
|||
STATIC_DIR = root_dir.with_child('static')
|
||||
FAVICON_PATH = STATIC_DIR.with_child('favicon.png')
|
||||
SERVER_CONFIG_FILENAME = 'etiquette_flask_config.json'
|
||||
SESSIONS_STATE_FILENAME = 'etiquette_flask_sessions.json'
|
||||
|
||||
site = flask.Flask(
|
||||
__name__,
|
||||
|
|
@ -61,6 +63,7 @@ site.jinja_env.lstrip_blocks = True
|
|||
jinja_filters.register_all(site)
|
||||
site.localhost_only = False
|
||||
|
||||
# state_file will be set later
|
||||
session_manager = sessions.SessionManager(maxlen=10000)
|
||||
file_etag_manager = client_caching.FileEtagManager(
|
||||
maxlen=10000,
|
||||
|
|
@ -308,6 +311,7 @@ def init_photodb(*args, **kwargs):
|
|||
global P
|
||||
P = etiquette.photodb.PhotoDB.closest_photodb(*args, **kwargs)
|
||||
load_config()
|
||||
load_sessions()
|
||||
|
||||
def load_config() -> None:
|
||||
log.debug('Loading server config file.')
|
||||
|
|
@ -321,6 +325,30 @@ def load_config() -> None:
|
|||
if needs_rewrite:
|
||||
save_config()
|
||||
|
||||
def load_sessions():
|
||||
state_file = P.data_directory.with_child(SESSIONS_STATE_FILENAME)
|
||||
session_manager.state_file = state_file
|
||||
if not state_file.exists:
|
||||
return
|
||||
log.debug('Loading sessions from state file')
|
||||
j = json.loads(state_file.read('r'))
|
||||
for session in j:
|
||||
if session['userid'] is None:
|
||||
user = None
|
||||
else:
|
||||
try:
|
||||
user = P.get_user(id=session['userid'])
|
||||
except etiquette.exceptions.NoSuchUser:
|
||||
continue
|
||||
session = sessions.Session(
|
||||
session_manager=session_manager,
|
||||
user=user,
|
||||
token=session['token'],
|
||||
ip_address=session['ip_address'],
|
||||
user_agent=session['user_agent'],
|
||||
last_activity=timetools.fromtimestamp(session['last_activity']),
|
||||
)
|
||||
|
||||
def save_config() -> None:
|
||||
log.debug('Saving server config file.')
|
||||
config_file = P.data_directory.with_child(SERVER_CONFIG_FILENAME)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,26 @@
|
|||
import datetime
|
||||
import flask; from flask import request
|
||||
import functools
|
||||
import json
|
||||
import random
|
||||
import werkzeug.datastructures
|
||||
|
||||
from voussoirkit import cacheclass
|
||||
from voussoirkit import flasktools
|
||||
from voussoirkit import passwordy
|
||||
from voussoirkit import timetools
|
||||
from voussoirkit import vlogging
|
||||
|
||||
log = vlogging.getLogger(__name__, 'sessions')
|
||||
|
||||
import etiquette
|
||||
|
||||
SESSION_MAX_AGE = 86400
|
||||
RNG = random.SystemRandom()
|
||||
|
||||
def _generate_token(length=32):
|
||||
return passwordy.random_hex(length=length)
|
||||
SESSION_MAX_AGE = 86400
|
||||
SAVE_STATE_INTERVAL = 60
|
||||
|
||||
def _generate_token() -> str:
|
||||
return str(RNG.getrandbits(128))
|
||||
|
||||
def _normalize_token(token):
|
||||
if isinstance(token, flasktools.REQUEST_TYPES):
|
||||
|
|
@ -31,8 +39,10 @@ def _normalize_token(token):
|
|||
return token
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self, maxlen=None):
|
||||
def __init__(self, maxlen=None, state_file=None):
|
||||
self.sessions = cacheclass.Cache(maxlen=maxlen)
|
||||
self.last_activity = timetools.now()
|
||||
self.last_save_state = timetools.now()
|
||||
|
||||
def _before_request(self, request):
|
||||
# Inject new token so the function doesn't know the difference
|
||||
|
|
@ -51,8 +61,7 @@ class SessionManager:
|
|||
try:
|
||||
session = self.get(request)
|
||||
except KeyError:
|
||||
session = Session(request, user=None)
|
||||
self.add(session)
|
||||
session = Session.from_request(session_manager=self, request=request, user=None)
|
||||
else:
|
||||
session.maintain()
|
||||
|
||||
|
|
@ -78,6 +87,7 @@ class SessionManager:
|
|||
return response
|
||||
|
||||
def add(self, session):
|
||||
session.session_manager = self
|
||||
self.sessions[session.token] = session
|
||||
|
||||
def clear(self):
|
||||
|
|
@ -86,11 +96,7 @@ class SessionManager:
|
|||
def get(self, request):
|
||||
token = _normalize_token(request)
|
||||
session = self.sessions[token]
|
||||
invalid = (
|
||||
request.remote_addr != session.ip_address or
|
||||
session.expired()
|
||||
)
|
||||
if invalid:
|
||||
if session.expired():
|
||||
self.remove(token)
|
||||
raise KeyError(token)
|
||||
return session
|
||||
|
|
@ -110,6 +116,13 @@ class SessionManager:
|
|||
|
||||
return wrapped
|
||||
|
||||
def maintain(self):
|
||||
now = timetools.now()
|
||||
self.last_activity = now
|
||||
state_age = now - self.last_save_state
|
||||
if state_age.seconds > SAVE_STATE_INTERVAL:
|
||||
self.save_state()
|
||||
|
||||
def remove(self, token):
|
||||
token = _normalize_token(token)
|
||||
try:
|
||||
|
|
@ -117,13 +130,45 @@ class SessionManager:
|
|||
except KeyError:
|
||||
pass
|
||||
|
||||
def save_state(self):
|
||||
log.debug('Saving sessions state.')
|
||||
j = [session.jsonify() for session in self.sessions.values() if not session.expired()]
|
||||
j = json.dumps(j)
|
||||
self.state_file.write('w', j)
|
||||
self.last_save_state = timetools.now()
|
||||
|
||||
class Session:
|
||||
def __init__(self, request, user):
|
||||
self.token = _normalize_token(request)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_manager,
|
||||
token,
|
||||
ip_address,
|
||||
user_agent,
|
||||
user,
|
||||
last_activity=None,
|
||||
):
|
||||
self.session_manager = session_manager
|
||||
self.token = token
|
||||
self.user = user
|
||||
self.ip_address = request.remote_addr
|
||||
self.user_agent = request.headers.get('User-Agent', '')
|
||||
self.last_activity = timetools.now()
|
||||
self.ip_address = ip_address
|
||||
self.user_agent = user_agent
|
||||
if last_activity is None:
|
||||
self.last_activity = timetools.now()
|
||||
else:
|
||||
self.last_activity = last_activity
|
||||
self.session_manager.add(self)
|
||||
self.session_manager.maintain()
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, *, session_manager, request, user):
|
||||
return cls(
|
||||
session_manager=session_manager,
|
||||
token=_normalize_token(request),
|
||||
user=user,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get('User-Agent', ''),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
if self.user:
|
||||
|
|
@ -136,5 +181,15 @@ class Session:
|
|||
age = now - self.last_activity
|
||||
return age.seconds > SESSION_MAX_AGE
|
||||
|
||||
def jsonify(self):
|
||||
return {
|
||||
'userid': (self.user.id) if self.user else None,
|
||||
'token': self.token,
|
||||
'ip_address': self.ip_address,
|
||||
'user_agent': self.user_agent,
|
||||
'last_activity': self.last_activity.timestamp(),
|
||||
}
|
||||
|
||||
def maintain(self):
|
||||
self.last_activity = timetools.now()
|
||||
self.session_manager.maintain()
|
||||
|
|
|
|||
Loading…
Reference in a new issue