Added other metrics to --compare + Updated code for pyconll==3.1.0 + Added option to calculate on one node queries

This commit is contained in:
Luka 2023-01-31 14:49:47 +01:00
parent 91a7ddd84c
commit 40aaffa632
5 changed files with 15 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from generic import generate_key
class Tree(object):
def __init__(self, index, form, lemma, upos, xpos, deprel, feats, feats_detailed, form_dict, lemma_dict, upos_dict, xpos_dict, deprel_dict, feats_dict, feats_detailed_dict, head):
def __init__(self, index, form, lemma, upos, xpos, deprel, feats_detailed, form_dict, lemma_dict, upos_dict, xpos_dict, deprel_dict, feats_dict, feats_detailed_dict, head):
if not hasattr(self, 'feats'):
self.feats_detailed = {}
@ -27,9 +27,6 @@ class Tree(object):
if deprel not in deprel_dict:
deprel_dict[deprel] = Value(deprel)
self.deprel = deprel_dict[deprel]
if feats not in feats_dict:
feats_dict[feats] = Value(feats)
self.feats = feats_dict[feats]
for feat in feats_detailed.keys():
if feat not in feats_detailed_dict:
feats_detailed_dict[feat] = {}
@ -81,7 +78,6 @@ class Tree(object):
filter_passed = filter_passed and \
('deprel' not in option or option['deprel'] == self.deprel.get_value()) and \
('feats' not in option or option['feats'] == self.feats.get_value()) and \
('form' not in option or option['form'] == self.form.get_value()) and \
('lemma' not in option or option['lemma'] == self.lemma.get_value()) and \
('upos' not in option or option['upos'] == self.upos.get_value())
@ -100,7 +96,6 @@ class Tree(object):
('upos' not in query_tree or query_tree['upos'] == self.upos.get_value()) and \
('xpos' not in query_tree or query_tree['xpos'] == self.xpos.get_value()) and \
('deprel' not in query_tree or query_tree['deprel'] == self.deprel.get_value()) and \
('feats' not in query_tree or query_tree['feats'] == self.feats.get_value()) and \
(not filters['complete_tree_type'] or (len(self.children) == 0 and 'children' not in query_tree) or ('children' in query_tree and len(self.children) == len(query_tree['children']))) and \
self.fits_static_requirements_feats(query_tree)

View File

@ -23,7 +23,7 @@ def create_output_string_deprel(tree):
return tree.deprel.get_value()
def create_output_string_lemma(tree):
return tree.lemma.get_value()
return tree.lemma.get_value() if tree.lemma.get_value() is not None else '_'
def create_output_string_upos(tree):
return tree.upos.get_value()

View File

@ -1 +1 @@
pyconll==2.1.1
pyconll==3.1.0

View File

@ -14,4 +14,4 @@ for path in sorted(os.listdir(input_path)):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
if not os.path.exists(file_name):
os.system("python /home/luka/Development/STARK/dependency-parsetree.py --config_file " + config_path + " --input " + str(path) + " --output " + file_name)
os.system("python /home/luka/Development/STARK/stark.py --config_file " + config_path + " --input " + str(path) + " --output " + file_name)

View File

@ -134,9 +134,8 @@ def create_trees(input_path, internal_saves, feats_detailed_dict={}, save=True):
continue
# TODO check if 5th place is always there for feats
feats = token._fields[5]
token_form = token.form if token.form is not None else '_'
node = Tree(int(token.id), token_form, token.lemma, token.upos, token.xpos, token.deprel, feats, token.feats, form_dict,
node = Tree(int(token.id), token_form, token.lemma, token.upos, token.xpos, token.deprel, token.feats, form_dict,
lemma_dict, upos_dict, xpos_dict, deprel_dict, feats_dict, feats_detailed_dict, token.head)
token_nodes.append(node)
if token.deprel == 'root':
@ -351,7 +350,7 @@ def read_filters(config, args, feats_detailed_list):
tree_size_range = tree_size.split('-')
tree_size_range = [int(r) for r in tree_size_range]
if tree_size_range[0] > 1:
if tree_size_range[0] > 0:
if len(tree_size_range) == 1:
query_tree = create_ngrams_query_trees(tree_size_range[0], [{}])
elif len(tree_size_range) == 2:
@ -505,8 +504,12 @@ def get_keyness(abs_freq_A, abs_freq_B, count_A, count_B):
E2 = count_B * (abs_freq_A + abs_freq_B) / (count_A + count_B)
LL = 2 * ((abs_freq_A * math.log(abs_freq_A / E1)) + (abs_freq_B * math.log(abs_freq_B / E2))) if abs_freq_B > 0 else 'NaN'
BIC = LL - math.log(count_A + count_B) if abs_freq_B > 0 else 'NaN'
log_ratio = math.log(((abs_freq_A/count_A)/(abs_freq_B/count_B)), 2) if abs_freq_B > 0 else 'NaN'
OR = (abs_freq_A/(count_A-abs_freq_A)) / (abs_freq_B/(count_B-abs_freq_B)) if abs_freq_B > 0 else 'NaN'
diff = (((abs_freq_A/count_A)*1000000 - (abs_freq_B/count_B)*1000000)*100) / ((abs_freq_B/count_B)*1000000) if abs_freq_B > 0 else 'NaN'
return [LL]
return [abs_freq_B, abs_freq_B/count_B, LL, BIC, log_ratio, OR, diff]
def main():
@ -519,7 +522,7 @@ def main():
parser.add_argument("--internal_saves", default=None, type=str, help="Location for internal_saves.")
parser.add_argument("--cpu_cores", default=None, type=int, help="Number of cores used.")
parser.add_argument("--tree_size", default=None, type=int, help="Size of trees.")
parser.add_argument("--tree_size", default=None, type=str, help="Size of trees.")
parser.add_argument("--tree_type", default=None, type=str, help="Tree type.")
parser.add_argument("--dependency_type", default=None, type=str, help="Dependency type.")
parser.add_argument("--node_order", default=None, type=str, help="Order of node.")
@ -574,7 +577,7 @@ def main():
if filters['association_measures']:
header += ['MI', 'MI3', 'Dice', 'logDice', 't-score', 'simple-LL']
if args.compare:
header += ['LL']
header += ['Other absolute frequency', 'Other relative frequency', 'LL', 'BIC', 'Log ratio', 'OR', '%DIFF']
writer.writerow(header)
if filters['lines_threshold']:
@ -587,7 +590,8 @@ def main():
if filters['frequency_threshold'] and filters['frequency_threshold'] > v['number']:
break
words_only = [word_att for word in v['object'].array for word_att in word] + ['' for i in range((tree_size_range[-1] - len(v['object'].array)) * len(v['object'].array[0]))]
row = [v['object'].get_key()[1:-1]] + words_only + [str(v['number'])]
key = [v['object'].get_key()[1:-1]] if v['object'].get_key()[0] == '(' and v['object'].get_key()[-1] == ')' else [v['object'].get_key()]
row = key + words_only + [str(v['number'])]
row += ['%.4f' % relative_frequency]
if filters['node_order']:
row += [v['object'].order]