diff --git a/voussoirkit/sqlhelpers.py b/voussoirkit/sqlhelpers.py index 0b933a1..f73f4da 100644 --- a/voussoirkit/sqlhelpers.py +++ b/voussoirkit/sqlhelpers.py @@ -57,41 +57,30 @@ def delete_filler(pairs): qmarks = f'WHERE {qmarks}' return (qmarks, bindings) -def insert_filler(column_names, values, require_all=True): +def insert_filler(pairs): ''' Manually aligning the bindings for INSERT statements is annoying. - Given the table's column names and a dictionary of {column: value}, - return the question marks and the list of bindings in the right order. + Given a dictionary of {column: value}, return the question marks and the + list of bindings in the right order. - require_all: - If `values` does not contain one of the column names, should we raise - an exception? - Otherwise, that column will simply receive None. - - >>> column_names=['id', 'name', 'score'], - >>> values={'score': 20, 'id': '1111', 'name': 'James'} - >>> insert_filler(column_names, scores) - ('?, ?, ?', ['1111', 'James', 20]) + >>> insert_filler({'score': 20, 'id': '1111', 'name': 'James'}) + ('(id, name, score) VALUES (?, ?, ?)', ['1111', 'James', 20]) In context: - (qmarks, bindings) = insert_filler(COLUMN_NAMES, data) - query = f'INSERT INTO table VALUES({qmarks})' + (qmarks, bindings) = insert_filler(pairs) + query = f'INSERT INTO table {qmarks}' cur.execute(query, bindings) ''' - values = values.copy() - missings = [] - for column in column_names: - if column in values: - continue - if require_all: - missings.append(column) - else: - values[column] = None - if missings: - raise ValueError(f'Missing columns {missings}.') - qmarks = '?' * len(column_names) + column_names = [] + bindings = [] + for (key, value) in pairs.items(): + column_names.append(key) + bindings.append(value) + + column_names = ', '.join(column_names) + qmarks = '?' * len(pairs) qmarks = ', '.join(qmarks) - bindings = [values[column] for column in column_names] + qmarks = f'({column_names}) VALUES ({qmarks})' return (qmarks, bindings) def update_filler(pairs, where_key): @@ -232,10 +221,10 @@ def _extract_table_name(create_table_statement): def _extract_columns_from_table(create_table_statement): # CREATE TABLE table_name(column_name TYPE MODIFIERS, ...) constraints = {'constraint', 'foreign', 'check', 'primary', 'unique'} - column_names = create_table_statement.split('(')[1].rsplit(')', 1)[0] - column_names = column_names.split(',') - column_names = [x.strip() for x in column_names] - column_names = [x.split(' ')[0] for x in column_names] + column_statements = create_table_statement.split('(')[1].rsplit(')', 1)[0] + column_statements = column_statements.split(',') + column_statements = [x.strip() for x in column_statements] + column_names = [x.split(' ')[0] for x in column_statements] column_names = [c for c in column_names if c.lower() not in constraints] return column_names