diff --git a/wani.py b/wani.py index ef49952..8120cd1 100644 --- a/wani.py +++ b/wani.py @@ -136,6 +136,7 @@ class WordFormSelection(Enum): All = 0 Msd = 1 Agreement = 2 + Any = 3 class Order(Enum): FromTo = 0 @@ -179,7 +180,6 @@ class ComponentRendition: self.rendition = r def _set_more(self, m): - assert(self.more is None and m is not None) self.more = m def add_feature(self, feature): @@ -188,6 +188,7 @@ class ComponentRendition: self._set_rendition(Rendition.Lemma) elif feature['rendition'] == "word_form": self._set_rendition(Rendition.WordForm) + self._set_more((WordFormSelection.Any, None)) elif feature['rendition'] == "lexis": self._set_rendition(Rendition.Lexis) self._set_more(feature['string']) @@ -217,15 +218,38 @@ class ComponentRendition: @staticmethod def set_representations(matches, structure): - representations = {c.idx: [True, ""] for c in structure.components} + representations = { + c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""] + for c in structure.components + } + representations_to_check = [] + word_component_id = {} + + # doprint = structure.id == '1' and matches[0]['1'].text.startswith('evrop') and matches[0]['2'].text.startswith('prv') + doprint = False def render_all(lst): - return "/".join(set(lst)) + return "/".join([w.text for w in set(lst)]) def render_form(lst): - # find most frequent - return max(set(lst), key=lst.count) + sorted_lst = sorted(set(lst), key=lst.count) + for word in sorted_lst: + othw = are_agreements_ok(word, representations_to_check) + if othw is not None: + if doprint: + print("AOK", othw.text, othw) + + matches.representations[word_component_id[othw.id]] = othw.text + matches.representations[word_component_id[word.id]] = word.text + return + def are_agreements_ok(word, words_to_try): + for w_id, other_word, agreements in words_to_try: + if check_agreement(word, other_word, agreements): + if doprint: + print("GOOD :)") + return other_word + def check_msd(word, selectors): for key, value in selectors.items(): t = word.msd[0] @@ -239,6 +263,9 @@ class ComponentRendition: return True def check_agreement(w1, w2, agreements): + if doprint: + print("CHECK", w1.text, w1, w2.text, w2) + for agr_case in agreements: t1 = w1.msd[0] v1 = TAGSET[t1].index(agr_case) @@ -263,11 +290,12 @@ class ComponentRendition: return True - - for words in matches: + for words in matches.matches: + # first pass, check everything but agreements for w_id, w in words.items(): component = structure.get_component(w_id) rep = component.representation + word_component_id[w.id] = w_id if rep.isit(Rendition.Lemma): representations[w_id][0] = False @@ -281,45 +309,72 @@ class ComponentRendition: # it HAS to be word_form now else: + assert(rep.isit(Rendition.WordForm)) wf_type, more = rep.more - # set correct type first - if type(representations[w_id][1]) is str: - representations[w_id] = ( - [], render_all if wf_type is WordFormSelection.All else render_form - ) - if wf_type is WordFormSelection.All: add = True + func = render_all elif wf_type is WordFormSelection.Msd: add = check_msd(w, more) + func = render_form + elif wf_type is WordFormSelection.Any: + add = True + func = render_form else: assert(wf_type is WordFormSelection.Agreement) other_w, agreements = more - add = check_agreement(w, words[other_w], agreements) + representations_to_check.append((other_w, w, agreements)) + add = True + func = lambda x: None if add: - representations[w_id][0].append(w.text) + representations[w_id][0].append(w) + representations[w_id][1] = func - doprint = matches[0]['1'].text.startswith('evrop') + if doprint: + print(len(matches), len(representations_to_check)) - # just need to set representation to first group... - for w_id, w in matches[0].items(): + # for w1i, w2i, agreements in representations_to_check: + # w1, w2 = words[w1i], words[w2i] + # if doprint: + # print("? ", w1.msd, w2.msd, end="") + + # if w2i not in bad_words: + # + # if check_agreement(w1, w2, agreements): + # representations[w1i][0].append(w1.text) + # if doprint: + # print(" :)") + # elif doprint: + # print(" :(") + # elif doprint: + # print(" :((") + + # just need to set representation to first group, + # but in correct order, agreements last! + representation_sorted_words = [] + for w_id, w in matches.matches[0].items(): + rep = component.representation + if rep.isit(Rendition.WordForm) and rep.more[0] is WordFormSelection.Agreement: + representation_sorted_words.append((w_id, w)) + else: + representation_sorted_words.insert(0, (w_id, w)) + + for w_id, w in representation_sorted_words: data = representations[w_id] if doprint: - print(data) + print([(r.text, r.lemma, r.msd) for r in data[0]]) if type(data[1]) is str: - w.representation_failed = data[0] - w.representation = w.lemma if w.representation_failed else data[1] + matches.representations[w_id] = None if data[0] else data[1] + elif len(data[0]) == 0: + matches.representations[w_id] = None else: - w.representation_failed = len(data[0]) == 0 - w.representation = w.lemma if w.representation_failed else data[1](data[0]) + data[1](data[0]) - if doprint: - print(w.representation_failed, w.representation) - if doprint: + print(matches.representations) print('--') def __str__(self): @@ -781,9 +836,6 @@ class Word: self.text = xml.text self.links = defaultdict(list) - self.representation = None - self.representation_failed = False - last_num = self.id.split('.')[-1] if last_num[0] not in '0123456789': last_num = last_num[1:] @@ -807,6 +859,29 @@ class Word: return self.links[link] +class WordMsdRenderer: + def __init__(self): + self.all_words = [] + self.rendered_words = {} + + def add_word(self, word): + self.all_words.append(word) + + def generate_renders(self): + data = defaultdict(lambda: defaultdict([])) + for w in self.all_words: + data[w.lemma][w.msd].append(w.text) + + for lemma, ld in data.items(): + self.rendered_words[lemma] = {} + for msd, texts in ld.items(): + rep = max(set(texts), key=texts.count) + self.rendered_words[lemma][msd] = rep + + def render(self, lemma, msd): + if lemma in self.rendered_words: + if msd in self.rendered_words[lemma]: + return self.rendered_words[lemma][msd] def is_root_id(id_): return len(id_.split('.')) == 3 @@ -905,15 +980,17 @@ class Writer: def length(self): return 4 if self.all else 3 - def from_word(self, word): + def from_word(self, word, representation): if word is None: return [""] * self.length() elif self.all: return [word.id, word.text, word.lemma, word.msd] else: - assert(word.representation is not None) - failed = "lemma_fallback" if word.representation_failed else "ok" - return [word.lemma, word.representation, failed] + print("1", word) + if representation is None: + return [word.lemma, word.lemma, "lemma_fallback"] + else: + return [word.lemma, representation, "ok"] def sorted_rows(self, rows): if self.sort_by < 0 or len(rows) < 2: @@ -937,14 +1014,16 @@ class Writer: def write_out_worker(self, file_handler, structure_id, components, colocation_ids): rows = [] - for cid, m, freq in colocation_ids.get_matches_for(structure_id, not self.all): + for cid, m, freq, rprsnt in colocation_ids.get_matches_for(structure_id, not self.all): to_write = [] representation = "" for idx, _comp in enumerate(components): idx = str(idx + 1) word = m[idx] if idx in m else None - to_write.extend(self.from_word(word)) + print(rprsnt) + rep = rprsnt[idx] if idx in rprsnt else None + to_write.extend(self.from_word(word, rep)) representation += " " + to_write[-2] # make them equal size @@ -993,6 +1072,19 @@ class Writer: if not self.multiple_output: fp_close(fp) +class StructureMatch: + def __init__(self, match_id, structure_id): + self.match_id = match_id + self.structure_id = structure_id + + self.matches = [] + self.representations = {} + + def append(self, match): + self.matches.append(match) + + def __len__(self): + return len(self.matches) class ColocationIds: def __init__(self): @@ -1000,41 +1092,32 @@ class ColocationIds: self.min_frequency = args.min_freq def _add_match(self, key, sid, match): - if key in self.data: - self.data[key][1].append(match) - else: - self.data[key] = (str(len(self.data) + 1), [match], sid) + if key not in self.data: + self.data[key] = StructureMatch(str(len(self.data) + 1), sid) + self.data[key].append(match) def get(self, key, n): return self.data[key][n] - def num(self, key): - return str(len(self.get(key, 1))) - - def to_id(self, key): - return self.get(key, 0) - def add_matches(self, matches): for sid, nms in matches.items(): for nm in nms: self._add_match(nm[1], sid, nm[0]) def get_matches_for(self, structure_id, group): - for _cid_tup, (cid, cid_matches, sid) in self.data.items(): - if sid != structure_id: + for _cid_tup, sm in self.data.items(): + if sm.structure_id != structure_id: continue - for words in cid_matches: - yield (cid, words, len(cid_matches)) + for words in sm.matches: + yield (sm.match_id, words, len(sm), sm.representations) if group: break def set_representations(self, structures): components_dict = {structure.id: structure for structure in structures} - for _1, (_2, cid_matches, sid) in self.data.items(): - if _2 == '1309': - a = 1 - ComponentRendition.set_representations(cid_matches, components_dict[sid]) + for _1, sm in self.data.items(): + ComponentRendition.set_representations(sm, components_dict[sm.structure_id]) def match_file(words, structures):