Replace usage of strings with bytes. Don't attempt to decode.

This commit is contained in:
Ethan Dalool 2017-11-19 11:00:51 -08:00
parent 731e1247c1
commit b611be6f37

View file

@ -14,7 +14,6 @@ def bencode(data):
encoders = { encoders = {
bytes: encode_bytes, bytes: encode_bytes,
str: encode_string,
float: encode_float, float: encode_float,
int: encode_int, int: encode_int,
dict: encode_dict, dict: encode_dict,
@ -29,7 +28,7 @@ def bencode(data):
return encoder(data) return encoder(data)
def encode_bytes(data): def encode_bytes(data):
return '%d:%s' % (len(data), data) return b'%d:%s' % (len(data), data)
def encode_dict(data): def encode_dict(data):
result = [] result = []
@ -38,24 +37,21 @@ def encode_dict(data):
for key in keys: for key in keys:
result.append(bencode(key)) result.append(bencode(key))
result.append(bencode(data[key])) result.append(bencode(data[key]))
result = ''.join(result) result = b''.join(result)
return 'd%se' % result return b'd%se' % result
def encode_float(data): def encode_float(data):
return encode_string(str(data)) return encode_bytes(str(data).encode())
def encode_int(data): def encode_int(data):
return 'i%de' % data return b'i%de' % data
def encode_iterator(data): def encode_iterator(data):
result = [] result = []
for item in data: for item in data:
result.append(bencode(item)) result.append(bencode(item))
result = ''.join(result) result = b''.join(result)
return 'l%se' % result return b'l%se' % result
def encode_string(data):
return encode_bytes(data)
# ============================================================================= # =============================================================================
@ -74,70 +70,28 @@ def bdecode(data):
if data is None: if data is None:
return None return None
data = data.strip() #data = data.strip()
if isinstance(data, bytes): if isinstance(data, str):
data = data.decode('utf-8') data = data.encode('utf-8')
if data[0] == 'i': identifier = data[0:1]
if identifier == b'i':
return decode_int(data) return decode_int(data)
if data[0].isdigit(): if identifier.isdigit():
return decode_string(data) return decode_bytes(data)
if data[0] == 'l': if identifier == b'l':
return decode_list(data) return decode_list(data)
if data[0] == 'd': if identifier == b'd':
return decode_dict(data) return decode_dict(data)
raise ValueError('Invalid initial delimiter "%s"' % data[0]) raise ValueError('Invalid initial delimiter "%s"' % identifier)
def decode_dict(data): def decode_bytes(data):
result = {} #print('Decoding bytes from', data[:100])
start = data.find(b':') + 1
# slice leading d
remainder = data[1:]
while remainder[0] != 'e':
temp = bdecode(remainder)
key = temp['result']
remainder = temp['remainder']
temp = bdecode(remainder)
value = temp['result']
remainder = temp['remainder']
result[key] = value
# slice ending 3
remainder = remainder[1:]
return {'result': result, 'remainder': remainder}
def decode_int(data):
end = data.find('e')
if end == -1:
raise ValueError('Missing end delimiter "e"')
# slice leading i and closing e
result = data[1:end]
remainder = data[end+1:]
return {'result': result, 'remainder': remainder}
def decode_list(data):
result = []
# slice leading l
remainder = data[1:]
while remainder[0] != 'e':
item = bdecode(data)
result.append(item['result'])
reaminder = item['remainder']
# slice ending e
remainder = remainder[1:]
return {'result': result, 'remainder': remainder}
def decode_string(data):
start = data.find(':') + 1
size = int(data[:start-1]) size = int(data[:start-1])
end = start + size end = start + size
text = data[start:end] text = data[start:end]
@ -145,3 +99,49 @@ def decode_string(data):
raise ValueError('Actual length %d is less than declared length %d' % len(text), size) raise ValueError('Actual length %d is less than declared length %d' % len(text), size)
remainder = data[end:] remainder = data[end:]
return {'result': text, 'remainder': remainder} return {'result': text, 'remainder': remainder}
def decode_dict(data):
#print('Decoding dict from', data[:100])
result = {}
# slice leading d
remainder = data[1:]
while remainder[0:1] != b'e':
temp = bdecode(remainder)
key = temp['result']
remainder = temp['remainder']
temp = bdecode(remainder)
value = temp['result']
remainder = temp['remainder']
result[key] = value
# slice ending e
remainder = remainder[1:]
return {'result': result, 'remainder': remainder}
def decode_int(data):
#print('Decoding int from', data[:100])
end = data.find(b'e')
if end == -1:
raise ValueError('Missing end delimiter "e"')
# slice leading i and closing e
result = int(data[1:end])
remainder = data[end+1:]
return {'result': result, 'remainder': remainder}
def decode_list(data):
#print('Decoding list from', data[:100])
result = []
# slice leading l
remainder = data[1:]
while remainder[0:1] != b'e':
item = bdecode(remainder)
result.append(item['result'])
remainder = item['remainder']
# slice ending e
remainder = remainder[1:]
return {'result': result, 'remainder': remainder}