diff --git a/voussoirkit/sqlhelpers.py b/voussoirkit/sqlhelpers.py index 327d498..e06c460 100644 --- a/voussoirkit/sqlhelpers.py +++ b/voussoirkit/sqlhelpers.py @@ -160,3 +160,74 @@ def listify(items): output = ', '.join(literal(item) for item in items) output = f'({output})' return output + +def _extract_create_table_statements(script): + for statement in script.split(';'): + statement = statement.strip() + statement = statement.strip('-') + statement = statement.strip() + if statement.lower().startswith('create table'): + yield statement + +def _extract_table_name(create_table_statement): + # CREATE TABLE table_name(...) + table_name = create_table_statement.split('(')[0].strip() + table_name = table_name.split()[-1] + return table_name + +def _extract_columns_from_table(create_table_statement): + # CREATE TABLE table_name(column_name TYPE MODIFIERS, ...) + constraints = {'constraint', 'foreign', 'check', 'primary', 'unique'} + column_names = create_table_statement.split('(')[1].rsplit(')', 1)[0] + column_names = column_names.split(',') + column_names = [x.strip() for x in column_names] + column_names = [x.split(' ')[0] for x in column_names] + column_names = [c for c in column_names if c.lower() not in constraints] + return column_names + +def _reverse_index(columns): + return {column: index for (index, column) in enumerate(columns)} + +def extract_table_column_map(script): + ''' + Given an entire SQL script containing CREATE TABLE statements, return a + dictionary of the form + { + 'table1': [ + 'column1', + 'column2', + ], + 'table2': [ + 'column1', + 'column2', + ], + } + ''' + columns = {} + create_table_statements = _extract_create_table_statements(script) + for create_table_statement in create_table_statements: + table_name = _extract_table_name(create_table_statement) + columns[table_name] = _extract_columns_from_table(create_table_statement) + return columns + +def reverse_table_column_map(table_column_map): + ''' + Given the table column map, return a reversed version of the form + { + 'table1': { + 'column1': 0, + 'column2': 1, + }, + 'table2': { + 'column1': 0, + 'column2': 1, + }, + } + If you have a row of data and you want to access one of the columns, you can + use this map to figure out which tuple index corresponds to the column name. + For example: + row = ('abcd', 'John', 23) + index = INDEX['people']['name'] + print(row[index]) + ''' + return {table: _reverse_index(columns) for (table, columns) in table_column_map.items()}