diff --git a/wani.py b/wani.py index 0a12803..319ebc0 100644 --- a/wani.py +++ b/wani.py @@ -1,7 +1,7 @@ from xml.etree import ElementTree import re from enum import Enum -from collections import defaultdict, namedtuple +from collections import defaultdict, namedtuple, Counter import sys import logging import argparse @@ -218,7 +218,7 @@ class ComponentRendition: return self.rendition is rendition @staticmethod - def set_representations(matches, structure, word_renderer, lemma_msds): + def set_representations(matches, structure, word_renderer): representations = { c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""] for c in structure.components @@ -227,42 +227,39 @@ class ComponentRendition: word_component_id = {} def render_all(component_id, lst, _bw): - rep = "/".join([w.text for w in set(lst)]) if len(lst) > 0 else None + rep = "/".join(set([w.text for w in set(lst)])) if len(lst) > 0 else None matches.representations[component_id] = rep def render_form(component_id, lst, backup_word): if backup_word is not None: lst.append(backup_word) + text_forms = {} + msd_lemma_txt_triplets = Counter([(w.msd, w.lemma, w.text) for w in lst]) + for (msd, lemma, text), _n in reversed(msd_lemma_txt_triplets.most_common()): + text_forms[(msd, lemma)] = text + lst_ctr = [] for word in lst: lst_ctr.append((word.msd, word.lemma)) sorted_lst = sorted(set(lst_ctr), key=lst.count) - if len(lst) > 3: - a = 3 - for word_msd, word_lemma in sorted_lst: if component_id in found_agreements: - other_component_id, other_word, agreements = found_agreements[component_id] - agr = are_agreements_ok(word_msd, other_word.lemma, other_word.msd, agreements) + other_component_id, other_word, agreements, other_texts = found_agreements[component_id] + agr = are_agreements_ok(word_msd, other_word.lemma, other_word.msd, agreements, other_texts) if agr is None: continue matches.representations[other_component_id] = agr if word_lemma is not None: - matches.representations[component_id] = word_renderer.render(word_lemma, word_msd) + matches.representations[component_id] = text_forms[(msd, lemma)] #word_renderer.render(word_lemma, word_msd) break - - # othw = are_agreements_ok(word, found_agreements) - # if othw is not None: - # matches.representations[word_component_id[othw.id]] = othw.most_frequent_text(word_renderer) - # return - def are_agreements_ok(w1_msd, ow_lemma, ow_msd, agreements): - for w2_msd, w2_txt in word_renderer.available_words(ow_lemma): + def are_agreements_ok(w1_msd, ow_lemma, ow_msd, agreements, ow_texts): + for w2_msd, w2_txt in word_renderer.available_words(ow_lemma, ow_texts): if ow_msd[0] != w2_msd[0]: continue @@ -333,22 +330,22 @@ class ComponentRendition: else: assert(rep.isit(Rendition.WordForm)) wf_type, more = rep.more + add = True if wf_type is WordFormSelection.Msd: add = check_msd(w, more) func = render_form elif wf_type is WordFormSelection.All: - add = True func = render_all elif wf_type is WordFormSelection.Any: - add = True func = render_form else: assert(wf_type is WordFormSelection.Agreement) other_w, agreements = more - found_agreements[other_w] = (w_id, w, agreements) + if other_w not in found_agreements: + found_agreements[other_w] = (w_id, w, agreements, []) - add = True + found_agreements[other_w][-1].append((w.msd, w.text)) func = lambda *x: None representations[w_id][1] = func @@ -370,7 +367,8 @@ class ComponentRendition: if type(data[1]) is str: matches.representations[w_id] = None if data[0] else data[1] else: - backup_word = lemma_only_word(lemma_msds[w.msd[0]]) if w.msd[0] in lemma_msds else None + backup_msd = word_renderer.get_lemma_msd(w.lemma) + backup_word = lemma_only_word(backup_msd) data[1](str(w_id), data[0], backup_word) def __str__(self): @@ -853,8 +851,11 @@ def get_msd(comp): raise NotImplementedError("MSD?") def lemma_only_word(msd): - WordLemma = namedtuple('WordLemmaOnly', 'msd most_frequent_text lemma') - return WordLemma(msd=msd, most_frequent_text=lambda *x: None, lemma=None) + if msd is None: + return None + else: + WordLemma = namedtuple('WordLemmaOnly', 'msd most_frequent_text lemma text') + return WordLemma(msd=msd, most_frequent_text=lambda *x: None, lemma=None, text=None) class Word: def __init__(self, xml, do_msd_translate): @@ -895,11 +896,12 @@ class WordMsdRenderer: self.all_words = [] self.rendered_words = {} self.frequent_words = {} + self.lemma_msd = {} def add_words(self, words): self.all_words.extend(words) - def generate_renders(self): + def generate_renders(self, lemma_features): data = defaultdict(lambda: defaultdict(list)) for w in self.all_words: data[w.lemma][w.msd].append(w.text) @@ -907,6 +909,7 @@ class WordMsdRenderer: for lemma, ld in data.items(): self.rendered_words[lemma] = {} freq_words = defaultdict(int) + common_msd = "*" * 10 for msd, texts in ld.items(): rep = max(set(texts), key=texts.count) @@ -914,22 +917,54 @@ class WordMsdRenderer: for txt in texts: freq_words[(msd, txt)] += 1 + + common_msd = self.merge_msd(common_msd, msd) + + self.lemma_msd[lemma] = common_msd self.frequent_words[lemma] = [] for (msd, txt), n in sorted(freq_words.items(), key=lambda x: -x[1]): self.frequent_words[lemma].append((msd, txt, n)) + for lemma in self.lemma_msd.keys(): + cmsd = self.lemma_msd[lemma] + if cmsd[0] in lemma_features: + self.lemma_msd[lemma] = "".join( + l1 if l1 != "-" else l2 for l1, l2 in zip(lemma_features[cmsd[0]], cmsd) + ) + + @staticmethod + def merge_msd(common_msd, new_msd): + def merge_letter(l1, l2): + if l1 == "*": + return l2 + elif l1 != l2: + return "-" + else: + return l1 + + return "".join(merge_letter(l1, l2) for l1, l2 in zip(common_msd, new_msd)) + def render(self, lemma, msd): if lemma in self.rendered_words: if msd in self.rendered_words[lemma]: return self.rendered_words[lemma][msd][0] - def available_words(self, lemma): + def available_words(self, lemma, existing_texts): + counted_texts = Counter(existing_texts) + for (msd, text), n in counted_texts.most_common(): + yield (msd, text) + if lemma in self.frequent_words: - # print("--") for msd, text, _ in self.frequent_words[lemma]: - # print(lemma, msd, text, _) - yield (msd, text) + if (msd, text) not in counted_texts: + yield (msd, text) + + def get_lemma_msd(self, lemma): + if lemma in self.lemma_msd and self.lemma_msd[lemma][0] != '-': + return self.lemma_msd[lemma] + else: + return None def is_root_id(id_): return len(id_.split('.')) == 3 @@ -1160,13 +1195,11 @@ class ColocationIds: if group: break - def set_representations(self, structures, word_renderer, lemma_msds): + def set_representations(self, structures, word_renderer): components_dict = {structure.id: structure for structure in structures} idx = 1 for _1, sm in tqdm(self.data.items()): - if idx == 120: - a = 3 - ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer, lemma_msds) + ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer) idx += 1 @@ -1246,9 +1279,9 @@ def main(input_file, structures_file, args): word_renderer.add_words(words) # get word renders for lemma/msd - word_renderer.generate_renders() + word_renderer.generate_renders(lemma_msds) # figure out representations! - colocation_ids.set_representations(structures, word_renderer, lemma_msds) + colocation_ids.set_representations(structures, word_renderer) if args.all: Writer.make_all_writer(args).write_out(structures, colocation_ids)