Add pathclass.Extension class to enrich extension operations.

I often found myself writing code like if a.extension == 'png' and
trying to remember if I'm supposed to compare against 'png' or '.png',
and then it would trip up on files like A.PNG because I forgot to
lower() it.

So this class handles all that for you. You can == against it and it
will use os.path.normcase to give you OS-appropriate case sens,
and == works whether you include the dot or not. Then you can use
ext.with_dot or ext.no_dot to get reliably dotted strings.
master
Ethan Dalool 2020-01-27 21:33:09 -08:00
parent b13f93c006
commit b88450b567
1 changed files with 44 additions and 9 deletions

View File

@ -15,6 +15,44 @@ class NotFile(PathclassException):
pass
class Extension:
def __init__(self, ext):
if isinstance(ext, Extension):
ext = ext.ext
ext = self.prep(ext)
self.ext = ext
@staticmethod
def prep(ext):
return os.path.normcase(ext).lstrip('.')
def __bool__(self):
return bool(self.ext)
def __eq__(self, other):
other = self.prep(other)
return self.ext == other
def __hash__(self):
return hash(self.ext)
def __repr__(self):
return f'Extension({repr(self.ext)})'
def __str__(self):
return self.ext
@property
def no_dot(self):
return self.ext
@property
def with_dot(self):
if self.ext == '':
return ''
return '.' + self.ext
class Path:
'''
I started to use pathlib.Path, but it was too much of a pain.
@ -65,10 +103,10 @@ class Path:
raise NotDirectory(self)
def add_extension(self, extension):
extension = extension.strip('.')
extension = Extension(extension)
if extension == '':
return self
return self.parent.with_child(self.basename + '.' + extension)
return self.parent.with_child(self.basename + extension.with_dot)
@property
def basename(self):
@ -85,10 +123,7 @@ class Path:
@property
def dot_extension(self):
extension = self.extension
if extension:
return '.' + extension
return ''
return self.extension.with_dot
@property
def exists(self):
@ -96,7 +131,7 @@ class Path:
@property
def extension(self):
return os.path.splitext(self.absolute_path)[1].lstrip('.')
return Extension(os.path.splitext(self.absolute_path)[1])
@property
def is_dir(self):
@ -170,13 +205,13 @@ class Path:
return relative_path
def replace_extension(self, extension):
extension = extension.rsplit('.', 1)[-1]
extension = Extension(extension)
base = os.path.splitext(self.basename)[0]
if extension == '':
return self.parent.with_child(base)
return self.parent.with_child(base + '.' + extension)
return self.parent.with_child(base + extension.with_dot)
@property
def size(self):