From 9e507fd61baa8013ce2c8a05d9c344893d725c10 Mon Sep 17 00:00:00 2001 From: Ethan Dalool Date: Sun, 7 Nov 2021 19:31:55 -0800 Subject: [PATCH] Significantly speed up bdecode by always indexing into original data. --- voussoirkit/bencode.py | 100 ++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/voussoirkit/bencode.py b/voussoirkit/bencode.py index 0202de2..2c07eb8 100644 --- a/voussoirkit/bencode.py +++ b/voussoirkit/bencode.py @@ -7,6 +7,9 @@ Bencode data. https://en.wikipedia.org/wiki/Bencode ''' +# PUBLIC +#################################################################################################### + def bencode(data) -> bytes: ''' Encode python types to bencode. @@ -23,17 +26,19 @@ def bencode(data) -> bytes: encoder = encoders.get(data_type, None) if encoder is None: - raise ValueError(f'Invalid data type {data_type}.') + raise TypeError(f'Invalid data type {data_type}.') + return encoder(data) def bdecode(data): ''' Decode bencode to python types. ''' - return _decode(data)['result'] + return _decode(data, start_index=0)['result'] # INTERNALS ################################################################################ + def _encode_bytes(data): ''' Binary data is encoded as {length}:{bytes}. @@ -47,8 +52,7 @@ def _encode_dict(data): Keys must be byte strings ''' result = [] - keys = list(data.keys()) - keys.sort() + keys = sorted(data.keys()) for key in keys: result.append(bencode(key)) result.append(bencode(data[key])) @@ -71,86 +75,88 @@ def _encode_list(data): result = b''.join(result) return b'l%se' % result -def _decode(data): +def _decode(data, *, start_index): if not isinstance(data, bytes): raise TypeError(f'bencode data should be bytes, not {type(data)}.') - identifier = data[0:1] + identifier = data[start_index:start_index+1] if identifier == b'i': - ret = _decode_int(data) + ret = _decode_int(data, start_index=start_index) elif identifier.isdigit(): - ret = _decode_bytes(data) + ret = _decode_bytes(data, start_index=start_index) elif identifier == b'l': - ret = _decode_list(data) + ret = _decode_list(data, start_index=start_index) elif identifier == b'd': - ret = _decode_dict(data) + ret = _decode_dict(data, start_index=start_index) else: raise ValueError(f'Invalid initial delimiter "{identifier}".') return ret -def _decode_bytes(data): - colon = data.find(b':') +def _decode_bytes(data, *, start_index): + colon = data.find(b':', start_index) + if colon == -1: + raise ValueError('Missing bytes delimiter ":"') start = colon + 1 - size = int(data[:colon]) - end = start + size + length = int(data[start_index:colon]) + end = start + length text = data[start:end] - remainder = data[end:] - return {'result': text, 'remainder': remainder} + return {'result': text, 'remainder_index': end} -def _decode_dict(data): +def _decode_dict(data, *, start_index): result = {} - # slice leading d - remainder = data[1:] + # +1 to skip the leading d. + start_index += 1 - # Checking [0:1] instead of [0] because [0] returns an int!!!! - # [0:1] returns b'e' which I want. - while remainder[0:1] != b'e': - temp = _decode(remainder) + # We need to check a slice of length 1 because subscripting into bytes + # returns ints. + while data[start_index:start_index+1] != b'e': + temp = _decode(data, start_index=start_index) key = temp['result'] - remainder = temp['remainder'] + start_index = temp['remainder_index'] - temp = _decode(remainder) + temp = _decode(data, start_index=start_index) value = temp['result'] - remainder = temp['remainder'] + start_index = temp['remainder_index'] + result[key] = value - # slice ending e - remainder = remainder[1:] - return {'result': result, 'remainder': remainder} + # +1 to skip the trailing e. + return {'result': result, 'remainder_index': start_index+1} -def _decode_int(data): - # slide leading i - data = data[1:] +def _decode_int(data, *, start_index): + # +1 to skip the leading i. + start_index += 1 - end = data.find(b'e') + end = data.find(b'e', start_index) if end == -1: raise ValueError('Missing end delimiter "e"') - result = int(data[:end]) - # slice ending e - remainder = data[end+1:] - return {'result': result, 'remainder': remainder} + result = int(data[start_index:end]) + + # +1 to skip the trailing e. + return {'result': result, 'remainder_index': end+1} + +def _decode_list(data, *, start_index): + # +1 to skip the leading l. + start_index += 1 -def _decode_list(data): result = [] - # slice leading l - remainder = data[1:] - - while remainder[0:1] != b'e': - item = _decode(remainder) + # We need to check a slice of length 1 because subscripting into bytes + # returns ints. + while data[start_index:start_index+1] != b'e': + item = _decode(data, start_index=start_index) result.append(item['result']) - remainder = item['remainder'] + start_index = item['remainder_index'] - # slice ending e - remainder = remainder[1:] - return {'result': result, 'remainder': remainder} + # +1 to skip the trailing e. + return {'result': result, 'remainder_index': start_index+1}