diff --git a/timesearch/exceptions.py b/timesearch/exceptions.py index c9f03c4..9bd74b5 100644 --- a/timesearch/exceptions.py +++ b/timesearch/exceptions.py @@ -15,5 +15,15 @@ class TimesearchException(Exception): def __str__(self): return self.error_message +OUTOFDATE = ''' +Database is out of date. {current} should be {new}. +Please use utilities\\database_upgrader.py +'''.strip() +class DatabaseOutOfDate(TimesearchException): + ''' + Raised by TSDB __init__ if the user's database is behind. + ''' + error_message = OUTOFDATE + class DatabaseNotFound(TimesearchException, FileNotFoundError): error_message = 'Database file not found: "{}"' diff --git a/timesearch/tsdb.py b/timesearch/tsdb.py index 2b9d052..0fe0227 100644 --- a/timesearch/tsdb.py +++ b/timesearch/tsdb.py @@ -73,10 +73,6 @@ CREATE TABLE IF NOT EXISTS comments( CREATE INDEX IF NOT EXISTS comment_index ON comments(idstr); '''.format(user_version=DATABASE_VERSION) -ERROR_DATABASE_OUTOFDATE = ''' -Database is out of date. {current} should be {new}. -'''.strip() - DEFAULT_CONFIG = { } @@ -143,9 +139,7 @@ class TSDB: self.cur.execute('PRAGMA user_version') existing_version = self.cur.fetchone()[0] if existing_version > 0 and existing_version != DATABASE_VERSION: - message = ERROR_DATABASE_OUTOFDATE - message = message.format(current=existing_version, new=DATABASE_VERSION) - raise ValueError(message) + raise exceptions.DatabaseOutOfDate(current=existing_version, new=DATABASE_VERSION) statements = DB_INIT.split(';') for statement in statements: diff --git a/utilities/database_upgrader.py b/utilities/database_upgrader.py new file mode 100644 index 0000000..73b5f65 --- /dev/null +++ b/utilities/database_upgrader.py @@ -0,0 +1,52 @@ +import argparse +import os +import sqlite3 +import sys + +import timesearch.tsdb + + +def upgrade_all(database_filename): + ''' + Given the filename of a database, apply all of the needed + upgrade_x_to_y functions in order. + ''' + if not os.path.isfile(database_filename): + raise FileNotFoundError(database_filename) + + sql = sqlite3.connect(database_filename) + cur = sql.cursor() + + cur.execute('PRAGMA user_version') + current_version = cur.fetchone()[0] + needed_version = timesearch.tsdb.DATABASE_VERSION + + if current_version == needed_version: + print('Already up to date with version %d.' % needed_version) + return + + for version_number in range(current_version + 1, needed_version + 1): + print('Upgrading from %d to %d' % (current_version, version_number)) + upgrade_function = 'upgrade_%d_to_%d' % (current_version, version_number) + upgrade_function = eval(upgrade_function) + upgrade_function(sql) + sql.cursor().execute('PRAGMA user_version = %d' % version_number) + sql.commit() + current_version = version_number + print('Upgrades finished.') + + +def upgrade_all_argparse(args): + return upgrade_all(database_filename=args.database_filename) + +def main(argv): + parser = argparse.ArgumentParser() + + parser.add_argument('database_filename') + parser.set_defaults(func=upgrade_all_argparse) + + args = parser.parse_args(argv) + args.func(args) + +if __name__ == '__main__': + main(sys.argv[1:])