voussoirkit/voussoirkit/sqlhelpers.py
2020-09-15 14:03:41 -07:00

236 lines
7.3 KiB
Python

import re
import types
def delete_filler(pairs):
'''
Manually aligning the bindings for DELETE statements is annoying.
Given a dictionary of {column: value}, return the "WHERE ..." portion of
the query and the bindings in the correct order.
Example:
pairs={'test': 'toast', 'ping': 'pong'}
->
returns ('WHERE test = ? AND ping = ?', ['toast', 'pong'])
In context:
(qmarks, bindings) = delete_filler(pairs)
query = f'DELETE FROM table {qmarks}'
cur.execute(query, bindings)
'''
qmarks = []
bindings = []
for (key, value) in pairs.items():
qmarks.append(f'{key} = ?')
bindings.append(value)
qmarks = ' AND '.join(qmarks)
qmarks = f'WHERE {qmarks}'
return (qmarks, bindings)
def insert_filler(column_names, values, require_all=True):
'''
Manually aligning the bindings for INSERT statements is annoying.
Given the table's column names and a dictionary of {column: value},
return the question marks and the list of bindings in the right order.
require_all:
If `values` does not contain one of the column names, should we raise
an exception?
Otherwise, that column will simply receive None.
Example:
column_names=['id', 'name', 'score'],
values={'score': 20, 'id': '1111', 'name': 'James'}
->
returns ('?, ?, ?', ['1111', 'James', 20])
In context:
(qmarks, bindings) = insert_filler(COLUMN_NAMES, data)
query = f'INSERT INTO table VALUES({qmarks})'
cur.execute(query, bindings)
'''
values = values.copy()
missings = []
for column in column_names:
if column in values:
continue
if require_all:
missings.append(column)
else:
values[column] = None
if missings:
raise ValueError(f'Missing columns {missings}.')
qmarks = '?' * len(column_names)
qmarks = ', '.join(qmarks)
bindings = [values[column] for column in column_names]
return (qmarks, bindings)
def update_filler(pairs, where_key):
'''
Manually aligning the bindings for UPDATE statements is annoying.
Given a dictionary of {column: value} as well as the name of the column
to be used as the WHERE, return the "SET ..." portion of the query and the
bindings in the correct order.
If the where_key needs to be reassigned also, let its value be a 2-tuple
where [0] is the current value used for WHERE, and [1] is the new value
used for SET.
Example:
pairs={'id': '1111', 'name': 'James', 'score': 20},
where_key='id'
->
returns ('SET name = ?, score = ? WHERE id == ?', ['James', 20, '1111'])
Example:
pairs={'filepath': ('/oldplace', '/newplace')},
where_key='filepath'
->
returns ('SET filepath = ? WHERE filepath == ?', ['/newplace', '/oldplace'])
In context:
(qmarks, bindings) = update_filler(data, where_key)
query = f'UPDATE table {qmarks}'
cur.execute(query, bindings)
'''
pairs = pairs.copy()
where_value = pairs.pop(where_key)
if isinstance(where_value, tuple):
(where_value, pairs[where_key]) = where_value
if isinstance(where_value, dict):
where_value = where_value['old']
pairs[where_key] = where_value['new']
if len(pairs) == 0:
raise ValueError('No pairs left after where_key.')
qmarks = []
bindings = []
for (key, value) in pairs.items():
qmarks.append(f'{key} = ?')
bindings.append(value)
bindings.append(where_value)
setters = ', '.join(qmarks)
qmarks = 'SET {setters} WHERE {where_key} == ?'
qmarks = qmarks.format(setters=setters, where_key=where_key)
return (qmarks, bindings)
def hex_byte(byte):
'''
Return the hex string for this byte. 00-ff.
'''
if byte not in range(0, 256):
raise ValueError(byte)
return hex(byte)[2:].rjust(2, '0')
def literal(item):
'''
Return a string depicting the SQL literal for this item.
Example:
0 -> "0"
'hello' -> "'hello'"
b'hello' -> "X'68656c6c6f'"
[3, 'hi'] -> "(3, 'hi')"
'''
if item is None:
return 'NULL'
elif isinstance(item, bool):
return f'{int(item)}'
elif isinstance(item, int):
return f'{item}'
elif isinstance(item, float):
return f'{item:f}'
elif isinstance(item, str):
item = item.replace("'", "''")
return f"'{item}'"
elif isinstance(item, bytes):
item = ''.join(hex_byte(byte) for byte in item)
return f"X'{item}'"
elif isinstance(item, (list, tuple, set, types.GeneratorType)):
return listify(item)
else:
raise ValueError(f'Unrecognized type {type(item)} {item}.')
def listify(items):
output = ', '.join(literal(item) for item in items)
output = f'({output})'
return output
def _extract_create_table_statements(script):
# script = sqlparse.format(script, strip_comments=True)
# script = re.sub(r'\s*--.+$', '', script, flags=re.MULTILINE)
script = re.sub(r'\n\s*create ', ';\ncreate ', script, flags=re.IGNORECASE)
for statement in script.split(';'):
statement = statement.strip()
if statement.lower().startswith('create table'):
yield statement
def _extract_table_name(create_table_statement):
# CREATE TABLE table_name(...)
table_name = create_table_statement.split('(')[0].strip()
table_name = table_name.split()[-1]
return table_name
def _extract_columns_from_table(create_table_statement):
# CREATE TABLE table_name(column_name TYPE MODIFIERS, ...)
constraints = {'constraint', 'foreign', 'check', 'primary', 'unique'}
column_names = create_table_statement.split('(')[1].rsplit(')', 1)[0]
column_names = column_names.split(',')
column_names = [x.strip() for x in column_names]
column_names = [x.split(' ')[0] for x in column_names]
column_names = [c for c in column_names if c.lower() not in constraints]
return column_names
def _reverse_index(columns):
return {column: index for (index, column) in enumerate(columns)}
def extract_table_column_map(script):
'''
Given an entire SQL script containing CREATE TABLE statements, return a
dictionary of the form
{
'table1': [
'column1',
'column2',
],
'table2': [
'column1',
'column2',
],
}
'''
columns = {}
create_table_statements = _extract_create_table_statements(script)
for create_table_statement in create_table_statements:
table_name = _extract_table_name(create_table_statement)
columns[table_name] = _extract_columns_from_table(create_table_statement)
return columns
def reverse_table_column_map(table_column_map):
'''
Given the table column map, return a reversed version of the form
{
'table1': {
'column1': 0,
'column2': 1,
},
'table2': {
'column1': 0,
'column2': 1,
},
}
If you have a row of data and you want to access one of the columns, you can
use this map to figure out which tuple index corresponds to the column name.
For example:
row = ('abcd', 'John', 23)
index = INDEX['people']['name']
print(row[index])
'''
return {table: _reverse_index(columns) for (table, columns) in table_column_map.items()}