Added option to use parameters that overwrite config.ini data + Fixed None issue for some languages

This commit is contained in:
Luka 2023-01-12 10:52:36 +01:00
parent 46e9467fc2
commit 91274a36af
2 changed files with 77 additions and 31 deletions

View File

@ -134,7 +134,8 @@ def create_trees(input_path, internal_saves, feats_detailed_dict={}, save=True):
# TODO check if 5th place is always there for feats # TODO check if 5th place is always there for feats
feats = token._fields[5] feats = token._fields[5]
node = Tree(int(token.id), token.form, token.lemma, token.upos, token.xpos, token.deprel, feats, token.feats, form_dict, token_form = token.form if token.form is not None else '_'
node = Tree(int(token.id), token_form, token.lemma, token.upos, token.xpos, token.deprel, feats, token.feats, form_dict,
lemma_dict, upos_dict, xpos_dict, deprel_dict, feats_dict, feats_detailed_dict, token.head) lemma_dict, upos_dict, xpos_dict, deprel_dict, feats_dict, feats_detailed_dict, token.head)
token_nodes.append(node) token_nodes.append(node)
if token.deprel == 'root': if token.deprel == 'root':
@ -344,8 +345,9 @@ def count_trees(cpu_cores, all_trees, query_tree, create_output_string_functs, f
else: else:
result_dict[key] = {'object': r, 'number': 1} result_dict[key] = {'object': r, 'number': 1}
def read_filters(config, feats_detailed_list): def read_filters(config, args, feats_detailed_list):
tree_size_range = config.get('settings', 'tree_size', fallback='0').split('-') tree_size = config.get('settings', 'tree_size', fallback='0') if not args.tree_size else args.tree_size
tree_size_range = tree_size.split('-')
tree_size_range = [int(r) for r in tree_size_range] tree_size_range = [int(r) for r in tree_size_range]
if tree_size_range[0] > 1: if tree_size_range[0] > 1:
@ -356,14 +358,16 @@ def read_filters(config, feats_detailed_list):
for i in range(tree_size_range[0], tree_size_range[1] + 1): for i in range(tree_size_range[0], tree_size_range[1] + 1):
query_tree.extend(create_ngrams_query_trees(i, [{}])) query_tree.extend(create_ngrams_query_trees(i, [{}]))
else: else:
query_tree = [decode_query('(' + config.get('settings', 'query') + ')', '', feats_detailed_list)] query = config.get('settings', 'query') if not args.query else args.query
query_tree = [decode_query('(' + query + ')', '', feats_detailed_list)]
# set filters # set filters
node_types = config.get('settings', 'node_type').split('+') node_type = config.get('settings', 'node_type') if not args.node_type else args.node_type
node_types = node_type.split('+')
create_output_string_functs = [] create_output_string_functs = []
for node_type in node_types: for node_type in node_types:
assert node_type in ['deprel', 'lemma', 'upos', 'xpos', 'form', 'feats'], '"node_type" is not set up correctly' assert node_type in ['deprel', 'lemma', 'upos', 'xpos', 'form', 'feats'], '"node_type" is not set up correctly'
cpu_cores = config.getint('settings', 'cpu_cores') cpu_cores = config.getint('settings', 'cpu_cores') if not args.cpu_cores else args.cpu_cores
if node_type == 'deprel': if node_type == 'deprel':
create_output_string_funct = create_output_string_deprel create_output_string_funct = create_output_string_deprel
elif node_type == 'lemma': elif node_type == 'lemma':
@ -379,21 +383,25 @@ def read_filters(config, feats_detailed_list):
create_output_string_functs.append(create_output_string_funct) create_output_string_functs.append(create_output_string_funct)
filters = {} filters = {}
filters['internal_saves'] = config.get('settings', 'internal_saves') filters['internal_saves'] = config.get('settings', 'internal_saves') if not args.internal_saves else args.internal_saves
filters['input'] = config.get('settings', 'input') filters['input'] = config.get('settings', 'input') if not args.input else args.input
filters['node_order'] = config.get('settings', 'node_order') == 'fixed' node_order = config.get('settings', 'node_order') if not args.node_order else args.node_order
filters['node_order'] = node_order == 'fixed'
# filters['caching'] = config.getboolean('settings', 'caching') # filters['caching'] = config.getboolean('settings', 'caching')
filters['dependency_type'] = config.get('settings', 'dependency_type') == 'labeled' dependency_type = config.get('settings', 'dependency_type') if not args.dependency_type else args.dependency_type
filters['dependency_type'] = dependency_type == 'labeled'
if config.has_option('settings', 'label_whitelist'): if config.has_option('settings', 'label_whitelist'):
filters['label_whitelist'] = config.get('settings', 'label_whitelist').split('|') label_whitelist = config.get('settings', 'label_whitelist') if not args.label_whitelist else args.label_whitelist
filters['label_whitelist'] = label_whitelist.split('|')
else: else:
filters['label_whitelist'] = [] filters['label_whitelist'] = []
if config.has_option('settings', 'root_whitelist'): root_whitelist = config.get('settings', 'root_whitelist') if not args.root_whitelist else args.root_whitelist
if root_whitelist:
# test # test
filters['root_whitelist'] = [] filters['root_whitelist'] = []
for option in config.get('settings', 'root_whitelist').split('|'): for option in root_whitelist.split('|'):
attribute_dict = {} attribute_dict = {}
for attribute in option.split('&'): for attribute in option.split('&'):
value = attribute.split('=') value = attribute.split('=')
@ -402,12 +410,13 @@ def read_filters(config, feats_detailed_list):
else: else:
filters['root_whitelist'] = [] filters['root_whitelist'] = []
filters['complete_tree_type'] = config.get('settings', 'tree_type') == 'complete' tree_type = config.get('settings', 'tree_type') if not args.tree_type else args.tree_type
filters['association_measures'] = config.getboolean('settings', 'association_measures') filters['complete_tree_type'] = tree_type == 'complete'
filters['nodes_number'] = config.getboolean('settings', 'nodes_number') filters['association_measures'] = config.getboolean('settings', 'association_measures') if not args.association_measures else args.association_measures
filters['frequency_threshold'] = config.getfloat('settings', 'frequency_threshold', fallback=0) filters['nodes_number'] = config.getboolean('settings', 'nodes_number') if not args.nodes_number else args.nodes_number
filters['lines_threshold'] = config.getint('settings', 'lines_threshold', fallback=0) filters['frequency_threshold'] = config.getfloat('settings', 'frequency_threshold', fallback=0) if not args.frequency_threshold else args.frequency_threshold
filters['print_root'] = config.getboolean('settings', 'print_root') filters['lines_threshold'] = config.getint('settings', 'lines_threshold', fallback=0) if not args.lines_threshold else args.lines_threshold
filters['print_root'] = config.getboolean('settings', 'print_root') if not args.print_root else args.print_root
return filters, query_tree, create_output_string_functs, cpu_cores, tree_size_range, node_types return filters, query_tree, create_output_string_functs, cpu_cores, tree_size_range, node_types
@ -415,23 +424,41 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--config_file", parser.add_argument("--config_file", default=None, type=str, required=True, help="The input config file.")
default=None, parser.add_argument("--input", default=None, type=str, help="The input file/folder.")
type=str, parser.add_argument("--output", default=None, type=str, help="The output file.")
required=True, parser.add_argument("--internal_saves", default=None, type=str, help="Location for internal_saves.")
help="The input config file.") parser.add_argument("--cpu_cores", default=None, type=int, help="Number of cores used.")
parser.add_argument("--tree_size", default=None, type=int, help="Size of trees.")
parser.add_argument("--tree_type", default=None, type=str, help="Tree type.")
parser.add_argument("--dependency_type", default=None, type=str, help="Dependency type.")
parser.add_argument("--node_order", default=None, type=str, help="Order of node.")
parser.add_argument("--node_type", default=None, type=str, help="Type of node.")
parser.add_argument("--label_whitelist", default=None, type=str, help="Label whitelist.")
parser.add_argument("--root_whitelist", default=None, type=str, help="Root whitelist.")
parser.add_argument("--query", default=None, type=str, help="Query.")
parser.add_argument("--lines_threshold", default=None, type=str, help="Lines treshold.")
parser.add_argument("--frequency_threshold", default=None, type=int, help="Frequency threshold.")
parser.add_argument("--association_measures", default=None, type=bool, help="Association measures.")
parser.add_argument("--print_root", default=None, type=bool, help="Print root.")
parser.add_argument("--nodes_number", default=None, type=bool, help="Nodes number.")
parser.add_argument("--continuation_processing", default=None, type=bool, help="Nodes number.")
args = parser.parse_args() args = parser.parse_args()
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_file) config.read(args.config_file)
internal_saves = config.get('settings', 'internal_saves') internal_saves = config.get('settings', 'internal_saves') if not args.internal_saves else args.internal_saves
input_path = config.get('settings', 'input') input_path = config.get('settings', 'input') if not args.input else args.input
if os.path.isdir(input_path): if os.path.isdir(input_path):
checkpoint_path = Path(internal_saves, 'checkpoint.pkl') checkpoint_path = Path(internal_saves, 'checkpoint.pkl')
continuation_processing = config.getboolean('settings', 'continuation_processing', fallback=False) continuation_processing = config.getboolean('settings', 'continuation_processing', fallback=False) if not args.continuation_processing else args.input
if not checkpoint_path.exists() or not continuation_processing: if not checkpoint_path.exists() or not continuation_processing:
already_processed = set() already_processed = set()
@ -461,7 +488,7 @@ def main():
corpus_size += sub_corpus_size corpus_size += sub_corpus_size
filters, query_tree, create_output_string_functs, cpu_cores, tree_size_range, node_types = read_filters( filters, query_tree, create_output_string_functs, cpu_cores, tree_size_range, node_types = read_filters(
config, feats_detailed_list) config, args, feats_detailed_list)
count_trees(cpu_cores, all_trees, query_tree, create_output_string_functs, filters, unigrams_dict, count_trees(cpu_cores, all_trees, query_tree, create_output_string_functs, filters, unigrams_dict,
@ -491,7 +518,7 @@ def main():
result_dict = {} result_dict = {}
unigrams_dict = {} unigrams_dict = {}
filters, query_tree, create_output_string_functs, cpu_cores, tree_size_range, node_types = read_filters(config, feats_detailed_list) filters, query_tree, create_output_string_functs, cpu_cores, tree_size_range, node_types = read_filters(config, args, feats_detailed_list)
start_exe_time = time.time() start_exe_time = time.time()
count_trees(cpu_cores, all_trees, query_tree, create_output_string_functs, filters, unigrams_dict, result_dict) count_trees(cpu_cores, all_trees, query_tree, create_output_string_functs, filters, unigrams_dict, result_dict)
@ -500,13 +527,15 @@ def main():
print("--- %s seconds ---" % (time.time() - start_exe_time)) print("--- %s seconds ---" % (time.time() - start_exe_time))
sorted_list = sorted(result_dict.items(), key=lambda x: x[1]['number'], reverse=True) sorted_list = sorted(result_dict.items(), key=lambda x: x[1]['number'], reverse=True)
with open(config.get('settings', 'output'), "w", newline="", encoding="utf-8") as f: output = config.get('settings', 'output') if not args.output else args.output
with open(output, "w", newline="", encoding="utf-8") as f:
# header - use every second space as a split # header - use every second space as a split
writer = csv.writer(f, delimiter='\t') writer = csv.writer(f, delimiter='\t')
if tree_size_range[-1]: if tree_size_range[-1]:
len_words = tree_size_range[-1] len_words = tree_size_range[-1]
else: else:
len_words = int(len(config.get('settings', 'query').split(" "))/2 + 1) query = config.get('settings', 'query') if not args.query else args.query
len_words = int(len(query.split(" "))/2 + 1)
header = ["Structure"] + ["Node " + string.ascii_uppercase[i] + "-" + node_type for i in range(len_words) for node_type in node_types] + ['Absolute frequency'] header = ["Structure"] + ["Node " + string.ascii_uppercase[i] + "-" + node_type for i in range(len_words) for node_type in node_types] + ['Absolute frequency']
header += ['Relative frequency'] header += ['Relative frequency']
if filters['node_order']: if filters['node_order']:

17
run-multiple-depparse.py Normal file
View File

@ -0,0 +1,17 @@
import os
from pathlib import Path
input_path = '/home/luka/Development/STARK/data/ud-treebanks-v2.11/'
output_path = '/home/luka/Development/STARK/results/ud-treebanks-v2.11_B/'
for path in sorted(os.listdir(input_path)):
path_obj = Path(input_path, path)
pathlist = path_obj.glob('**/*.conllu')
for path in sorted(pathlist):
folder_name = os.path.join(output_path, path.parts[-2])
file_name = os.path.join(folder_name, path.name)
if not os.path.exists(folder_name):
os.makedirs(folder_name)
if not os.path.exists(file_name):
# os.system("python /home/luka/Development/STARK/dependency-parsetree.py --config_file config.ini --input " + str(path) + " --output " + file_name)
os.system("python /home/luka/Development/STARK/dependency-parsetree.py --config_file data/B_test-all-treebanks_3_completed_unlabeled_fixed_form_root=NOUN_5.ini --input " + str(path) + " --output " + file_name)