continued work on representation, almost there...

This commit is contained in:
Ozbolt Menegatti 2019-05-16 01:53:38 +02:00
parent 84a184c44d
commit d2f1e95a8f

83
wani.py
View File

@ -204,26 +204,56 @@ class ComponentRendition:
else:
return None
def render(self, words):
if self.rendition == Rendition.Lemma:
return words[0].lemma
elif self.rendition == Rendition.Lexis:
return self.more
elif self.rendition == Rendition.Unknown:
return None
elif self.rendition == Rendition.WordForm:
# check more!
return words[0].text
else:
raise RuntimeError("Unknown rendition: {}".format(self.rendition))
def isit(self, rendition):
return self.rendition is rendition
@staticmethod
def set_representations(matches, components):
def set_representations(matches, structure):
representations = {c.idx: [True, ""] for c in structure.components}
def render_all(lst):
return "/".join(set(lst))
def render_form(_lst):
return ":("
for words, agreement in matches:
for _, w in words.items():
w.representation = ":("
if not agreement:
continue
for w_id, w in words.items():
component = structure.get_component(w_id)
rep = component.representation
if rep.isit(Rendition.Lemma):
representations[w_id][0] = False
representations[w_id][1] = w.lemma
elif rep.isit(Rendition.Lexis):
representations[w_id][0] = False
representations[w_id][1] = rep.more
elif rep.isit(Rendition.Unknown):
representations[w_id][0] = False
representations[w_id][1] = ""
# it HAS to be word_form now
else:
# set correct type first
if type(representations[w_id][1]) is str:
representations[w_id] = (
[], render_all if rep.more is StructureSelection.All else render_form
)
representations[w_id][0].append(w.text)
# just need to set representation to first group...
for w_id, w in matches[0][0].items():
data = representations[w_id]
if type(data[1]) is str:
w.representation_failed = data[0]
w.representation = w.lemma if w.representation_failed else data[1]
else:
w.representation_failed = len(data[0]) > 0
w.representation = w.lemma if w.representation_failed else data[1](data[0])
def __str__(self):
return str(self.rendition)
@ -377,9 +407,6 @@ class Component:
self.iter_ctr = 0
def render_word(self, word):
return self.representation.render(word)
def add_next(self, next_component, link_label, order):
self.next_element.append((next_component, link_label, Order.new(order)))
@ -397,9 +424,8 @@ class Component:
raise RuntimeError("Unreachable")
def set_representation(self, representation):
if len(representation) > 0:
for feature in representation:
self.representation.add_feature(feature)
self.representation.add_feature(feature.attrib)
def find_next(self, deps, comps, restrs, reprs):
to_ret = []
@ -720,6 +746,9 @@ class Word:
self.text = xml.text
self.links = defaultdict(list)
self.representation = None
self.representation_failed = False
last_num = self.id.split('.')[-1]
if last_num[0] not in '0123456789':
last_num = last_num[1:]
@ -827,7 +856,7 @@ class Writer:
if self.all:
cols = ["Token_ID", "Word_form"] + cols + ["Msd"]
else:
cols.append("Representative_form")
cols.extend(["Representative_form", "RF_scenario"])
assert(len(cols) == self.length())
cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS) for thd in cols]
@ -839,7 +868,7 @@ class Writer:
return cols
def length(self):
return 4 if self.all else 2
return 4 if self.all else 3
def from_word(self, word):
if word is None:
@ -848,7 +877,8 @@ class Writer:
return [word.id, word.text, word.lemma, word.msd]
else:
assert(word.representation is not None)
return [word.lemma, word.representation]
failed = "ok" if word.representation_failed else "lemma_fallback"
return [word.lemma, word.representation, failed]
def sorted_rows(self, rows):
if self.sort_by < 0 or len(rows) < 2:
@ -880,7 +910,7 @@ class Writer:
idx = str(idx + 1)
word = m[idx] if idx in m else None
to_write.extend(self.from_word(word))
representation += " " + to_write[-1]
representation += " " + to_write[-2]
# make them equal size
to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write)))
@ -1071,4 +1101,3 @@ if __name__ == '__main__':
start = time.time()
main(args.input, args.structures, args)
logging.info("TIME: {}".format(time.time() - start))