diff --git a/etiquette/constants.py b/etiquette/constants.py index 2f96e43..23168c6 100644 --- a/etiquette/constants.py +++ b/etiquette/constants.py @@ -204,6 +204,14 @@ def _extract_columns_from_table(create_table_statement): column_names = [c for c in column_names if c.lower() != 'foreign'] return column_names +def _extract_table_column_map(script): + columns = {} + table_statements = _extract_table_statements(script) + for table_statement in table_statements: + table_name = _extract_table_name(table_statement) + columns[table_name] = _extract_columns_from_table(table_statement) + return columns + def _reverse_index(columns): ''' Given an iterable, return a dictionary where the key is the item and the @@ -213,12 +221,7 @@ def _reverse_index(columns): ''' return {column: index for (index, column) in enumerate(columns)} -SQL_COLUMNS = {} -for table_statement in _extract_table_statements(DB_INIT): - table_name = _extract_table_name(table_statement) - columns = _extract_columns_from_table(table_statement) - SQL_COLUMNS[table_name] = columns - +SQL_COLUMNS = _extract_table_column_map(DB_INIT) SQL_INDEX = {table: _reverse_index(columns) for (table, columns) in SQL_COLUMNS.items()} ALLOWED_ORDERBY_COLUMNS = [