From 629594b6ae92474f2679b15c12458acdbf32581a Mon Sep 17 00:00:00 2001 From: Ethan Dalool Date: Sun, 20 Mar 2022 13:01:26 -0700 Subject: [PATCH] Add SSE functions. --- voussoirkit/flasktools.py | 56 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/voussoirkit/flasktools.py b/voussoirkit/flasktools.py index 3373e01..2ef7426 100644 --- a/voussoirkit/flasktools.py +++ b/voussoirkit/flasktools.py @@ -3,7 +3,9 @@ import functools import gzip import io import json +import queue import random +import threading import time import werkzeug.wrappers @@ -26,6 +28,9 @@ RESPONSE_TYPES = (flask.Response, werkzeug.wrappers.Response) NOT_CACHED = sentinel.Sentinel('not cached', truthyness=False) +SSE_LISTENERS = set() +SSE_LISTENERS_LOCK = threading.Lock() + def cached_endpoint(max_age, etag_function=None, max_urls=1000): ''' The cached_endpoint decorator can be used on slow endpoints that don't need @@ -330,3 +335,54 @@ def required_fields(fields, forbid_whitespace=False): return function(*args, **kwargs) return wrapped return wrapper + +def send_sse(*, event, data): + # This is not required by spec, but it is required for my sanity. + # I think every message should be describable by some event name. + if event is None: + raise TypeError(event) + + event = event.strip() + if not event: + raise ValueError(event) + + message = [f'event: {event}'] + + if data is None or data == '': + message.append('data: ') + else: + data = str(data) + data = '\n'.join(f'data: {line.strip()}' for line in data.splitlines()) + message.append(data) + + message = '\n'.join(message) + '\n\n' + message = message.encode('utf-8') + + with SSE_LISTENERS_LOCK: + for queue in SSE_LISTENERS: + queue.put(message) + +def send_sse_comment(comment): + message = f': {comment}\n\n' + message = message.encode('utf-8') + with SSE_LISTENERS_LOCK: + for queue in SSE_LISTENERS: + queue.put(message) + +def sse_generator(): + this_queue = queue.Queue() + with SSE_LISTENERS_LOCK: + SSE_LISTENERS.add(this_queue) + try: + log.debug('SSE listener has connected.') + yield ': welcome\n\n'.encode('utf-8') + while True: + try: + message = this_queue.get(timeout=60) + yield message + except queue.Empty: + pass + except GeneratorExit: + log.debug('SSE listener has disconnected.') + with SSE_LISTENERS_LOCK: + SSE_LISTENERS.remove(this_queue)