Better object-oriented SimpleServer.

master
Ethan Dalool 2020-10-10 08:44:56 -07:00
parent fa86f1b393
commit 6ce69da7a6
1 changed files with 37 additions and 28 deletions

View File

@ -55,10 +55,12 @@ PASSWORD_PROMPT_HTML = '''
ROOT_DIRECTORY = pathclass.Path(os.getcwd()) ROOT_DIRECTORY = pathclass.Path(os.getcwd())
TOKEN_COOKIE_NAME = 'simpleserver_token' TOKEN_COOKIE_NAME = 'simpleserver_token'
# SERVER ###########################################################################################
class RequestHandler(http.server.BaseHTTPRequestHandler): class RequestHandler(http.server.BaseHTTPRequestHandler):
def __init__(self, *args, passw=None, accepted_tokens=None, **kwargs): def __init__(self, *args, password=None, accepted_tokens=None, **kwargs):
self.accepted_tokens = accepted_tokens self.accepted_tokens = accepted_tokens
self.password = passw self.password = password
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def check_password(self, attempt): def check_password(self, attempt):
@ -209,7 +211,7 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
def do_POST(self): def do_POST(self):
ctype, pdict = cgi.parse_header(self.headers.get('content-type')) (ctype, pdict) = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data': if ctype == 'multipart/form-data':
form = cgi.parse_multipart(self.rfile, pdict) form = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded': elif ctype == 'application/x-www-form-urlencoded':
@ -232,6 +234,10 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
self.send_header('Location', goto) self.send_header('Location', goto)
else: else:
self.send_response(401) self.send_response(401)
elif not self.check_has_password():
self.send_response(401)
self.end_headers()
return
else: else:
self.send_response(400) self.send_response(400)
self.end_headers() self.end_headers()
@ -242,6 +248,31 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
return True return True
return False return False
class SimpleServer:
def __init__(self, port, password):
self.port = port
self.password = password
self.accepted_tokens = set()
def make_request_handler(self, *args, **kwargs):
return RequestHandler(
password=self.password,
accepted_tokens=self.accepted_tokens,
*args,
**kwargs,
)
def start(self):
server = http.server.ThreadingHTTPServer(('0.0.0.0', self.port), self.make_request_handler)
print(f'Server starting on {self.port}')
try:
server.serve_forever()
except KeyboardInterrupt:
print('Goodbye.')
server.shutdown()
# HELPERS ##########################################################################################
def allowed(path): def allowed(path):
return path == ROOT_DIRECTORY or path in ROOT_DIRECTORY return path == ROOT_DIRECTORY or path in ROOT_DIRECTORY
@ -402,33 +433,11 @@ def zip_directory(path):
return zipfile return zipfile
def RRR(password=None): # COMMAND LINE ###################################################################################################
accepted_tokens = set()
def R(*args, **kwargs):
handler = RequestHandler(passw=password, accepted_tokens=accepted_tokens, *args, **kwargs)
return handler
return R
def simpleserver(port, password=None):
server = http.server.ThreadingHTTPServer(('', port), RRR(password=password))
print(f'server starting on {port}')
try:
server.serve_forever()
except KeyboardInterrupt:
print('goodbye.')
t = threading.Thread(target=server.shutdown)
t.daemon = True
t.start()
server.shutdown()
print('really goodbye.')
return 0
def simpleserver_argparse(args): def simpleserver_argparse(args):
return simpleserver( server = SimpleServer(port=args.port, password=args.password)
port=args.port, server.start()
password=args.password,
)
def main(argv): def main(argv):
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)