Add extract_table_column_map and reverse_table_column_map.

This commit is contained in:
Ethan Dalool 2020-02-06 20:18:42 -08:00
parent 593ff020dc
commit abe73ae24d

View file

@ -160,3 +160,74 @@ def listify(items):
output = ', '.join(literal(item) for item in items) output = ', '.join(literal(item) for item in items)
output = f'({output})' output = f'({output})'
return output return output
def _extract_create_table_statements(script):
for statement in script.split(';'):
statement = statement.strip()
statement = statement.strip('-')
statement = statement.strip()
if statement.lower().startswith('create table'):
yield statement
def _extract_table_name(create_table_statement):
# CREATE TABLE table_name(...)
table_name = create_table_statement.split('(')[0].strip()
table_name = table_name.split()[-1]
return table_name
def _extract_columns_from_table(create_table_statement):
# CREATE TABLE table_name(column_name TYPE MODIFIERS, ...)
constraints = {'constraint', 'foreign', 'check', 'primary', 'unique'}
column_names = create_table_statement.split('(')[1].rsplit(')', 1)[0]
column_names = column_names.split(',')
column_names = [x.strip() for x in column_names]
column_names = [x.split(' ')[0] for x in column_names]
column_names = [c for c in column_names if c.lower() not in constraints]
return column_names
def _reverse_index(columns):
return {column: index for (index, column) in enumerate(columns)}
def extract_table_column_map(script):
'''
Given an entire SQL script containing CREATE TABLE statements, return a
dictionary of the form
{
'table1': [
'column1',
'column2',
],
'table2': [
'column1',
'column2',
],
}
'''
columns = {}
create_table_statements = _extract_create_table_statements(script)
for create_table_statement in create_table_statements:
table_name = _extract_table_name(create_table_statement)
columns[table_name] = _extract_columns_from_table(create_table_statement)
return columns
def reverse_table_column_map(table_column_map):
'''
Given the table column map, return a reversed version of the form
{
'table1': {
'column1': 0,
'column2': 1,
},
'table2': {
'column1': 0,
'column2': 1,
},
}
If you have a row of data and you want to access one of the columns, you can
use this map to figure out which tuple index corresponds to the column name.
For example:
row = ('abcd', 'John', 23)
index = INDEX['people']['name']
print(row[index])
'''
return {table: _reverse_index(columns) for (table, columns) in table_column_map.items()}