Significantly speed up bdecode by always indexing into original data.

This commit is contained in:
voussoir 2021-11-07 19:31:55 -08:00
parent 41ec819a35
commit 9e507fd61b
No known key found for this signature in database
GPG key ID: 5F7554F8C26DACCB

View file

@ -7,6 +7,9 @@ Bencode data.
https://en.wikipedia.org/wiki/Bencode https://en.wikipedia.org/wiki/Bencode
''' '''
# PUBLIC
####################################################################################################
def bencode(data) -> bytes: def bencode(data) -> bytes:
''' '''
Encode python types to bencode. Encode python types to bencode.
@ -23,17 +26,19 @@ def bencode(data) -> bytes:
encoder = encoders.get(data_type, None) encoder = encoders.get(data_type, None)
if encoder is None: if encoder is None:
raise ValueError(f'Invalid data type {data_type}.') raise TypeError(f'Invalid data type {data_type}.')
return encoder(data) return encoder(data)
def bdecode(data): def bdecode(data):
''' '''
Decode bencode to python types. Decode bencode to python types.
''' '''
return _decode(data)['result'] return _decode(data, start_index=0)['result']
# INTERNALS # INTERNALS
################################################################################ ################################################################################
def _encode_bytes(data): def _encode_bytes(data):
''' '''
Binary data is encoded as {length}:{bytes}. Binary data is encoded as {length}:{bytes}.
@ -47,8 +52,7 @@ def _encode_dict(data):
Keys must be byte strings Keys must be byte strings
''' '''
result = [] result = []
keys = list(data.keys()) keys = sorted(data.keys())
keys.sort()
for key in keys: for key in keys:
result.append(bencode(key)) result.append(bencode(key))
result.append(bencode(data[key])) result.append(bencode(data[key]))
@ -71,86 +75,88 @@ def _encode_list(data):
result = b''.join(result) result = b''.join(result)
return b'l%se' % result return b'l%se' % result
def _decode(data): def _decode(data, *, start_index):
if not isinstance(data, bytes): if not isinstance(data, bytes):
raise TypeError(f'bencode data should be bytes, not {type(data)}.') raise TypeError(f'bencode data should be bytes, not {type(data)}.')
identifier = data[0:1] identifier = data[start_index:start_index+1]
if identifier == b'i': if identifier == b'i':
ret = _decode_int(data) ret = _decode_int(data, start_index=start_index)
elif identifier.isdigit(): elif identifier.isdigit():
ret = _decode_bytes(data) ret = _decode_bytes(data, start_index=start_index)
elif identifier == b'l': elif identifier == b'l':
ret = _decode_list(data) ret = _decode_list(data, start_index=start_index)
elif identifier == b'd': elif identifier == b'd':
ret = _decode_dict(data) ret = _decode_dict(data, start_index=start_index)
else: else:
raise ValueError(f'Invalid initial delimiter "{identifier}".') raise ValueError(f'Invalid initial delimiter "{identifier}".')
return ret return ret
def _decode_bytes(data): def _decode_bytes(data, *, start_index):
colon = data.find(b':') colon = data.find(b':', start_index)
if colon == -1:
raise ValueError('Missing bytes delimiter ":"')
start = colon + 1 start = colon + 1
size = int(data[:colon]) length = int(data[start_index:colon])
end = start + size end = start + length
text = data[start:end] text = data[start:end]
remainder = data[end:]
return {'result': text, 'remainder': remainder} return {'result': text, 'remainder_index': end}
def _decode_dict(data): def _decode_dict(data, *, start_index):
result = {} result = {}
# slice leading d # +1 to skip the leading d.
remainder = data[1:] start_index += 1
# Checking [0:1] instead of [0] because [0] returns an int!!!! # We need to check a slice of length 1 because subscripting into bytes
# [0:1] returns b'e' which I want. # returns ints.
while remainder[0:1] != b'e': while data[start_index:start_index+1] != b'e':
temp = _decode(remainder) temp = _decode(data, start_index=start_index)
key = temp['result'] key = temp['result']
remainder = temp['remainder'] start_index = temp['remainder_index']
temp = _decode(remainder) temp = _decode(data, start_index=start_index)
value = temp['result'] value = temp['result']
remainder = temp['remainder'] start_index = temp['remainder_index']
result[key] = value result[key] = value
# slice ending e # +1 to skip the trailing e.
remainder = remainder[1:] return {'result': result, 'remainder_index': start_index+1}
return {'result': result, 'remainder': remainder}
def _decode_int(data): def _decode_int(data, *, start_index):
# slide leading i # +1 to skip the leading i.
data = data[1:] start_index += 1
end = data.find(b'e') end = data.find(b'e', start_index)
if end == -1: if end == -1:
raise ValueError('Missing end delimiter "e"') raise ValueError('Missing end delimiter "e"')
result = int(data[:end])
# slice ending e result = int(data[start_index:end])
remainder = data[end+1:]
return {'result': result, 'remainder': remainder} # +1 to skip the trailing e.
return {'result': result, 'remainder_index': end+1}
def _decode_list(data, *, start_index):
# +1 to skip the leading l.
start_index += 1
def _decode_list(data):
result = [] result = []
# slice leading l # We need to check a slice of length 1 because subscripting into bytes
remainder = data[1:] # returns ints.
while data[start_index:start_index+1] != b'e':
while remainder[0:1] != b'e': item = _decode(data, start_index=start_index)
item = _decode(remainder)
result.append(item['result']) result.append(item['result'])
remainder = item['remainder'] start_index = item['remainder_index']
# slice ending e # +1 to skip the trailing e.
remainder = remainder[1:] return {'result': result, 'remainder_index': start_index+1}
return {'result': result, 'remainder': remainder}