557 lines
17 KiB
Python
557 lines
17 KiB
Python
'''
|
|
This module provides the ExpressionTree class, which parses a query expression
|
|
like "a AND (b OR c)" and then evaluates whether an input satisfies the query.
|
|
|
|
Basic usage:
|
|
tree = expressionmatch.ExpressionTree.parse('a AND (b OR c)')
|
|
tree.evaluate('a b')
|
|
tree.evaluate('a c')
|
|
tree.evaluate('b c')
|
|
|
|
The available operators are:
|
|
a AND b
|
|
a OR b
|
|
a XOR b
|
|
NOT a
|
|
|
|
where a and b can be single tokens or a parenthesized group of tokens.
|
|
|
|
The operators must be capitalized as seen and can be enclosed in quotes if you
|
|
need to literally match the word "AND", etc.
|
|
|
|
If the tokens contain spaces, they must be enclosed in quotation marks:
|
|
tree = expressionmatch.ExpressionTree.parse('"mark hamill" OR "harrison ford"')
|
|
'''
|
|
from voussoirkit import sentinel
|
|
|
|
BINARY_OPERATORS = {'AND', 'OR', 'XOR'}
|
|
UNARY_OPERATORS = {'NOT'}
|
|
PRECEDENCE = ['NOT', 'AND', 'XOR', 'OR']
|
|
OPERATORS = BINARY_OPERATORS | UNARY_OPERATORS
|
|
|
|
# These sentinels help the parser distinguish between parens used for token
|
|
# grouping and parens that have been escaped by the user and should remain
|
|
# as strings.
|
|
PAREN_OPEN = sentinel.Sentinel('PAREN_OPEN')
|
|
PAREN_CLOSE = sentinel.Sentinel('PAREN_CLOSE')
|
|
|
|
DEFAULT_MATCH_FUNCTION = str.__contains__
|
|
|
|
MESSAGE_WRITE_YOUR_OWN_MATCHER = '''
|
|
The default match function is {function}.
|
|
Consider passing your own `match_function`, which accepts two
|
|
positional arguments:
|
|
1. The object being tested.
|
|
2. The Expression token, a string.
|
|
'''.strip()
|
|
|
|
def func_and(values):
|
|
return all(values)
|
|
|
|
def func_or(values):
|
|
return any(values)
|
|
|
|
def func_xor(values):
|
|
values = list(values)
|
|
return values.count(True) % 2 == 1
|
|
|
|
def func_not(value):
|
|
value = list(value)
|
|
if len(value) != 1:
|
|
raise ValueError('NOT only takes 1 value')
|
|
return not value[0]
|
|
|
|
OPERATOR_FUNCTIONS = {
|
|
'AND': func_and,
|
|
'OR': func_or,
|
|
'XOR': func_xor,
|
|
'NOT': func_not,
|
|
}
|
|
|
|
class NoTokens(Exception):
|
|
pass
|
|
|
|
class ExpressionTree:
|
|
def __init__(self, token, parent=None):
|
|
'''
|
|
This constructor is for each individual node of the tree.
|
|
End-users should probably call ExpressionTree.parse instead of this
|
|
constructor.
|
|
'''
|
|
self.children = []
|
|
self.parent = parent
|
|
self.token = token
|
|
|
|
def __str__(self):
|
|
if self.token is None:
|
|
return '""'
|
|
|
|
self_token = str(self.token)
|
|
|
|
if self_token not in OPERATORS:
|
|
t = self_token
|
|
t = t.replace('"', '\\"')
|
|
t = t.replace('(', '\\(')
|
|
t = t.replace(')', '\\)')
|
|
if ' ' in t:
|
|
t = f'"{t}"'
|
|
return t
|
|
|
|
if len(self.children) == 1:
|
|
child = self.children[0]
|
|
childstring = str(child)
|
|
if child.token in OPERATORS:
|
|
return f'{self_token}({childstring})'
|
|
return f'{self_token} {childstring}'
|
|
|
|
children = []
|
|
for child in self.children:
|
|
childstring = str(child)
|
|
if child.token in OPERATORS:
|
|
childstring = f'({childstring})'
|
|
children.append(childstring)
|
|
|
|
if len(children) == 1:
|
|
return f'{self_token} {children[0]}'
|
|
|
|
s = f' {self_token} '
|
|
s = s.join(children)
|
|
return s
|
|
|
|
@classmethod
|
|
def parse(cls, tokens):
|
|
'''
|
|
Create an ExpressionTree from the given query string or list of tokens.
|
|
'''
|
|
if isinstance(tokens, str):
|
|
tokens = tokenize(tokens)
|
|
|
|
if tokens == []:
|
|
raise NoTokens()
|
|
|
|
if isinstance(tokens[0], list):
|
|
current = cls.parse(tokens[0])
|
|
else:
|
|
current = cls(token=tokens[0])
|
|
|
|
for token in tokens[1:]:
|
|
if isinstance(token, list):
|
|
new = cls.parse(token)
|
|
else:
|
|
new = cls(token=token)
|
|
|
|
if 0 == 1:
|
|
pass
|
|
|
|
elif current.token not in OPERATORS:
|
|
if new.token in BINARY_OPERATORS:
|
|
if len(new.children) == 0:
|
|
new.children.append(current)
|
|
current.parent = new
|
|
current = new
|
|
else:
|
|
raise Exception(f'Expected binary operator, got {new.token}.')
|
|
|
|
elif current.token in BINARY_OPERATORS:
|
|
if new.token in BINARY_OPERATORS:
|
|
if new.token == current.token:
|
|
for child in new.children:
|
|
child.parent = current
|
|
current.children.extend(new.children)
|
|
else:
|
|
if len(new.children) == 0:
|
|
new.children.append(current)
|
|
current.parent = new
|
|
current = new
|
|
else:
|
|
current.children.append(new)
|
|
new.parent = current
|
|
|
|
elif new.token in UNARY_OPERATORS:
|
|
if len(new.children) == 0:
|
|
current.children.append(new)
|
|
new.parent = current
|
|
current = new
|
|
else:
|
|
current.children.append(new)
|
|
new.parent = current
|
|
|
|
elif new.token not in OPERATORS:
|
|
if len(current.children) > 0:
|
|
current.children.append(new)
|
|
new.parent = current
|
|
else:
|
|
raise Exception('Expected current children > 0.')
|
|
|
|
elif current.token in UNARY_OPERATORS:
|
|
if len(current.children) == 0:
|
|
current.children.append(new)
|
|
new.parent = current
|
|
if current.parent is not None:
|
|
current = current.parent
|
|
elif new.token in BINARY_OPERATORS:
|
|
if len(new.children) == 0:
|
|
new.children.append(current)
|
|
current.parent = new
|
|
current = new
|
|
else:
|
|
current.children.append(new)
|
|
new.parent = current
|
|
if current.parent is not None:
|
|
current = current.parent
|
|
else:
|
|
raise Exception('Expected new to be my operand or parent binary.')
|
|
|
|
current = current.rootmost()
|
|
return current
|
|
|
|
def _evaluate(self, text, match_function=None):
|
|
if self.token not in OPERATORS:
|
|
if match_function is None:
|
|
match_function = DEFAULT_MATCH_FUNCTION
|
|
|
|
value = match_function(text, self.token)
|
|
return value
|
|
|
|
operator_function = OPERATOR_FUNCTIONS[self.token]
|
|
children = (child.evaluate(text, match_function=match_function) for child in self.children)
|
|
return operator_function(children)
|
|
|
|
def diagram(self):
|
|
if self.token is None:
|
|
return '""'
|
|
t = self.token
|
|
if ' ' in t:
|
|
t = f'"{t}"'
|
|
|
|
output = t
|
|
indent = 1
|
|
for child in self.children:
|
|
child = child.diagram()
|
|
for line in child.splitlines():
|
|
output += (' ' * indent)
|
|
output += line + '\n'
|
|
indent = len(t) + 1
|
|
output = output.strip()
|
|
|
|
return output
|
|
|
|
def evaluate(self, text, match_function=None):
|
|
if match_function is None:
|
|
match_function = DEFAULT_MATCH_FUNCTION
|
|
|
|
try:
|
|
return self._evaluate(text, match_function)
|
|
except Exception as e:
|
|
if match_function is DEFAULT_MATCH_FUNCTION:
|
|
message = MESSAGE_WRITE_YOUR_OWN_MATCHER.format(function=DEFAULT_MATCH_FUNCTION)
|
|
override = Exception(message)
|
|
raise override from e
|
|
raise e
|
|
|
|
@property
|
|
def is_leaf(self):
|
|
return self.token not in OPERATORS
|
|
|
|
@property
|
|
def is_root(self):
|
|
return self.parent is None
|
|
|
|
def map(self, function):
|
|
'''
|
|
Apply this function to all of the operands.
|
|
'''
|
|
for node in self.walk_leaves():
|
|
node.token = function(node.token)
|
|
|
|
def prune(self):
|
|
'''
|
|
Remove any nodes where `token` is None.
|
|
'''
|
|
self.children = [child for child in self.children if child.token is not None]
|
|
|
|
for child in self.children:
|
|
child.prune()
|
|
|
|
if self.token in OPERATORS and len(self.children) == 0:
|
|
self.token = None
|
|
if self.parent is not None:
|
|
self.parent.children.remove(self)
|
|
|
|
def rootmost(self):
|
|
current = self
|
|
while current.parent is not None:
|
|
current = current.parent
|
|
return current
|
|
|
|
def walk(self):
|
|
yield self
|
|
for child in self.children:
|
|
yield from child.walk()
|
|
|
|
def walk_leaves(self):
|
|
for node in self.walk():
|
|
if node.is_leaf:
|
|
yield node
|
|
|
|
def implied_tokens(tokens):
|
|
'''
|
|
This function returns a new list of tokens which has all of the implied
|
|
tokens added explicitly and meaningless tokens removed, by the
|
|
following rules:
|
|
|
|
1. If two operands are directly next to each other, or an operand is
|
|
followed by a unary operator, it is implied that there is an AND
|
|
between them.
|
|
'1 2' -> '1 AND 2'
|
|
'1 NOT 2' -> '1 AND NOT 2'
|
|
|
|
2. If an expression begins or ends with an invalid operator, remove it.
|
|
'AND 2' -> '2'
|
|
'2 AND' -> '2'
|
|
|
|
3. If a parenthetical term contains only 1 item, the parentheses can be removed.
|
|
'(a)' -> 'a'
|
|
'(NOT a)' -> 'NOT a'
|
|
'(a OR)' -> '(a)' (by rule 2) -> 'a'
|
|
|
|
4. If two operators are next to each other, except for binary-unary,
|
|
keep only the first.
|
|
'1 OR AND 2' -> '1 OR 2'
|
|
'1 NOT AND 2' -> '1 AND NOT AND 2' (by rule 1) -> '1 AND NOT 2'
|
|
'NOT NOT 1' -> 'NOT 1'
|
|
'1 AND NOT NOT 2' -> '1 AND NOT 2'
|
|
'''
|
|
final_tokens = []
|
|
has_operand = False
|
|
has_binary_operator = False
|
|
has_unary_operator = False
|
|
|
|
if len(tokens) == 1 and not isinstance(tokens[0], str):
|
|
# [['A' 'AND' 'B']] -> ['A' 'AND' 'B']
|
|
tokens = tokens[0]
|
|
|
|
for token in tokens:
|
|
skip_this = False
|
|
while isinstance(token, (list, tuple)):
|
|
if len(token) == 0:
|
|
# Delete empty parentheses.
|
|
skip_this = True
|
|
break
|
|
if len(token) == 1:
|
|
# Take singular terms out of their parentheses.
|
|
token = token[0]
|
|
else:
|
|
previous = token
|
|
token = implied_tokens(token)
|
|
if previous == token:
|
|
break
|
|
|
|
if skip_this:
|
|
continue
|
|
|
|
if isinstance(token, str) and token in OPERATORS:
|
|
this_binary = token in BINARY_OPERATORS
|
|
this_unary = not this_binary
|
|
|
|
# 'NOT AND' and 'AND AND' are malformed...
|
|
if this_binary and (has_binary_operator or has_unary_operator):
|
|
continue
|
|
# ...'NOT NOT' is malformed...
|
|
if this_unary and has_unary_operator:
|
|
continue
|
|
# ...but AND NOT is okay.
|
|
|
|
# 'AND test' is malformed
|
|
if this_binary and not has_operand:
|
|
continue
|
|
|
|
if this_unary and has_operand:
|
|
final_tokens.append('AND')
|
|
|
|
has_unary_operator = this_unary
|
|
has_binary_operator = this_binary
|
|
has_operand = False
|
|
|
|
else:
|
|
if has_operand:
|
|
final_tokens.append('AND')
|
|
has_unary_operator = False
|
|
has_binary_operator = False
|
|
has_operand = True
|
|
|
|
final_tokens.append(token)
|
|
|
|
if has_binary_operator or has_unary_operator:
|
|
final_tokens.pop(-1)
|
|
|
|
return final_tokens
|
|
|
|
def order_operations(tokens):
|
|
for (index, token) in enumerate(tokens):
|
|
if isinstance(token, list):
|
|
tokens[index] = order_operations(token)
|
|
|
|
if len(tokens) < 5:
|
|
return tokens
|
|
|
|
index = 0
|
|
slice_start = None
|
|
slice_end = None
|
|
precedence_stack = []
|
|
while index < len(tokens):
|
|
token = tokens[index]
|
|
try:
|
|
precedence = PRECEDENCE.index(token)
|
|
except ValueError:
|
|
precedence = None
|
|
|
|
if precedence is None:
|
|
index += 1
|
|
continue
|
|
precedence_stack.append(precedence)
|
|
|
|
if token in UNARY_OPERATORS:
|
|
slice_start = index
|
|
slice_end = index + 2
|
|
|
|
elif len(precedence_stack) > 1:
|
|
if precedence_stack[-1] < precedence_stack[-2]:
|
|
slice_start = index - 1
|
|
slice_end = None
|
|
elif precedence_stack[-2] < precedence_stack[-1]:
|
|
slice_end = index
|
|
|
|
if slice_start is None or slice_end is None:
|
|
index += 1
|
|
continue
|
|
|
|
tokens[slice_start:slice_end] = [tokens[slice_start:slice_end]]
|
|
slice_start = None
|
|
slice_end = None
|
|
for x in range(2):
|
|
if not precedence_stack:
|
|
break
|
|
|
|
delete = precedence_stack[-1]
|
|
while precedence_stack and precedence_stack[-1] == delete:
|
|
index -= 1
|
|
precedence_stack.pop(-1)
|
|
|
|
index += 1
|
|
|
|
if slice_start is not None:
|
|
slice_end = len(tokens)
|
|
tokens[slice_start:slice_end] = [tokens[slice_start:slice_end]]
|
|
|
|
return tokens
|
|
|
|
def sublist_tokens(tokens, _from_index=0, depth=0):
|
|
'''
|
|
Given a list of tokens, replace parentheses with actual sublists.
|
|
['1', 'AND', '(', '3', 'OR', '4', ')'] ->
|
|
['1', 'AND', ['3', 'OR', '4']]
|
|
|
|
Unclosed parentheses are automatically closed at the end.
|
|
'''
|
|
final_tokens = []
|
|
index = _from_index
|
|
while index < len(tokens):
|
|
token = tokens[index]
|
|
index += 1
|
|
if token is PAREN_OPEN:
|
|
(token, index) = sublist_tokens(tokens, _from_index=index, depth=depth+1)
|
|
if token is PAREN_CLOSE:
|
|
break
|
|
final_tokens.append(token)
|
|
if _from_index == 0:
|
|
return final_tokens
|
|
else:
|
|
return (final_tokens, index)
|
|
|
|
def tokenize(expression):
|
|
'''
|
|
Break the string into a list of tokens. Spaces are the delimiter unless
|
|
they are inside quotation marks.
|
|
|
|
Quotation marks and parentheses can be escaped by preceeding with a
|
|
backslash '\\'.
|
|
|
|
Opening and closing parentheses are put into their own token unless
|
|
escaped / quoted.
|
|
|
|
Extraneous closing parentheses are ignored completely.
|
|
|
|
'1 AND(4 OR "5 6") OR \\(test\\)' ->
|
|
['1', 'AND', '(', '4', 'OR', '5 6', ')', 'OR', '\\(test\\)']
|
|
'''
|
|
current_word = []
|
|
in_escape = False
|
|
in_quotes = False
|
|
paren_depth = 0
|
|
tokens = []
|
|
for character in expression:
|
|
if in_escape:
|
|
in_escape = False
|
|
|
|
elif character in {'(', ')'} and not in_quotes:
|
|
if character == '(':
|
|
sentinel = PAREN_OPEN
|
|
paren_depth += 1
|
|
elif character == ')':
|
|
sentinel = PAREN_CLOSE
|
|
paren_depth -= 1
|
|
|
|
if paren_depth >= 0:
|
|
tokens.append(''.join(current_word))
|
|
tokens.append(sentinel)
|
|
current_word.clear()
|
|
continue
|
|
else:
|
|
continue
|
|
|
|
elif character == '\\':
|
|
in_escape = True
|
|
continue
|
|
|
|
elif character == '"':
|
|
in_quotes = not in_quotes
|
|
continue
|
|
|
|
elif character.isspace() and not in_quotes:
|
|
tokens.append(''.join(current_word))
|
|
current_word.clear()
|
|
continue
|
|
|
|
current_word.append(character)
|
|
|
|
tokens.append(''.join(current_word))
|
|
tokens = [w for w in tokens if w != '']
|
|
tokens = sublist_tokens(tokens)
|
|
tokens = implied_tokens(tokens)
|
|
tokens = order_operations(tokens)
|
|
return tokens
|
|
|
|
if __name__ == '__main__':
|
|
tests = [
|
|
'[sci-fi] OR [pg-13]',
|
|
'([sci-fi] OR [war]) AND [r]',
|
|
'[r] XOR [sci-fi]',
|
|
'"[mark hamill]" "[harrison ford]"',
|
|
]
|
|
teststrings = {
|
|
'Star Wars': '[harrison ford] [george lucas] [sci-fi] [pg] [carrie fisher] [mark hamill] [space]',
|
|
'Blade Runner': '[harrison ford] [ridley scott] [neo-noir] [dystopian] [sci-fi] [r]',
|
|
'Indiana Jones': '[harrison ford] [steven spielberg] [adventure] [pg-13]',
|
|
'Apocalypse Now': '[harrison ford] [francis coppola] [r] [war] [drama]'
|
|
}
|
|
for test in tests:
|
|
print('start:', test)
|
|
tokens = tokenize(test)
|
|
print('implied:', tokens)
|
|
etree = ExpressionTree.parse(tokens)
|
|
print('tree:', etree)
|
|
print(etree.diagram())
|
|
for (name, teststring) in teststrings.items():
|
|
print('Matches', name, ':', etree.evaluate(teststring))
|
|
print()
|