diff --git a/ycdl/ycdldb.py b/ycdl/ycdldb.py index fe36827..ef6aa41 100644 --- a/ycdl/ycdldb.py +++ b/ycdl/ycdldb.py @@ -7,6 +7,7 @@ from . import exceptions from . import helpers from . import ytapi +from voussoirkit import pathclass from voussoirkit import sqlhelpers @@ -21,10 +22,18 @@ logging.getLogger('requests.packages.urllib3.connectionpool').setLevel(logging.W logging.getLogger('requests.packages.urllib3.util.retry').setLevel(logging.WARNING) DATABASE_VERSION = 4 -DB_INIT = ''' +DB_VERSION_PRAGMA = ''' +PRAGMA user_version = {user_version}; +''' +DB_PRAGMAS = ''' PRAGMA count_changes = OFF; PRAGMA cache_size = 10000; -PRAGMA user_version = {user_version}; +''' +DB_INIT = f''' +BEGIN; +---------------------------------------------------------------------------------------------------- +{DB_PRAGMAS} +{DB_VERSION_PRAGMA} CREATE TABLE IF NOT EXISTS channels( id TEXT, name TEXT, @@ -49,6 +58,8 @@ CREATE INDEX IF NOT EXISTS index_video_author_download on videos(author_id, down CREATE INDEX IF NOT EXISTS index_video_id on videos(id); CREATE INDEX IF NOT EXISTS index_video_published on videos(published); CREATE INDEX IF NOT EXISTS index_video_download on videos(download); +---------------------------------------------------------------------------------------------------- +COMMIT; '''.format(user_version=DATABASE_VERSION) SQL_CHANNEL_COLUMNS = [ @@ -84,29 +95,53 @@ def assert_is_abspath(path): class YCDLDB: - def __init__(self, youtube, database_filename=None, youtube_dl_function=None): + def __init__( + self, + youtube, + database_filename=None, + youtube_dl_function=None, + skip_version_check=False, + ): self.youtube = youtube if database_filename is None: database_filename = DEFAULT_DBNAME - existing_database = os.path.exists(database_filename) + self.database_filepath = pathclass.Path(database_filename) + existing_database = self.database_filepath.exists self.sql = sqlite3.connect(database_filename) self.cur = self.sql.cursor() if existing_database: - self.cur.execute('PRAGMA user_version') - existing_version = self.cur.fetchone()[0] - if existing_version != DATABASE_VERSION: - raise exceptions.DatabaseOutOfDate(current=existing_version, new=DATABASE_VERSION) + if not skip_version_check: + self._check_version() + self._load_pragmas() + else: + self._first_time_setup() if youtube_dl_function: self.youtube_dl_function = youtube_dl_function else: self.youtube_dl_function = YOUTUBE_DL_COMMAND - statements = DB_INIT.split(';') - for statement in statements: - self.cur.execute(statement) + def _check_version(self): + ''' + Compare database's user_version against DATABASE_VERSION, + raising exceptions.DatabaseOutOfDate if not correct. + ''' + existing = self.sql.execute('PRAGMA user_version').fetchone()[0] + if existing != DATABASE_VERSION: + raise exceptions.DatabaseOutOfDate( + existing=existing, + new=DATABASE_VERSION, + filepath=self.database_filepath, + ) + + def _first_time_setup(self): + self.sql.executescript(DB_INIT) + self.sql.commit() + + def _load_pragmas(self): + self.sql.executescript(DB_PRAGMAS) self.sql.commit() def add_channel(