Support usage of filepaths as -r and -u arguments.

Add `tsdb.name_from_path` so that if the user provides a filepath
as the -r or -u option, we can load their database from that
path and also figure out the subreddit / username to fetch the
subreddit / user object properly, without having to add another
cmd arg for specifying a nonstandard path. New constructor argument
fix_name to perform this automatically and return a tuple of
(database, fixed name).
master
Ethan Dalool 2017-12-13 14:42:56 -08:00
parent 278923a000
commit c185ddaf77
6 changed files with 54 additions and 20 deletions

View File

@ -29,12 +29,10 @@ def commentaugment(
subreddit = specific_submission_obj.subreddit.display_name subreddit = specific_submission_obj.subreddit.display_name
if subreddit: if subreddit:
if specific_submission is None: do_create = specific_submission is not None
database = tsdb.TSDB.for_subreddit(subreddit, do_create=False) (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, do_create=do_create, fix_name=True)
else:
database = tsdb.TSDB.for_subreddit(subreddit, do_create=True)
else: else:
database = tsdb.TSDB.for_user(username, do_create=False) (database, username) = tsdb.TSDB.for_user(username, do_create=False, fix_name=True)
cur = database.sql.cursor() cur = database.sql.cursor()
if limit == 0: if limit == 0:

View File

@ -6,11 +6,11 @@ from . import tsdb
def getstyles(subreddit): def getstyles(subreddit):
(database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True)
print('Getting styles for /r/%s' % subreddit) print('Getting styles for /r/%s' % subreddit)
subreddit = common.r.subreddit(subreddit) subreddit = common.r.subreddit(subreddit)
styles = subreddit.stylesheet() styles = subreddit.stylesheet()
database = tsdb.TSDB.for_subreddit(subreddit.display_name)
os.makedirs(database.styles_dir.absolute_path, exist_ok=True) os.makedirs(database.styles_dir.absolute_path, exist_ok=True)

View File

@ -5,9 +5,10 @@ from . import tsdb
def getwiki(subreddit): def getwiki(subreddit):
(database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True)
print('Getting wiki pages for /r/%s' % subreddit) print('Getting wiki pages for /r/%s' % subreddit)
subreddit = common.r.subreddit(subreddit) subreddit = common.r.subreddit(subreddit)
database = tsdb.TSDB.for_subreddit(subreddit)
for wikipage in subreddit.wiki: for wikipage in subreddit.wiki:
if wikipage.name == 'config/stylesheet': if wikipage.name == 'config/stylesheet':

View File

@ -88,13 +88,13 @@ def _livestream_as_a_generator(
if subreddit: if subreddit:
common.log.debug('Getting subreddit %s', subreddit) common.log.debug('Getting subreddit %s', subreddit)
database = tsdb.TSDB.for_subreddit(subreddit) (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True)
subreddit = common.r.subreddit(subreddit) subreddit = common.r.subreddit(subreddit)
submission_function = subreddit.new if do_submissions else None submission_function = subreddit.new if do_submissions else None
comment_function = subreddit.comments if do_comments else None comment_function = subreddit.comments if do_comments else None
else: else:
common.log.debug('Getting redditor %s', username) common.log.debug('Getting redditor %s', username)
database = tsdb.TSDB.for_user(username) (database, username) = tsdb.TSDB.for_user(username, fix_name=True)
user = common.r.redditor(username) user = common.r.redditor(username)
submission_function = user.submissions.new if do_submissions else None submission_function = user.submissions.new if do_submissions else None
comment_function = user.comments.new if do_comments else None comment_function = user.comments.new if do_comments else None
@ -162,5 +162,5 @@ def livestream_argparse(args):
do_submissions=args.submissions, do_submissions=args.submissions,
limit=limit, limit=limit,
only_once=args.once, only_once=args.once,
sleepy=common.int_none(args.sleepy), sleepy=int(args.sleepy),
) )

View File

@ -28,11 +28,11 @@ def timesearch(
common.bot.login(common.r) common.bot.login(common.r)
if subreddit: if subreddit:
database = tsdb.TSDB.for_subreddit(subreddit) (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True)
else: else:
# When searching, we'll take the user's submissions from anywhere. # When searching, we'll take the user's submissions from anywhere.
subreddit = 'all' subreddit = 'all'
database = tsdb.TSDB.for_user(username) (database, username) = tsdb.TSDB.for_user(username, fix_name=True)
cur = database.sql.cursor() cur = database.sql.cursor()
if lower == 'update': if lower == 'update':

View File

@ -177,24 +177,44 @@ class TSDB:
return pathclass.Path(path) return pathclass.Path(path)
@classmethod @classmethod
def for_subreddit(cls, name, do_create=True): def _for_object_helper(cls, name, path_formats, do_create=True, fix_name=False):
if name != os.path.basename(name):
filepath = pathclass.Path(name)
else:
filepath = cls._pick_filepath(formats=path_formats, name=name)
database = cls(filepath=filepath, do_create=do_create)
if fix_name:
return (database, name_from_path(name))
return database
@classmethod
def for_subreddit(cls, name, do_create=True, fix_name=False):
if isinstance(name, common.praw.models.Subreddit): if isinstance(name, common.praw.models.Subreddit):
name = name.display_name name = name.display_name
elif not isinstance(name, str): elif not isinstance(name, str):
raise TypeError(name, 'should be str or Subreddit.') raise TypeError(name, 'should be str or Subreddit.')
return cls._for_object_helper(
filepath = cls._pick_filepath(formats=DB_FORMATS_SUBREDDIT, name=name) name,
return cls(filepath=filepath, do_create=do_create) do_create=do_create,
fix_name=fix_name,
path_formats=DB_FORMATS_SUBREDDIT,
)
@classmethod @classmethod
def for_user(cls, name, do_create=True): def for_user(cls, name, do_create=True, fix_name=False):
if isinstance(name, common.praw.models.Redditor): if isinstance(name, common.praw.models.Redditor):
name = name.name name = name.name
elif not isinstance(name, str): elif not isinstance(name, str):
raise TypeError(name, 'should be str or Redditor.') raise TypeError(name, 'should be str or Redditor.')
filepath = cls._pick_filepath(formats=DB_FORMATS_USER, name=name) return cls._for_object_helper(
return cls(filepath=filepath, do_create=do_create) name,
do_create=do_create,
fix_name=fix_name,
path_formats=DB_FORMATS_USER,
)
def insert(self, objects, commit=True): def insert(self, objects, commit=True):
if not isinstance(objects, (list, tuple, types.GeneratorType)): if not isinstance(objects, (list, tuple, types.GeneratorType)):
@ -368,3 +388,18 @@ def binding_filler(column_names, values, require_all=True):
qmarks = ', '.join(qmarks) qmarks = ', '.join(qmarks)
bindings = [values[column] for column in column_names] bindings = [values[column] for column in column_names]
return (qmarks, bindings) return (qmarks, bindings)
def name_from_path(filepath):
'''
In order to support usage like
> timesearch livestream -r D:\\some\\other\\filepath\\learnpython.db
this function extracts the subreddit name / username based on the given
path, so that we can pass it into `r.subreddit` / `r.redditor` properly.
'''
if isinstance(filepath, pathclass.Path):
filepath = filepath.basename
else:
filepath = os.path.basename(filepath)
name = os.path.splitext(filepath)[0]
name = name.strip('@')
return name