Updating for lemma representation of word_form. Also cleaning code, adding tqdm,...

This commit is contained in:
Ozbolt Menegatti 2019-05-24 18:15:21 +02:00
parent 3c669c7901
commit 4c2b5f2b13

171
wani.py
View File

@ -1,7 +1,7 @@
from xml.etree import ElementTree from xml.etree import ElementTree
import re import re
from enum import Enum from enum import Enum
from collections import defaultdict from collections import defaultdict, namedtuple
import sys import sys
import logging import logging
import argparse import argparse
@ -12,6 +12,7 @@ import concurrent.futures
import tempfile import tempfile
from msd_translate import MSD_TRANSLATE from msd_translate import MSD_TRANSLATE
from tqdm import tqdm
MAX_NUM_COMPONENTS = 5 MAX_NUM_COMPONENTS = 5
@ -217,7 +218,7 @@ class ComponentRendition:
return self.rendition is rendition return self.rendition is rendition
@staticmethod @staticmethod
def set_representations(matches, structure, word_renderer): def set_representations(matches, structure, word_renderer, lemma_msds):
representations = { representations = {
c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""] c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""]
for c in structure.components for c in structure.components
@ -225,24 +226,34 @@ class ComponentRendition:
found_agreements = {} found_agreements = {}
word_component_id = {} word_component_id = {}
# doprint = structure.id == '1' and matches[0]['1'].text.startswith('evrop') and matches[0]['2'].text.startswith('prv') def render_all(component_id, lst, _bw):
doprint = False rep = "/".join([w.text for w in set(lst)]) if len(lst) > 0 else None
matches.representations[component_id] = rep
def render_all(component_id, lst): def render_form(component_id, lst, backup_word):
matches.representations[component_id] = "/".join([w.text for w in set(lst)]) if backup_word is not None:
lst.append(backup_word)
def render_form(component_id, lst): lst_ctr = []
sorted_lst = sorted(set(lst), key=lst.count) for word in lst:
for word in sorted_lst: lst_ctr.append((word.msd, word.lemma))
sorted_lst = sorted(set(lst_ctr), key=lst.count)
if len(lst) > 3:
a = 3
for word_msd, word_lemma in sorted_lst:
if component_id in found_agreements: if component_id in found_agreements:
other_component_id, other_word, agreements = found_agreements[component_id] other_component_id, other_word, agreements = found_agreements[component_id]
print(word.lemma, other_word.lemma, component_id, other_component_id, word.msd, word.msd) agr = are_agreements_ok(word_msd, other_word.lemma, other_word.msd, agreements)
agr = are_agreements_ok(word.msd, other_word.lemma, other_word.msd, agreements)
if agr is None: if agr is None:
continue continue
matches.representations[other_component_id] = agr matches.representations[other_component_id] = agr
matches.representations[word_component_id[word.id]] = word.most_frequent_text(word_renderer) if word_lemma is not None:
matches.representations[component_id] = word_renderer.render(word_lemma, word_msd)
break break
# othw = are_agreements_ok(word, found_agreements) # othw = are_agreements_ok(word, found_agreements)
@ -255,10 +266,7 @@ class ComponentRendition:
if ow_msd[0] != w2_msd[0]: if ow_msd[0] != w2_msd[0]:
continue continue
print(w1_msd, w2_msd)
if check_agreement(w1_msd, w2_msd, agreements): if check_agreement(w1_msd, w2_msd, agreements):
if doprint:
print("GOOD :)")
return w2_txt return w2_txt
def check_msd(word, selectors): def check_msd(word, selectors):
@ -279,7 +287,6 @@ class ComponentRendition:
# if not in msd, some strange msd was tries, skipping... # if not in msd, some strange msd was tries, skipping...
if agr_case not in TAGSET[t1]: if agr_case not in TAGSET[t1]:
logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd1)) logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd1))
print("BAAAD")
return False return False
v1 = TAGSET[t1].index(agr_case) v1 = TAGSET[t1].index(agr_case)
@ -293,7 +300,6 @@ class ComponentRendition:
t2 = msd2[0] t2 = msd2[0]
if agr_case not in TAGSET[t2]: if agr_case not in TAGSET[t2]:
logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd2)) logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd2))
print("BAAAD")
return False return False
v2 = TAGSET[t2].index(agr_case) v2 = TAGSET[t2].index(agr_case)
if v2 + 1 >= len(msd2): if v2 + 1 >= len(msd2):
@ -328,44 +334,26 @@ class ComponentRendition:
assert(rep.isit(Rendition.WordForm)) assert(rep.isit(Rendition.WordForm))
wf_type, more = rep.more wf_type, more = rep.more
if wf_type is WordFormSelection.All: if wf_type is WordFormSelection.Msd:
add = True
func = render_all
elif wf_type is WordFormSelection.Msd:
add = check_msd(w, more) add = check_msd(w, more)
func = render_form func = render_form
elif wf_type is WordFormSelection.All:
add = True
func = render_all
elif wf_type is WordFormSelection.Any: elif wf_type is WordFormSelection.Any:
add = True add = True
func = render_form func = render_form
else: else:
assert(wf_type is WordFormSelection.Agreement) assert(wf_type is WordFormSelection.Agreement)
other_w, agreements = more other_w, agreements = more
found_agreements[other_w] = (w_id, w.lemma, agreements) found_agreements[other_w] = (w_id, w, agreements)
add = True add = True
func = lambda *x: None func = lambda *x: None
representations[w_id][1] = func
if add: if add:
representations[w_id][0].append(w) representations[w_id][0].append(w)
representations[w_id][1] = func
if doprint:
print(len(matches), len(found_agreements))
# for w1i, w2i, agreements in found_agreements:
# w1, w2 = words[w1i], words[w2i]
# if doprint:
# print("? ", w1.msd, w2.msd, end="")
# if w2i not in bad_words:
#
# if check_agreement(w1, w2, agreements):
# representations[w1i][0].append(w1.text)
# if doprint:
# print(" :)")
# elif doprint:
# print(" :(")
# elif doprint:
# print(" :((")
# just need to set representation to first group, # just need to set representation to first group,
# but in correct order, agreements last! # but in correct order, agreements last!
@ -379,19 +367,11 @@ class ComponentRendition:
for w_id, w in representation_sorted_words: for w_id, w in representation_sorted_words:
data = representations[w_id] data = representations[w_id]
if doprint:
print([(r.text, r.lemma, r.msd) for r in data[0]])
if type(data[1]) is str: if type(data[1]) is str:
matches.representations[w_id] = None if data[0] else data[1] matches.representations[w_id] = None if data[0] else data[1]
elif len(data[0]) == 0:
matches.representations[w_id] = None
else: else:
data[1](str(w_id), data[0]) backup_word = lemma_only_word(lemma_msds[w.msd[0]]) if w.msd[0] in lemma_msds else None
data[1](str(w_id), data[0], backup_word)
if doprint:
print(matches.representations)
print('--')
def __str__(self): def __str__(self):
return str(self.rendition) return str(self.rendition)
@ -468,7 +448,7 @@ def build_morphology_regex(restriction):
return False return False
return True return True
return " ".join(rgx), matcher return rgx, matcher
def build_lexis_regex(restriction): def build_lexis_regex(restriction):
@ -493,7 +473,8 @@ class Restriction:
restriction_type = restriction_tag.get('type') restriction_type = restriction_tag.get('type')
if restriction_type == "morphology": if restriction_type == "morphology":
self.type = RestrictionType.Morphology self.type = RestrictionType.Morphology
self.present, self.matcher = build_morphology_regex(list(restriction_tag)) present, self.matcher = build_morphology_regex(list(restriction_tag))
self.present = " ".join(present)
elif restriction_type == "lexis": elif restriction_type == "lexis":
self.type = RestrictionType.Lexis self.type = RestrictionType.Lexis
self.present, self.matcher = build_lexis_regex(list(restriction_tag)) self.present, self.matcher = build_lexis_regex(list(restriction_tag))
@ -822,10 +803,14 @@ class SyntacticStructure:
# to_ret.append((m, self.check_agreements(m))) # to_ret.append((m, self.check_agreements(m)))
def build_structures(filename): def load_structures(filename):
structures = []
with open(filename, 'r') as fp: with open(filename, 'r') as fp:
et = ElementTree.XML(fp.read()) et = ElementTree.XML(fp.read())
return build_structures(et), get_lemma_features(et)
def build_structures(et):
structures = []
for structure in et.iter('syntactic_structure'): for structure in et.iter('syntactic_structure'):
to_append = SyntacticStructure.from_xml(structure) to_append = SyntacticStructure.from_xml(structure)
if to_append is None: if to_append is None:
@ -833,6 +818,29 @@ def build_structures(filename):
structures.append(to_append) structures.append(to_append)
return structures return structures
def get_lemma_features(et):
lf = et.find('lemma_features')
if lf is None:
return {}
result = {}
for pos in lf.iter('POS'):
rgx_list, _ = build_morphology_regex(pos)
rgx_str = ""
for position in rgx_list:
if position == ".":
rgx_str += " "
elif len(position) == 1:
rgx_str += position
elif len(position) == 3 and position[0] == "[" and position[2] == "]":
rgx_str += position[1]
else:
raise RuntimeError("Strange rgx for lemma_feature...")
assert(rgx_str[0].isupper())
result[rgx_str[0]] = rgx_str.strip().replace(' ', '-')
return result
def get_msd(comp): def get_msd(comp):
d = dict(comp.items()) d = dict(comp.items())
@ -844,6 +852,10 @@ def get_msd(comp):
logging.error(d, file=sys.stderr) logging.error(d, file=sys.stderr)
raise NotImplementedError("MSD?") raise NotImplementedError("MSD?")
def lemma_only_word(msd):
WordLemma = namedtuple('WordLemmaOnly', 'msd most_frequent_text lemma')
return WordLemma(msd=msd, most_frequent_text=lambda *x: None, lemma=None)
class Word: class Word:
def __init__(self, xml, do_msd_translate): def __init__(self, xml, do_msd_translate):
self.lemma = xml.get('lemma') self.lemma = xml.get('lemma')
@ -882,6 +894,7 @@ class WordMsdRenderer:
def __init__(self): def __init__(self):
self.all_words = [] self.all_words = []
self.rendered_words = {} self.rendered_words = {}
self.frequent_words = {}
def add_words(self, words): def add_words(self, words):
self.all_words.extend(words) self.all_words.extend(words)
@ -893,19 +906,30 @@ class WordMsdRenderer:
for lemma, ld in data.items(): for lemma, ld in data.items():
self.rendered_words[lemma] = {} self.rendered_words[lemma] = {}
freq_words = defaultdict(int)
for msd, texts in ld.items(): for msd, texts in ld.items():
rep = max(set(texts), key=texts.count) rep = max(set(texts), key=texts.count)
self.rendered_words[lemma][msd] = rep self.rendered_words[lemma][msd] = (rep, len(texts))
for txt in texts:
freq_words[(msd, txt)] += 1
self.frequent_words[lemma] = []
for (msd, txt), n in sorted(freq_words.items(), key=lambda x: -x[1]):
self.frequent_words[lemma].append((msd, txt, n))
def render(self, lemma, msd): def render(self, lemma, msd):
if lemma in self.rendered_words: if lemma in self.rendered_words:
if msd in self.rendered_words[lemma]: if msd in self.rendered_words[lemma]:
return self.rendered_words[lemma][msd] return self.rendered_words[lemma][msd][0]
def available_words(self, lemma): def available_words(self, lemma):
if lemma in self.rendered_words: if lemma in self.frequent_words:
for msd in self.rendered_words[lemma].keys(): # print("--")
yield (msd, self.rendered_words[lemma][msd]) for msd, text, _ in self.frequent_words[lemma]:
# print(lemma, msd, text, _)
yield (msd, text)
def is_root_id(id_): def is_root_id(id_):
return len(id_.split('.')) == 3 return len(id_.split('.')) == 3
@ -1010,7 +1034,6 @@ class Writer:
elif self.all: elif self.all:
return [word.id, word.text, word.lemma, word.msd] return [word.id, word.text, word.lemma, word.msd]
else: else:
# print("1", word)
if representation is None: if representation is None:
return [word.lemma, word.lemma, "lemma_fallback"] return [word.lemma, word.lemma, "lemma_fallback"]
else: else:
@ -1137,17 +1160,21 @@ class ColocationIds:
if group: if group:
break break
def set_representations(self, structures, word_renderer): def set_representations(self, structures, word_renderer, lemma_msds):
components_dict = {structure.id: structure for structure in structures} components_dict = {structure.id: structure for structure in structures}
for _1, sm in self.data.items(): idx = 1
ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer) for _1, sm in tqdm(self.data.items()):
if idx == 120:
a = 3
ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer, lemma_msds)
idx += 1
def match_file(words, structures): def match_file(words, structures):
matches = {s.id: [] for s in structures} matches = {s.id: [] for s in structures}
for idx, s in enumerate(structures): for idx, s in tqdm(list(enumerate(structures))):
logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id)) # logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id))
for w in words: for w in words:
mhere = s.match(w) mhere = s.match(w)
logging.debug(" GOT: {}".format(len(mhere))) logging.debug(" GOT: {}".format(len(mhere)))
@ -1162,13 +1189,19 @@ def match_file(words, structures):
def main(input_file, structures_file, args): def main(input_file, structures_file, args):
structures = build_structures(structures_file) structures, lemma_msds = load_structures(structures_file)
for s in structures: for s in structures:
logging.debug(str(s)) logging.debug(str(s))
colocation_ids = ColocationIds() colocation_ids = ColocationIds()
word_renderer = WordMsdRenderer() word_renderer = WordMsdRenderer()
# if True:
# with open("match_word.p", "rb") as fp:
# words, matches = pickle.load(fp)
# colocation_ids.add_matches(matches)
# word_renderer.add_words(words)
if args.parallel: if args.parallel:
num_parallel = int(args.parallel) num_parallel = int(args.parallel)
@ -1215,7 +1248,7 @@ def main(input_file, structures_file, args):
# get word renders for lemma/msd # get word renders for lemma/msd
word_renderer.generate_renders() word_renderer.generate_renders()
# figure out representations! # figure out representations!
colocation_ids.set_representations(structures, word_renderer) colocation_ids.set_representations(structures, word_renderer, lemma_msds)
if args.all: if args.all:
Writer.make_all_writer(args).write_out(structures, colocation_ids) Writer.make_all_writer(args).write_out(structures, colocation_ids)