diff --git a/SimpleServer/simpleserver.py b/SimpleServer/simpleserver.py index bb57988..ff19712 100644 --- a/SimpleServer/simpleserver.py +++ b/SimpleServer/simpleserver.py @@ -59,8 +59,9 @@ TOKEN_COOKIE_NAME = 'simpleserver_token' # SERVER ########################################################################################### class RequestHandler(http.server.BaseHTTPRequestHandler): - def __init__(self, *args, password=None, accepted_tokens=None, **kwargs): + def __init__(self, *args, password=None, accepted_tokens=None, accepted_ips=None, **kwargs): self.accepted_tokens = accepted_tokens + self.accepted_ips = accepted_ips self.password = password super().__init__(*args, **kwargs) @@ -91,6 +92,10 @@ class RequestHandler(http.server.BaseHTTPRequestHandler): (username, password) = authorization.split(':', 1) return password + @property + def remote_addr(self): + return self.request.getpeername()[0] + def check_password(self, attempt): if self.password is None: return True @@ -110,6 +115,9 @@ class RequestHandler(http.server.BaseHTTPRequestHandler): if self.accepted_tokens is not None and self.auth_cookie in self.accepted_tokens: return True + if self.accepted_ips is not None and self.remote_addr in self.accepted_ips: + return True + return False def write(self, data): @@ -248,13 +256,17 @@ class RequestHandler(http.server.BaseHTTPRequestHandler): attempt = form.get(b'password')[0].decode('utf-8') goto = form.get(b'goto')[0].decode('utf-8') if self.check_password(attempt): - cookie = http.cookies.SimpleCookie() - token = random_hex(32) - cookie[TOKEN_COOKIE_NAME] = token - self.accepted_tokens.add(token) - self.send_response(302) - self.send_header('Set-Cookie', cookie.output(header='', sep='')) + + if self.accepted_tokens is not None: + cookie = http.cookies.SimpleCookie() + token = random_hex(32) + cookie[TOKEN_COOKIE_NAME] = token + self.accepted_tokens.add(token) + self.send_header('Set-Cookie', cookie.output(header='', sep='')) + if self.accepted_ips is not None: + self.accepted_ips.add(self.remote_addr) + self.send_header('Location', goto) else: self.send_response(401) @@ -273,15 +285,22 @@ class RequestHandler(http.server.BaseHTTPRequestHandler): return False class SimpleServer: - def __init__(self, port, password): + def __init__(self, port, password, authorize_by_ip): self.port = port self.password = password - self.accepted_tokens = set() + self.authorize_by_ip = authorize_by_ip + if authorize_by_ip: + self.accepted_ips = set() + self.accepted_tokens = None + else: + self.accepted_tokens = set() + self.accepted_ips = None def make_request_handler(self, *args, **kwargs): return RequestHandler( password=self.password, accepted_tokens=self.accepted_tokens, + accepted_ips=self.accepted_ips, *args, **kwargs, ) @@ -460,7 +479,11 @@ def zip_directory(path): # COMMAND LINE ################################################################################################### def simpleserver_argparse(args): - server = SimpleServer(port=args.port, password=args.password) + server = SimpleServer( + port=args.port, + password=args.password, + authorize_by_ip=args.authorize_by_ip, + ) server.start() def main(argv): @@ -468,6 +491,7 @@ def main(argv): parser.add_argument('port', nargs='?', type=int, default=40000) parser.add_argument('--password', dest='password', default=None) + parser.add_argument('--authorize_by_ip', '--authorize-by-ip', dest='authorize_by_ip', action='store_true') parser.set_defaults(func=simpleserver_argparse) args = parser.parse_args(argv)