Add support for more special filenames like Windows's "con".

This commit is contained in:
Ethan Dalool 2020-09-25 15:43:36 -07:00
parent 677d8a251f
commit 3deb6270ba

View file

@ -6,8 +6,9 @@ import urllib
import warnings import warnings
from voussoirkit import bytestring from voussoirkit import bytestring
from voussoirkit import ratelimiter
from voussoirkit import clipext from voussoirkit import clipext
from voussoirkit import pathclass
from voussoirkit import ratelimiter
from voussoirkit import safeprint from voussoirkit import safeprint
warnings.simplefilter('ignore') warnings.simplefilter('ignore')
@ -25,6 +26,11 @@ TEMP_EXTENSION = '.downloadytemp'
PRINT_LIMITER = ratelimiter.Ratelimiter(allowance=5, mode='reject') PRINT_LIMITER = ratelimiter.Ratelimiter(allowance=5, mode='reject')
SPECIAL_FILENAMES = [os.devnull]
if os.name == 'nt':
SPECIAL_FILENAMES.append('con')
SPECIAL_FILENAMES = [os.path.normcase(x) for x in SPECIAL_FILENAMES]
class NotEnoughBytes(Exception): class NotEnoughBytes(Exception):
pass pass
@ -50,7 +56,8 @@ def download_file(
if os.path.isdir(localname): if os.path.isdir(localname):
localname = os.path.join(localname, basename_from_url(url)) localname = os.path.join(localname, basename_from_url(url))
localname = sanitize_filename(localname) localname = sanitize_filename(localname)
if localname != os.devnull:
if not is_special_file(localname):
localname = os.path.abspath(localname) localname = os.path.abspath(localname)
if verbose: if verbose:
@ -80,7 +87,10 @@ def download_plan(plan):
directory = os.path.split(localname)[0] directory = os.path.split(localname)[0]
if directory != '': if directory != '':
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
if not is_special_file(localname):
touch(localname) touch(localname)
if plan['plan_type'] in ['resume', 'partial']: if plan['plan_type'] in ['resume', 'partial']:
file_handle = open(localname, 'r+b') file_handle = open(localname, 'r+b')
file_handle.seek(plan['seek_to']) file_handle.seek(plan['seek_to'])
@ -124,8 +134,8 @@ def download_plan(plan):
file_handle.close() file_handle.close()
# Don't try to rename /dev/null # Don't try to rename /dev/null or other special names
if os.devnull not in [localname, plan['real_localname']]: if not is_special_file(localname) and not is_special_file(plan['real_localname']):
localsize = os.path.getsize(localname) localsize = os.path.getsize(localname)
undersized = plan['plan_type'] != 'partial' and localsize < plan['remote_total_bytes'] undersized = plan['plan_type'] != 'partial' and localsize < plan['remote_total_bytes']
if plan['raise_for_undersized'] and undersized: if plan['raise_for_undersized'] and undersized:
@ -155,12 +165,16 @@ def prepare_plan(
headers = headers or {} headers = headers or {}
user_provided_range = 'range' in headers user_provided_range = 'range' in headers
real_localname = localname real_localname = localname
if is_special_file(localname):
temp_localname = localname
else:
temp_localname = localname + TEMP_EXTENSION temp_localname = localname + TEMP_EXTENSION
real_exists = os.path.exists(real_localname) real_exists = os.path.exists(real_localname)
if real_exists and overwrite is False and not user_provided_range: if real_exists and overwrite is False and not user_provided_range:
print('File exists and overwrite is off. Nothing to do.') print('File exists and overwrite is off. Nothing to do.')
return None return None
temp_exists = os.path.exists(temp_localname) temp_exists = os.path.exists(temp_localname)
real_localsize = int(real_exists and os.path.getsize(real_localname)) real_localsize = int(real_exists and os.path.getsize(real_localname))
temp_localsize = int(temp_exists and os.path.getsize(temp_localname)) temp_localsize = int(temp_exists and os.path.getsize(temp_localname))
@ -349,6 +363,9 @@ def get_permission(prompt='y/n\n>', affirmative=['y', 'yes']):
permission = input(prompt) permission = input(prompt)
return permission.lower() in affirmative return permission.lower() in affirmative
def is_special_file(filename):
return os.path.normcase(filename) in SPECIAL_FILENAMES
def request(method, url, stream=False, headers=None, timeout=TIMEOUT, verify_ssl=True, **kwargs): def request(method, url, stream=False, headers=None, timeout=TIMEOUT, verify_ssl=True, **kwargs):
if headers is None: if headers is None:
headers = {} headers = {}