diff --git a/utilities/database_upgrader.py b/utilities/database_upgrader.py index bd41f0d..bf50124 100644 --- a/utilities/database_upgrader.py +++ b/utilities/database_upgrader.py @@ -3,14 +3,14 @@ import os import sqlite3 import sys +import bot import ycdl -def upgrade_1_to_2(sql): +def upgrade_1_to_2(ycdldb): ''' In this version, the duration column was added. ''' - cur = sql.cursor() - cur.executescript(''' + ycdldb.sql.executescript(''' ALTER TABLE videos RENAME TO videos_old; CREATE TABLE videos( id TEXT, @@ -35,20 +35,18 @@ def upgrade_1_to_2(sql): DROP TABLE videos_old; ''') -def upgrade_2_to_3(sql): +def upgrade_2_to_3(ycdldb): ''' In this version, a column `automark` was added to the channels table, where you can set channels to automatically mark videos as ignored or downloaded. ''' - cur = sql.cursor() - cur.execute('ALTER TABLE channels ADD COLUMN automark TEXT') + ycdldb.sql.execute('ALTER TABLE channels ADD COLUMN automark TEXT') -def upgrade_3_to_4(sql): +def upgrade_3_to_4(ycdldb): ''' In this version, the views column was added. ''' - cur = sql.cursor() - cur.executescript(''' + ycdldb.sql.executescript(''' ALTER TABLE videos RENAME TO videos_old; CREATE TABLE videos( id TEXT, @@ -75,18 +73,20 @@ def upgrade_3_to_4(sql): DROP TABLE videos_old; ''') -def upgrade_all(database_filepath): + +def upgrade_all(data_directory): ''' - Given the directory containing a phototagger database, apply all of the + Given the directory containing a ycdl database, apply all of the needed upgrade_x_to_y functions in order. ''' - sql = sqlite3.connect(database_filepath) + youtube = ycdl.ytapi.Youtube(bot.get_youtube_key()) + ycdldb = ycdl.ycdldb.YCDLDB(youtube, data_directory, skip_version_check=True) - cur = sql.cursor() + cur = ycdldb.sql.cursor() cur.execute('PRAGMA user_version') current_version = cur.fetchone()[0] - needed_version = ycdl.ycdl.DATABASE_VERSION + needed_version = ycdl.constants.DATABASE_VERSION if current_version == needed_version: print('Already up to date with version %d.' % needed_version) @@ -96,20 +96,20 @@ def upgrade_all(database_filepath): print('Upgrading from %d to %d.' % (current_version, version_number)) upgrade_function = 'upgrade_%d_to_%d' % (current_version, version_number) upgrade_function = eval(upgrade_function) - upgrade_function(sql) - sql.cursor().execute('PRAGMA user_version = %d' % version_number) - sql.commit() + upgrade_function(ycdldb) + ycdldb.sql.cursor().execute('PRAGMA user_version = %d' % version_number) + ycdldb.commit() current_version = version_number print('Upgrades finished.') def upgrade_all_argparse(args): - return upgrade_all(database_filepath=args.database_filepath) + return upgrade_all(data_directory=args.data_directory) def main(argv): parser = argparse.ArgumentParser() - parser.add_argument('database_filepath') + parser.add_argument('data_directory') parser.set_defaults(func=upgrade_all_argparse) args = parser.parse_args(argv)