diff --git a/wani.py b/wani.py index a60b7bf..0b3f26a 100644 --- a/wani.py +++ b/wani.py @@ -203,28 +203,58 @@ class ComponentRendition: else: return None - - def render(self, words): - if self.rendition == Rendition.Lemma: - return words[0].lemma - elif self.rendition == Rendition.Lexis: - return self.more - elif self.rendition == Rendition.Unknown: - return None - - elif self.rendition == Rendition.WordForm: - # check more! - return words[0].text - - else: - raise RuntimeError("Unknown rendition: {}".format(self.rendition)) + + def isit(self, rendition): + return self.rendition is rendition @staticmethod - def set_representations(matches, components): + def set_representations(matches, structure): + representations = {c.idx: [True, ""] for c in structure.components} + + def render_all(lst): + return "/".join(set(lst)) + + def render_form(_lst): + return ":(" + for words, agreement in matches: - for _, w in words.items(): - w.representation = ":(" - + if not agreement: + continue + + for w_id, w in words.items(): + component = structure.get_component(w_id) + rep = component.representation + + if rep.isit(Rendition.Lemma): + representations[w_id][0] = False + representations[w_id][1] = w.lemma + elif rep.isit(Rendition.Lexis): + representations[w_id][0] = False + representations[w_id][1] = rep.more + elif rep.isit(Rendition.Unknown): + representations[w_id][0] = False + representations[w_id][1] = "" + + # it HAS to be word_form now + else: + # set correct type first + if type(representations[w_id][1]) is str: + representations[w_id] = ( + [], render_all if rep.more is StructureSelection.All else render_form + ) + representations[w_id][0].append(w.text) + + # just need to set representation to first group... + for w_id, w in matches[0][0].items(): + data = representations[w_id] + + if type(data[1]) is str: + w.representation_failed = data[0] + w.representation = w.lemma if w.representation_failed else data[1] + else: + w.representation_failed = len(data[0]) > 0 + w.representation = w.lemma if w.representation_failed else data[1](data[0]) + def __str__(self): return str(self.rendition) @@ -377,9 +407,6 @@ class Component: self.iter_ctr = 0 - def render_word(self, word): - return self.representation.render(word) - def add_next(self, next_component, link_label, order): self.next_element.append((next_component, link_label, Order.new(order))) @@ -397,9 +424,8 @@ class Component: raise RuntimeError("Unreachable") def set_representation(self, representation): - if len(representation) > 0: - for feature in representation: - self.representation.add_feature(feature) + for feature in representation: + self.representation.add_feature(feature.attrib) def find_next(self, deps, comps, restrs, reprs): to_ret = [] @@ -720,6 +746,9 @@ class Word: self.text = xml.text self.links = defaultdict(list) + self.representation = None + self.representation_failed = False + last_num = self.id.split('.')[-1] if last_num[0] not in '0123456789': last_num = last_num[1:] @@ -827,7 +856,7 @@ class Writer: if self.all: cols = ["Token_ID", "Word_form"] + cols + ["Msd"] else: - cols.append("Representative_form") + cols.extend(["Representative_form", "RF_scenario"]) assert(len(cols) == self.length()) cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS) for thd in cols] @@ -839,7 +868,7 @@ class Writer: return cols def length(self): - return 4 if self.all else 2 + return 4 if self.all else 3 def from_word(self, word): if word is None: @@ -848,7 +877,8 @@ class Writer: return [word.id, word.text, word.lemma, word.msd] else: assert(word.representation is not None) - return [word.lemma, word.representation] + failed = "ok" if word.representation_failed else "lemma_fallback" + return [word.lemma, word.representation, failed] def sorted_rows(self, rows): if self.sort_by < 0 or len(rows) < 2: @@ -880,7 +910,7 @@ class Writer: idx = str(idx + 1) word = m[idx] if idx in m else None to_write.extend(self.from_word(word)) - representation += " " + to_write[-1] + representation += " " + to_write[-2] # make them equal size to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) @@ -1071,4 +1101,3 @@ if __name__ == '__main__': start = time.time() main(args.input, args.structures, args) logging.info("TIME: {}".format(time.time() - start)) -