Make the SQL_COLUMNS code a bit easier to read.

This commit is contained in:
voussoir 2018-05-06 10:26:05 -07:00
parent db28b6819c
commit 11fda94968

View file

@ -192,15 +192,20 @@ def _extract_columns(create_table_statement):
column_names = [c for c in column_names if c.lower() != 'foreign']
return column_names
SQL_COLUMNS = {}
for statement in DB_INIT.split(';'):
if 'create table' not in statement.lower():
continue
def _extract_table_statements(script):
for statement in script.split(';'):
if 'create table' not in statement.lower():
continue
table_name = statement.split('(')[0].strip().split(' ')[-1]
SQL_COLUMNS[table_name] = _extract_columns(statement)
yield statement
def _sql_dictify(columns):
def _extract_table_name(create_table_statement):
# CREATE TABLE table_name(
table_name = create_table_statement.split('(')[0].strip()
table_name = table_name.split()[-1]
return table_name
def _reverse_index(columns):
'''
A dictionary where the key is the item and the value is the index.
Used to convert a stringy name into the correct number to then index into
@ -209,7 +214,11 @@ def _sql_dictify(columns):
'''
return {column: index for (index, column) in enumerate(columns)}
SQL_INDEX = {table: _sql_dictify(columns) for (table, columns) in SQL_COLUMNS.items()}
SQL_COLUMNS = {
_extract_table_name(table): _extract_columns(table)
for table in _extract_table_statements(DB_INIT)
}
SQL_INDEX = {table: _reverse_index(columns) for (table, columns) in SQL_COLUMNS.items()}
ALLOWED_ORDERBY_COLUMNS = [