diff --git a/wani.py b/wani.py index 60d8547..0a12803 100644 --- a/wani.py +++ b/wani.py @@ -1,7 +1,7 @@ from xml.etree import ElementTree import re from enum import Enum -from collections import defaultdict +from collections import defaultdict, namedtuple import sys import logging import argparse @@ -12,6 +12,7 @@ import concurrent.futures import tempfile from msd_translate import MSD_TRANSLATE +from tqdm import tqdm MAX_NUM_COMPONENTS = 5 @@ -217,7 +218,7 @@ class ComponentRendition: return self.rendition is rendition @staticmethod - def set_representations(matches, structure, word_renderer): + def set_representations(matches, structure, word_renderer, lemma_msds): representations = { c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""] for c in structure.components @@ -225,24 +226,34 @@ class ComponentRendition: found_agreements = {} word_component_id = {} - # doprint = structure.id == '1' and matches[0]['1'].text.startswith('evrop') and matches[0]['2'].text.startswith('prv') - doprint = False - - def render_all(component_id, lst): - matches.representations[component_id] = "/".join([w.text for w in set(lst)]) + def render_all(component_id, lst, _bw): + rep = "/".join([w.text for w in set(lst)]) if len(lst) > 0 else None + matches.representations[component_id] = rep - def render_form(component_id, lst): - sorted_lst = sorted(set(lst), key=lst.count) - for word in sorted_lst: + def render_form(component_id, lst, backup_word): + if backup_word is not None: + lst.append(backup_word) + + lst_ctr = [] + for word in lst: + lst_ctr.append((word.msd, word.lemma)) + sorted_lst = sorted(set(lst_ctr), key=lst.count) + + if len(lst) > 3: + a = 3 + + for word_msd, word_lemma in sorted_lst: if component_id in found_agreements: other_component_id, other_word, agreements = found_agreements[component_id] - print(word.lemma, other_word.lemma, component_id, other_component_id, word.msd, word.msd) - agr = are_agreements_ok(word.msd, other_word.lemma, other_word.msd, agreements) + agr = are_agreements_ok(word_msd, other_word.lemma, other_word.msd, agreements) if agr is None: continue + matches.representations[other_component_id] = agr - matches.representations[word_component_id[word.id]] = word.most_frequent_text(word_renderer) + if word_lemma is not None: + matches.representations[component_id] = word_renderer.render(word_lemma, word_msd) + break # othw = are_agreements_ok(word, found_agreements) @@ -255,10 +266,7 @@ class ComponentRendition: if ow_msd[0] != w2_msd[0]: continue - print(w1_msd, w2_msd) if check_agreement(w1_msd, w2_msd, agreements): - if doprint: - print("GOOD :)") return w2_txt def check_msd(word, selectors): @@ -279,7 +287,6 @@ class ComponentRendition: # if not in msd, some strange msd was tries, skipping... if agr_case not in TAGSET[t1]: logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd1)) - print("BAAAD") return False v1 = TAGSET[t1].index(agr_case) @@ -293,7 +300,6 @@ class ComponentRendition: t2 = msd2[0] if agr_case not in TAGSET[t2]: logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd2)) - print("BAAAD") return False v2 = TAGSET[t2].index(agr_case) if v2 + 1 >= len(msd2): @@ -328,44 +334,26 @@ class ComponentRendition: assert(rep.isit(Rendition.WordForm)) wf_type, more = rep.more - if wf_type is WordFormSelection.All: - add = True - func = render_all - elif wf_type is WordFormSelection.Msd: + if wf_type is WordFormSelection.Msd: add = check_msd(w, more) func = render_form + elif wf_type is WordFormSelection.All: + add = True + func = render_all elif wf_type is WordFormSelection.Any: add = True func = render_form else: assert(wf_type is WordFormSelection.Agreement) other_w, agreements = more - found_agreements[other_w] = (w_id, w.lemma, agreements) + found_agreements[other_w] = (w_id, w, agreements) + add = True func = lambda *x: None + representations[w_id][1] = func if add: representations[w_id][0].append(w) - representations[w_id][1] = func - - if doprint: - print(len(matches), len(found_agreements)) - - # for w1i, w2i, agreements in found_agreements: - # w1, w2 = words[w1i], words[w2i] - # if doprint: - # print("? ", w1.msd, w2.msd, end="") - - # if w2i not in bad_words: - # - # if check_agreement(w1, w2, agreements): - # representations[w1i][0].append(w1.text) - # if doprint: - # print(" :)") - # elif doprint: - # print(" :(") - # elif doprint: - # print(" :((") # just need to set representation to first group, # but in correct order, agreements last! @@ -379,20 +367,12 @@ class ComponentRendition: for w_id, w in representation_sorted_words: data = representations[w_id] - if doprint: - print([(r.text, r.lemma, r.msd) for r in data[0]]) - if type(data[1]) is str: matches.representations[w_id] = None if data[0] else data[1] - elif len(data[0]) == 0: - matches.representations[w_id] = None else: - data[1](str(w_id), data[0]) + backup_word = lemma_only_word(lemma_msds[w.msd[0]]) if w.msd[0] in lemma_msds else None + data[1](str(w_id), data[0], backup_word) - if doprint: - print(matches.representations) - print('--') - def __str__(self): return str(self.rendition) @@ -468,7 +448,7 @@ def build_morphology_regex(restriction): return False return True - return " ".join(rgx), matcher + return rgx, matcher def build_lexis_regex(restriction): @@ -493,7 +473,8 @@ class Restriction: restriction_type = restriction_tag.get('type') if restriction_type == "morphology": self.type = RestrictionType.Morphology - self.present, self.matcher = build_morphology_regex(list(restriction_tag)) + present, self.matcher = build_morphology_regex(list(restriction_tag)) + self.present = " ".join(present) elif restriction_type == "lexis": self.type = RestrictionType.Lexis self.present, self.matcher = build_lexis_regex(list(restriction_tag)) @@ -822,17 +803,44 @@ class SyntacticStructure: # to_ret.append((m, self.check_agreements(m))) -def build_structures(filename): - structures = [] +def load_structures(filename): with open(filename, 'r') as fp: et = ElementTree.XML(fp.read()) - for structure in et.iter('syntactic_structure'): - to_append = SyntacticStructure.from_xml(structure) - if to_append is None: - continue - structures.append(to_append) + + return build_structures(et), get_lemma_features(et) + +def build_structures(et): + structures = [] + for structure in et.iter('syntactic_structure'): + to_append = SyntacticStructure.from_xml(structure) + if to_append is None: + continue + structures.append(to_append) return structures +def get_lemma_features(et): + lf = et.find('lemma_features') + if lf is None: + return {} + + result = {} + for pos in lf.iter('POS'): + rgx_list, _ = build_morphology_regex(pos) + rgx_str = "" + for position in rgx_list: + if position == ".": + rgx_str += " " + elif len(position) == 1: + rgx_str += position + elif len(position) == 3 and position[0] == "[" and position[2] == "]": + rgx_str += position[1] + else: + raise RuntimeError("Strange rgx for lemma_feature...") + + assert(rgx_str[0].isupper()) + result[rgx_str[0]] = rgx_str.strip().replace(' ', '-') + + return result def get_msd(comp): d = dict(comp.items()) @@ -844,6 +852,10 @@ def get_msd(comp): logging.error(d, file=sys.stderr) raise NotImplementedError("MSD?") +def lemma_only_word(msd): + WordLemma = namedtuple('WordLemmaOnly', 'msd most_frequent_text lemma') + return WordLemma(msd=msd, most_frequent_text=lambda *x: None, lemma=None) + class Word: def __init__(self, xml, do_msd_translate): self.lemma = xml.get('lemma') @@ -882,6 +894,7 @@ class WordMsdRenderer: def __init__(self): self.all_words = [] self.rendered_words = {} + self.frequent_words = {} def add_words(self, words): self.all_words.extend(words) @@ -893,19 +906,30 @@ class WordMsdRenderer: for lemma, ld in data.items(): self.rendered_words[lemma] = {} + freq_words = defaultdict(int) + for msd, texts in ld.items(): rep = max(set(texts), key=texts.count) - self.rendered_words[lemma][msd] = rep - + self.rendered_words[lemma][msd] = (rep, len(texts)) + + for txt in texts: + freq_words[(msd, txt)] += 1 + + self.frequent_words[lemma] = [] + for (msd, txt), n in sorted(freq_words.items(), key=lambda x: -x[1]): + self.frequent_words[lemma].append((msd, txt, n)) + def render(self, lemma, msd): if lemma in self.rendered_words: if msd in self.rendered_words[lemma]: - return self.rendered_words[lemma][msd] + return self.rendered_words[lemma][msd][0] def available_words(self, lemma): - if lemma in self.rendered_words: - for msd in self.rendered_words[lemma].keys(): - yield (msd, self.rendered_words[lemma][msd]) + if lemma in self.frequent_words: + # print("--") + for msd, text, _ in self.frequent_words[lemma]: + # print(lemma, msd, text, _) + yield (msd, text) def is_root_id(id_): return len(id_.split('.')) == 3 @@ -1010,7 +1034,6 @@ class Writer: elif self.all: return [word.id, word.text, word.lemma, word.msd] else: - # print("1", word) if representation is None: return [word.lemma, word.lemma, "lemma_fallback"] else: @@ -1137,17 +1160,21 @@ class ColocationIds: if group: break - def set_representations(self, structures, word_renderer): + def set_representations(self, structures, word_renderer, lemma_msds): components_dict = {structure.id: structure for structure in structures} - for _1, sm in self.data.items(): - ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer) + idx = 1 + for _1, sm in tqdm(self.data.items()): + if idx == 120: + a = 3 + ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer, lemma_msds) + idx += 1 def match_file(words, structures): matches = {s.id: [] for s in structures} - for idx, s in enumerate(structures): - logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id)) + for idx, s in tqdm(list(enumerate(structures))): + # logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id)) for w in words: mhere = s.match(w) logging.debug(" GOT: {}".format(len(mhere))) @@ -1162,13 +1189,19 @@ def match_file(words, structures): def main(input_file, structures_file, args): - structures = build_structures(structures_file) + structures, lemma_msds = load_structures(structures_file) for s in structures: logging.debug(str(s)) colocation_ids = ColocationIds() word_renderer = WordMsdRenderer() + # if True: + # with open("match_word.p", "rb") as fp: + # words, matches = pickle.load(fp) + # colocation_ids.add_matches(matches) + # word_renderer.add_words(words) + if args.parallel: num_parallel = int(args.parallel) @@ -1215,7 +1248,7 @@ def main(input_file, structures_file, args): # get word renders for lemma/msd word_renderer.generate_renders() # figure out representations! - colocation_ids.set_representations(structures, word_renderer) + colocation_ids.set_representations(structures, word_renderer, lemma_msds) if args.all: Writer.make_all_writer(args).write_out(structures, colocation_ids)