Intermediate UGLY CODE commit. Working more on representations

This commit is contained in:
Ozbolt Menegatti 2019-05-22 11:22:07 +02:00
parent dce55d04a3
commit d14efff709

203
wani.py
View File

@ -136,6 +136,7 @@ class WordFormSelection(Enum):
All = 0
Msd = 1
Agreement = 2
Any = 3
class Order(Enum):
FromTo = 0
@ -179,7 +180,6 @@ class ComponentRendition:
self.rendition = r
def _set_more(self, m):
assert(self.more is None and m is not None)
self.more = m
def add_feature(self, feature):
@ -188,6 +188,7 @@ class ComponentRendition:
self._set_rendition(Rendition.Lemma)
elif feature['rendition'] == "word_form":
self._set_rendition(Rendition.WordForm)
self._set_more((WordFormSelection.Any, None))
elif feature['rendition'] == "lexis":
self._set_rendition(Rendition.Lexis)
self._set_more(feature['string'])
@ -217,14 +218,37 @@ class ComponentRendition:
@staticmethod
def set_representations(matches, structure):
representations = {c.idx: [True, ""] for c in structure.components}
representations = {
c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""]
for c in structure.components
}
representations_to_check = []
word_component_id = {}
# doprint = structure.id == '1' and matches[0]['1'].text.startswith('evrop') and matches[0]['2'].text.startswith('prv')
doprint = False
def render_all(lst):
return "/".join(set(lst))
return "/".join([w.text for w in set(lst)])
def render_form(lst):
# find most frequent
return max(set(lst), key=lst.count)
sorted_lst = sorted(set(lst), key=lst.count)
for word in sorted_lst:
othw = are_agreements_ok(word, representations_to_check)
if othw is not None:
if doprint:
print("AOK", othw.text, othw)
matches.representations[word_component_id[othw.id]] = othw.text
matches.representations[word_component_id[word.id]] = word.text
return
def are_agreements_ok(word, words_to_try):
for w_id, other_word, agreements in words_to_try:
if check_agreement(word, other_word, agreements):
if doprint:
print("GOOD :)")
return other_word
def check_msd(word, selectors):
for key, value in selectors.items():
@ -239,6 +263,9 @@ class ComponentRendition:
return True
def check_agreement(w1, w2, agreements):
if doprint:
print("CHECK", w1.text, w1, w2.text, w2)
for agr_case in agreements:
t1 = w1.msd[0]
v1 = TAGSET[t1].index(agr_case)
@ -263,11 +290,12 @@ class ComponentRendition:
return True
for words in matches:
for words in matches.matches:
# first pass, check everything but agreements
for w_id, w in words.items():
component = structure.get_component(w_id)
rep = component.representation
word_component_id[w.id] = w_id
if rep.isit(Rendition.Lemma):
representations[w_id][0] = False
@ -281,45 +309,72 @@ class ComponentRendition:
# it HAS to be word_form now
else:
assert(rep.isit(Rendition.WordForm))
wf_type, more = rep.more
# set correct type first
if type(representations[w_id][1]) is str:
representations[w_id] = (
[], render_all if wf_type is WordFormSelection.All else render_form
)
if wf_type is WordFormSelection.All:
add = True
func = render_all
elif wf_type is WordFormSelection.Msd:
add = check_msd(w, more)
func = render_form
elif wf_type is WordFormSelection.Any:
add = True
func = render_form
else:
assert(wf_type is WordFormSelection.Agreement)
other_w, agreements = more
add = check_agreement(w, words[other_w], agreements)
representations_to_check.append((other_w, w, agreements))
add = True
func = lambda x: None
if add:
representations[w_id][0].append(w.text)
doprint = matches[0]['1'].text.startswith('evrop')
# just need to set representation to first group...
for w_id, w in matches[0].items():
data = representations[w_id]
if doprint:
print(data)
if type(data[1]) is str:
w.representation_failed = data[0]
w.representation = w.lemma if w.representation_failed else data[1]
else:
w.representation_failed = len(data[0]) == 0
w.representation = w.lemma if w.representation_failed else data[1](data[0])
if doprint:
print(w.representation_failed, w.representation)
representations[w_id][0].append(w)
representations[w_id][1] = func
if doprint:
print(len(matches), len(representations_to_check))
# for w1i, w2i, agreements in representations_to_check:
# w1, w2 = words[w1i], words[w2i]
# if doprint:
# print("? ", w1.msd, w2.msd, end="")
# if w2i not in bad_words:
#
# if check_agreement(w1, w2, agreements):
# representations[w1i][0].append(w1.text)
# if doprint:
# print(" :)")
# elif doprint:
# print(" :(")
# elif doprint:
# print(" :((")
# just need to set representation to first group,
# but in correct order, agreements last!
representation_sorted_words = []
for w_id, w in matches.matches[0].items():
rep = component.representation
if rep.isit(Rendition.WordForm) and rep.more[0] is WordFormSelection.Agreement:
representation_sorted_words.append((w_id, w))
else:
representation_sorted_words.insert(0, (w_id, w))
for w_id, w in representation_sorted_words:
data = representations[w_id]
if doprint:
print([(r.text, r.lemma, r.msd) for r in data[0]])
if type(data[1]) is str:
matches.representations[w_id] = None if data[0] else data[1]
elif len(data[0]) == 0:
matches.representations[w_id] = None
else:
data[1](data[0])
if doprint:
print(matches.representations)
print('--')
def __str__(self):
@ -781,9 +836,6 @@ class Word:
self.text = xml.text
self.links = defaultdict(list)
self.representation = None
self.representation_failed = False
last_num = self.id.split('.')[-1]
if last_num[0] not in '0123456789':
last_num = last_num[1:]
@ -807,6 +859,29 @@ class Word:
return self.links[link]
class WordMsdRenderer:
def __init__(self):
self.all_words = []
self.rendered_words = {}
def add_word(self, word):
self.all_words.append(word)
def generate_renders(self):
data = defaultdict(lambda: defaultdict([]))
for w in self.all_words:
data[w.lemma][w.msd].append(w.text)
for lemma, ld in data.items():
self.rendered_words[lemma] = {}
for msd, texts in ld.items():
rep = max(set(texts), key=texts.count)
self.rendered_words[lemma][msd] = rep
def render(self, lemma, msd):
if lemma in self.rendered_words:
if msd in self.rendered_words[lemma]:
return self.rendered_words[lemma][msd]
def is_root_id(id_):
return len(id_.split('.')) == 3
@ -905,15 +980,17 @@ class Writer:
def length(self):
return 4 if self.all else 3
def from_word(self, word):
def from_word(self, word, representation):
if word is None:
return [""] * self.length()
elif self.all:
return [word.id, word.text, word.lemma, word.msd]
else:
assert(word.representation is not None)
failed = "lemma_fallback" if word.representation_failed else "ok"
return [word.lemma, word.representation, failed]
print("1", word)
if representation is None:
return [word.lemma, word.lemma, "lemma_fallback"]
else:
return [word.lemma, representation, "ok"]
def sorted_rows(self, rows):
if self.sort_by < 0 or len(rows) < 2:
@ -937,14 +1014,16 @@ class Writer:
def write_out_worker(self, file_handler, structure_id, components, colocation_ids):
rows = []
for cid, m, freq in colocation_ids.get_matches_for(structure_id, not self.all):
for cid, m, freq, rprsnt in colocation_ids.get_matches_for(structure_id, not self.all):
to_write = []
representation = ""
for idx, _comp in enumerate(components):
idx = str(idx + 1)
word = m[idx] if idx in m else None
to_write.extend(self.from_word(word))
print(rprsnt)
rep = rprsnt[idx] if idx in rprsnt else None
to_write.extend(self.from_word(word, rep))
representation += " " + to_write[-2]
# make them equal size
@ -993,6 +1072,19 @@ class Writer:
if not self.multiple_output:
fp_close(fp)
class StructureMatch:
def __init__(self, match_id, structure_id):
self.match_id = match_id
self.structure_id = structure_id
self.matches = []
self.representations = {}
def append(self, match):
self.matches.append(match)
def __len__(self):
return len(self.matches)
class ColocationIds:
def __init__(self):
@ -1000,41 +1092,32 @@ class ColocationIds:
self.min_frequency = args.min_freq
def _add_match(self, key, sid, match):
if key in self.data:
self.data[key][1].append(match)
else:
self.data[key] = (str(len(self.data) + 1), [match], sid)
if key not in self.data:
self.data[key] = StructureMatch(str(len(self.data) + 1), sid)
self.data[key].append(match)
def get(self, key, n):
return self.data[key][n]
def num(self, key):
return str(len(self.get(key, 1)))
def to_id(self, key):
return self.get(key, 0)
def add_matches(self, matches):
for sid, nms in matches.items():
for nm in nms:
self._add_match(nm[1], sid, nm[0])
def get_matches_for(self, structure_id, group):
for _cid_tup, (cid, cid_matches, sid) in self.data.items():
if sid != structure_id:
for _cid_tup, sm in self.data.items():
if sm.structure_id != structure_id:
continue
for words in cid_matches:
yield (cid, words, len(cid_matches))
for words in sm.matches:
yield (sm.match_id, words, len(sm), sm.representations)
if group:
break
def set_representations(self, structures):
components_dict = {structure.id: structure for structure in structures}
for _1, (_2, cid_matches, sid) in self.data.items():
if _2 == '1309':
a = 1
ComponentRendition.set_representations(cid_matches, components_dict[sid])
for _1, sm in self.data.items():
ComponentRendition.set_representations(sm, components_dict[sm.structure_id])
def match_file(words, structures):