voussoirkit/voussoirkit/sqlhelpers.py
Ethan Dalool 54fa46c4f9
Use named column inserts, but remove require_all parameter.
I am yet again bumping into the limits of extract_table_column_map
as it does not use a real parser and does not recognize generated
columns, which causes issues when I use that column list to enforce
an insert.

It probably would be better to move even farther away from extracting
columns and just asking the database instead.

Anyway we get slightly improved ergonomics of insert_filler at the cost
of not being able to enforce all columns within this function.
2022-08-13 07:23:53 -07:00

286 lines
9 KiB
Python

'''
sqlhelpers
==========
This module provides functions for SQL string manipulation that I need often.
Most importantly, creating the right number of ? for binding insert / update
statements.
'''
import re
import types
class Inject:
'''
When making UPDATE or DELETE statements, you may wish to use WHERE clauses
that are not simply data values, such as sqlite function calls. When making
your pairs dict for update_filler or delete_filler, you can let the value of
a pair be an Inject class with a string, and that string will be directly
injected into the query.
You should not use an Inject for any user-provided strings, but you could
let the user pick between two different Injects prepared by you.
>>> pairs = {
... 'domain': 'example.com',
... 'url': sqlhelpers.Inject('REPLACE(url, "http://", "https://")'),
... }
>>> sqlhelpers.update_filler(pairs=pairs, where_key='domain')
('SET url = REPLACE(url, "http://", "https://") WHERE domain == ?', ['example.com'])
'''
def __init__(self, string):
self.string = string
def delete_filler(pairs):
'''
Manually aligning the bindings for DELETE statements is annoying.
Given a dictionary of {column: value}, return the "WHERE ..." portion of
the query and the bindings in the correct order.
>>> pairs={'test': 'toast', 'ping': 'pong'}
>>> delete_filler(pairs)
('WHERE test == ? AND ping == ?', ['toast', 'pong'])
In context:
(qmarks, bindings) = delete_filler(pairs)
query = f'DELETE FROM table {qmarks}'
cur.execute(query, bindings)
'''
qmarks = []
bindings = []
for (key, value) in pairs.items():
if isinstance(value, Inject):
qmarks.append(f'{key} == {value.string}')
else:
qmarks.append(f'{key} == ?')
bindings.append(value)
qmarks = ' AND '.join(qmarks)
qmarks = f'WHERE {qmarks}'
return (qmarks, bindings)
def insert_filler(pairs):
'''
Manually aligning the bindings for INSERT statements is annoying.
Given a dictionary of {column: value}, return the question marks and the
list of bindings in the right order.
>>> insert_filler({'score': 20, 'id': '1111', 'name': 'James'})
('(id, name, score) VALUES (?, ?, ?)', ['1111', 'James', 20])
In context:
(qmarks, bindings) = insert_filler(pairs)
query = f'INSERT INTO table {qmarks}'
cur.execute(query, bindings)
'''
column_names = []
bindings = []
for (key, value) in pairs.items():
column_names.append(key)
bindings.append(value)
column_names = ', '.join(column_names)
qmarks = '?' * len(pairs)
qmarks = ', '.join(qmarks)
qmarks = f'({column_names}) VALUES ({qmarks})'
return (qmarks, bindings)
def update_filler(pairs, where_key):
'''
Manually aligning the bindings for UPDATE statements is annoying.
Given a dictionary of {column: value} as well as the name of the column
to be used as the WHERE, return the "SET ..." portion of the query and the
bindings in the correct order.
If the where_key needs to be reassigned also, let its value be a 2-tuple
where [0] is the current value used for WHERE, and [1] is the new value
used for SET.
>>> pairs={'id': '1111', 'name': 'James', 'score': 20},
>>> where_key='id'
>>> update_filler(pairs, where_key)
('SET name = ?, score = ? WHERE id == ?', ['James', 20, '1111'])
Example:
>>> pairs={'filepath': ('/oldplace', '/newplace')},
>>> where_key='filepath'
>>> update_filler(pairs, where_key)
('SET filepath = ? WHERE filepath == ?', ['/newplace', '/oldplace'])
In context:
(qmarks, bindings) = update_filler(data, where_key)
query = f'UPDATE table {qmarks}'
cur.execute(query, bindings)
'''
pairs = pairs.copy()
where_value = pairs.pop(where_key)
if isinstance(where_value, tuple):
(where_value, pairs[where_key]) = where_value
if isinstance(where_value, dict):
where_value = where_value['old']
pairs[where_key] = where_value['new']
if len(pairs) == 0:
raise ValueError('No pairs left after where_key.')
qmarks = []
bindings = []
for (key, value) in pairs.items():
if isinstance(value, Inject):
qmarks.append(f'{key} = {value.string}')
else:
qmarks.append(f'{key} = ?')
bindings.append(value)
bindings.append(where_value)
setters = ', '.join(qmarks)
qmarks = 'SET {setters} WHERE {where_key} == ?'
qmarks = qmarks.format(setters=setters, where_key=where_key)
return (qmarks, bindings)
def executescript(conn, script):
'''
The problem with Python's default executescript is that it executes a
commit before running your script. If I wanted a commit I'd write one!
'''
script = _remove_script_comments(script)
statements = re.split(r';(:?\n|$)', script)
statements = (statement.strip() for statement in statements)
statements = (statement for statement in statements if statement)
cur = conn.cursor()
for statement in statements:
cur.execute(statement)
def hex_byte(byte):
'''
Return the hex string for this byte. 00-ff.
'''
if byte not in range(0, 256):
raise ValueError(byte)
return hex(byte)[2:].rjust(2, '0')
def literal(item):
'''
Return a string depicting the SQL literal for this item.
>>> literal(0)
"0"
>>> literal('hello')
"'hello'"
>>> literal(b'hello')
"X'68656c6c6f'"
>>> literal([3, 'hi'])
"(3, 'hi')"
'''
if item is None:
return 'NULL'
elif isinstance(item, bool):
return f'{int(item)}'
elif isinstance(item, int):
return f'{item}'
elif isinstance(item, float):
return f'{item:f}'
elif isinstance(item, str):
item = item.replace("'", "''")
return f"'{item}'"
elif isinstance(item, bytes):
item = ''.join(hex_byte(byte) for byte in item)
return f"X'{item}'"
elif isinstance(item, (list, tuple, set, types.GeneratorType)):
return listify(item)
else:
raise ValueError(f'Unrecognized type {type(item)} {item}.')
def listify(items):
output = ', '.join(literal(item) for item in items)
output = f'({output})'
return output
def _extract_create_table_statements(script):
# script = sqlparse.format(script, strip_comments=True)
# script = re.sub(r'\s*--.+$', '', script, flags=re.MULTILINE)
script = re.sub(r'\n\s*create ', ';\ncreate ', script, flags=re.IGNORECASE)
for statement in script.split(';'):
statement = statement.strip()
if statement.lower().startswith('create table'):
yield statement
def _extract_table_name(create_table_statement):
# CREATE TABLE table_name(...)
table_name = create_table_statement.split('(')[0].strip()
table_name = table_name.split()[-1]
return table_name
def _extract_columns_from_table(create_table_statement):
# CREATE TABLE table_name(column_name TYPE MODIFIERS, ...)
constraints = {'constraint', 'foreign', 'check', 'primary', 'unique'}
column_statements = create_table_statement.split('(')[1].rsplit(')', 1)[0]
column_statements = column_statements.split(',')
column_statements = [x.strip() for x in column_statements]
column_names = [x.split(' ')[0] for x in column_statements]
column_names = [c for c in column_names if c.lower() not in constraints]
return column_names
def _remove_script_comments(script):
lines = []
for line in script.splitlines():
if re.match(r'^\s*--', line):
continue
lines.append(line)
return '\n'.join(lines)
def _reverse_index(columns):
return {column: index for (index, column) in enumerate(columns)}
def extract_table_column_map(script):
'''
Given an entire SQL script containing CREATE TABLE statements, return a
dictionary of the form
{
'table1': [
'column1',
'column2',
],
'table2': [
'column1',
'column2',
],
}
'''
columns = {}
script = _remove_script_comments(script)
create_table_statements = _extract_create_table_statements(script)
for create_table_statement in create_table_statements:
table_name = _extract_table_name(create_table_statement)
columns[table_name] = _extract_columns_from_table(create_table_statement)
return columns
def reverse_table_column_map(table_column_map):
'''
Given the table column map, return a reversed version of the form
{
'table1': {
'column1': 0,
'column2': 1,
},
'table2': {
'column1': 0,
'column2': 1,
},
}
If you have a row of data and you want to access one of the columns, you can
use this map to figure out which tuple index corresponds to the column name.
For example:
row = ('abcd', 'John', 23)
index = INDEX['people']['name']
print(row[index])
'''
return {table: _reverse_index(columns) for (table, columns) in table_column_map.items()}