Added label_whitelist

This commit is contained in:
Luka 2019-11-06 16:59:10 +01:00
parent d8b740a1e5
commit 2efae2e2de
2 changed files with 9 additions and 1 deletions

View File

@ -62,6 +62,10 @@ class Tree(object):
# return True
def fits_temporary_requirements(self, filters):
return not filters['label_whitelist'] or self.deprel.get_value() in filters['label_whitelist']
def fits_static_requirements(self, query_tree):
return ('form' not in query_tree or query_tree['form'] == self.form.get_value()) and \
('lemma' not in query_tree or query_tree['lemma'] == self.lemma.get_value()) and \
@ -306,7 +310,7 @@ class Tree(object):
active_temporary_query_trees = []
successful_temporary_queries = []
for i, temporary_query_tree in enumerate(temporary_query_trees):
if self.fits_static_requirements(temporary_query_tree):
if self.fits_static_requirements(temporary_query_tree) and self.fits_temporary_requirements(filters):
active_temporary_query_trees.append(temporary_query_tree)
successful_temporary_queries.append(i)
# if 'l_children' in temporary_query_tree and 'r_children' in temporary_query_tree:

View File

@ -271,6 +271,10 @@ def main():
filters = {}
filters['node_order'] = config.get('settings', 'node_order') == 'fixed'
filters['dependency_type'] = config.get('settings', 'dependency_type') == 'labeled'
if config.has_option('settings', 'label_whitelist'):
filters['label_whitelist'] = config.get('settings', 'label_whitelist').split('|')
else:
filters['label_whitelist'] = []
# for tree in all_trees[2:]:
# for tree in all_trees[1205:]: