diff --git a/timesearch/commentaugment.py b/timesearch/commentaugment.py index 1514938..16ca05a 100644 --- a/timesearch/commentaugment.py +++ b/timesearch/commentaugment.py @@ -29,12 +29,10 @@ def commentaugment( subreddit = specific_submission_obj.subreddit.display_name if subreddit: - if specific_submission is None: - database = tsdb.TSDB.for_subreddit(subreddit, do_create=False) - else: - database = tsdb.TSDB.for_subreddit(subreddit, do_create=True) + do_create = specific_submission is not None + (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, do_create=do_create, fix_name=True) else: - database = tsdb.TSDB.for_user(username, do_create=False) + (database, username) = tsdb.TSDB.for_user(username, do_create=False, fix_name=True) cur = database.sql.cursor() if limit == 0: diff --git a/timesearch/getstyles.py b/timesearch/getstyles.py index 5d8dca9..242f9ea 100644 --- a/timesearch/getstyles.py +++ b/timesearch/getstyles.py @@ -6,11 +6,11 @@ from . import tsdb def getstyles(subreddit): + (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True) + print('Getting styles for /r/%s' % subreddit) subreddit = common.r.subreddit(subreddit) - styles = subreddit.stylesheet() - database = tsdb.TSDB.for_subreddit(subreddit.display_name) os.makedirs(database.styles_dir.absolute_path, exist_ok=True) diff --git a/timesearch/getwiki.py b/timesearch/getwiki.py index ee0d2ca..8352f10 100644 --- a/timesearch/getwiki.py +++ b/timesearch/getwiki.py @@ -5,9 +5,10 @@ from . import tsdb def getwiki(subreddit): + (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True) + print('Getting wiki pages for /r/%s' % subreddit) subreddit = common.r.subreddit(subreddit) - database = tsdb.TSDB.for_subreddit(subreddit) for wikipage in subreddit.wiki: if wikipage.name == 'config/stylesheet': diff --git a/timesearch/livestream.py b/timesearch/livestream.py index 8e8737d..e407d68 100644 --- a/timesearch/livestream.py +++ b/timesearch/livestream.py @@ -88,13 +88,13 @@ def _livestream_as_a_generator( if subreddit: common.log.debug('Getting subreddit %s', subreddit) - database = tsdb.TSDB.for_subreddit(subreddit) + (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True) subreddit = common.r.subreddit(subreddit) submission_function = subreddit.new if do_submissions else None comment_function = subreddit.comments if do_comments else None else: common.log.debug('Getting redditor %s', username) - database = tsdb.TSDB.for_user(username) + (database, username) = tsdb.TSDB.for_user(username, fix_name=True) user = common.r.redditor(username) submission_function = user.submissions.new if do_submissions else None comment_function = user.comments.new if do_comments else None @@ -162,5 +162,5 @@ def livestream_argparse(args): do_submissions=args.submissions, limit=limit, only_once=args.once, - sleepy=common.int_none(args.sleepy), + sleepy=int(args.sleepy), ) diff --git a/timesearch/timesearch.py b/timesearch/timesearch.py index bfbca27..23736bc 100644 --- a/timesearch/timesearch.py +++ b/timesearch/timesearch.py @@ -28,11 +28,11 @@ def timesearch( common.bot.login(common.r) if subreddit: - database = tsdb.TSDB.for_subreddit(subreddit) + (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True) else: # When searching, we'll take the user's submissions from anywhere. subreddit = 'all' - database = tsdb.TSDB.for_user(username) + (database, username) = tsdb.TSDB.for_user(username, fix_name=True) cur = database.sql.cursor() if lower == 'update': diff --git a/timesearch/tsdb.py b/timesearch/tsdb.py index 0fe0227..8ae8018 100644 --- a/timesearch/tsdb.py +++ b/timesearch/tsdb.py @@ -177,24 +177,44 @@ class TSDB: return pathclass.Path(path) @classmethod - def for_subreddit(cls, name, do_create=True): + def _for_object_helper(cls, name, path_formats, do_create=True, fix_name=False): + if name != os.path.basename(name): + filepath = pathclass.Path(name) + + else: + filepath = cls._pick_filepath(formats=path_formats, name=name) + + database = cls(filepath=filepath, do_create=do_create) + if fix_name: + return (database, name_from_path(name)) + return database + + @classmethod + def for_subreddit(cls, name, do_create=True, fix_name=False): if isinstance(name, common.praw.models.Subreddit): name = name.display_name elif not isinstance(name, str): raise TypeError(name, 'should be str or Subreddit.') - - filepath = cls._pick_filepath(formats=DB_FORMATS_SUBREDDIT, name=name) - return cls(filepath=filepath, do_create=do_create) + return cls._for_object_helper( + name, + do_create=do_create, + fix_name=fix_name, + path_formats=DB_FORMATS_SUBREDDIT, + ) @classmethod - def for_user(cls, name, do_create=True): + def for_user(cls, name, do_create=True, fix_name=False): if isinstance(name, common.praw.models.Redditor): name = name.name elif not isinstance(name, str): raise TypeError(name, 'should be str or Redditor.') - filepath = cls._pick_filepath(formats=DB_FORMATS_USER, name=name) - return cls(filepath=filepath, do_create=do_create) + return cls._for_object_helper( + name, + do_create=do_create, + fix_name=fix_name, + path_formats=DB_FORMATS_USER, + ) def insert(self, objects, commit=True): if not isinstance(objects, (list, tuple, types.GeneratorType)): @@ -368,3 +388,18 @@ def binding_filler(column_names, values, require_all=True): qmarks = ', '.join(qmarks) bindings = [values[column] for column in column_names] return (qmarks, bindings) + +def name_from_path(filepath): + ''' + In order to support usage like + > timesearch livestream -r D:\\some\\other\\filepath\\learnpython.db + this function extracts the subreddit name / username based on the given + path, so that we can pass it into `r.subreddit` / `r.redditor` properly. + ''' + if isinstance(filepath, pathclass.Path): + filepath = filepath.basename + else: + filepath = os.path.basename(filepath) + name = os.path.splitext(filepath)[0] + name = name.strip('@') + return name