From c6440162b8d0bdd3befbecade42e5a801ce52946 Mon Sep 17 00:00:00 2001 From: Ozbolt Menegatti Date: Sun, 9 Jun 2019 22:25:58 +0200 Subject: [PATCH] NOT WORKING inbetween commit --- wani.py | 194 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 145 insertions(+), 49 deletions(-) diff --git a/wani.py b/wani.py index 43f50f6..7724183 100644 --- a/wani.py +++ b/wani.py @@ -891,6 +891,7 @@ class WordMsdRenderer: common_msd = "*" * 10 for msd, texts in ld.items(): + # TODO: this should be out of generate_renders... self.num_words[(lemma, msd[0])] += len(texts) rep = max(set(texts), key=texts.count) @@ -1022,15 +1023,16 @@ def load_tei_file(filename, skip_id_check, do_msd_translate, pc_tag, status): class Formatter: - def __init__(self, colocation_ids): + def __init__(self, colocation_ids, word_renderer): self.colocation_ids = colocation_ids + self.word_renderer = word_renderer self.additional_init() def header_repeat(self): raise NotImplementedError("Header repeat formatter not implemented") def header_right(self): raise NotImplementedError("Header right formatter not implemented") - def content_repeat(self, words, representations, idx): + def content_repeat(self, words, representations, idx, sidx): raise NotImplementedError("Content repeat formatter not implemented") def content_right(self, freq): raise NotImplementedError("Content right formatter not implemented") @@ -1042,6 +1044,11 @@ class Formatter: def length(self): return len(self.header_repeat()) + def set_structure(self, structure): + pass + def new_match(self, match): + pass + class OutFormatter(Formatter): def additional_init(self): @@ -1053,7 +1060,7 @@ class OutFormatter(Formatter): def header_right(self): return ["Joint_representative_form", "Frequency"] - def content_repeat(self, words, representations, idx): + def content_repeat(self, words, representations, idx, _sidx): word = words[idx] if idx not in representations: return [word.lemma, "", ""] @@ -1082,7 +1089,7 @@ class AllFormatter(Formatter): def header_right(self): return [] - def content_repeat(self, words, representations, idx): + def content_repeat(self, words, _representations, idx, _sidx): word = words[idx] return [word.id, word.text, word.lemma, word.msd] @@ -1092,6 +1099,67 @@ class AllFormatter(Formatter): def group(self): return False +class StatsFormatter(Formatter): + def additional_init(self): + self.stats = None + self.jppb = None + self.corew = None + + def set_structure(self, structure): + jppb = [] + corew = [] + + for component in structure.components: + if component.type == ComponentType.Core2w: + jppb.append(component.idx) + if component.type != ComponentType.Other: + corew.append(component.idx) + + assert(len(jppb) == 2) + self.jppb = tuple(jppb) + self.corew = tuple(corew) + + def new_match(self, match): + jppb_forms = set() + self.stats = {"freq": {}} + + for words in match.matches: + cw1 = words[self.jppb[0]] + cw2 = words[self.jppb[1]] + jppb_forms.add((cw1.text, cw2.text)) + + for cid, word in match.matches[0].items(): + if cid in self.corew: + self.stats["freq"][cid] = self.word_renderer.num_words[( + word.text, word.msd[0])] + + self.stats['fc'] = match.distinct_forms() + self.stats['fc'] = len(jppb_forms) + self.stats['n'] = self.word_renderer.num_all_words() + + def header_repeat(self): + return ["Distribution", "Delta"] + + def header_right(self): + return ["LogDice_core", "LogDice_all", "Distinct_forms"] + + def content_repeat(self, words, representations, idx, sidx): + word = words[idx] + key = (sidx, idx, word.lemma) + distribution = self.colocation_ids.dispersions[key] + + # TODO... + delta = "?" + + return [str(distribution), delta] + + def content_right(self, freq): + # TODO... + return ["?"] * 3 + + def group(self): + return True + class Writer: @staticmethod @@ -1099,19 +1167,18 @@ class Writer: return (args.multiple_output, int(args.sort_by), args.sort_reversed) @staticmethod - def make_output_writer(args, colocation_ids): + def make_output_writer(args, colocation_ids, word_renderer): params = Writer.other_params(args) - return Writer(args.output, OutFormatter(colocation_ids), params) + return Writer(args.output, OutFormatter(colocation_ids, word_renderer), params) @staticmethod - def make_all_writer(args, colocation_ids): - return Writer(args.all, AllFormatter(colocation_ids), None) + def make_all_writer(args, colocation_ids, word_renderer): + return Writer(args.all, AllFormatter(colocation_ids, word_renderer), None) @staticmethod - # todo... - def make_stats_writer(args): + def make_stats_writer(args, colocation_ids, word_renderer): params = Writer.other_params(args) - return Writer(args.stats, None, None) + return Writer(args.stats, StatsFormatter(colocation_ids, word_renderer), params) def __init__(self, file_out, formatter, params): if params is None: @@ -1156,29 +1223,35 @@ class Writer: def write_header(self, file_handler): file_handler.write(", ".join(self.header()) + "\n") - def write_out_worker(self, file_handler, structure_id, components, colocation_ids): + def write_out_worker(self, file_handler, structure, colocation_ids): rows = [] - group = self.formatter.group() + components = structure.components - for cid, m, freq, rprsnt in colocation_ids.get_matches_for(structure_id, group): - to_write = [] + for match in colocation_ids.get_matches_for(structure): + self.formatter.new_match(match) - for idx, _comp in enumerate(components): - idx = str(idx + 1) - if idx not in m: - to_write.extend([""] * self.formatter.length()) - else: - to_write.extend(self.formatter.content_repeat(m, rprsnt, idx)) + for words in match.matches: + to_write = [] - # make them equal size - to_write.extend([""] * (MAX_NUM_COMPONENTS * self.formatter.length() - len(to_write))) + for idx, _comp in enumerate(components): + idx = str(idx + 1) + if idx not in words: + to_write.extend([""] * self.formatter.length()) + else: + to_write.extend(self.formatter.content_repeat(words, match.representations, idx, structure.id)) - # structure_id and colocation_id - to_write = [structure_id] + to_write + [cid] + # make them equal size + to_write.extend([""] * (MAX_NUM_COMPONENTS * self.formatter.length() - len(to_write))) - # header_right - to_write.extend(self.formatter.content_right(freq)) - rows.append(to_write) + # structure_id and colocation_id + to_write = [structure.id] + to_write + [match.match_id] + + # header_right + to_write.extend(self.formatter.content_right(len(match))) + rows.append(to_write) + + if self.formatter.group(): + break if rows != []: rows = self.sorted_rows(rows) @@ -1210,7 +1283,8 @@ class Writer: fp = fp_open(s.id) self.write_header(fp) - self.write_out_worker(fp, s.id, s.components, colocation_ids) + self.formatter.set_structure(s) + self.write_out_worker(fp, s, colocation_ids) if self.multiple_output: fp_close(fp) @@ -1219,20 +1293,29 @@ class Writer: fp_close(fp) class StructureMatch: - def __init__(self, match_id, structure_id): + def __init__(self, match_id, structure): self.match_id = match_id - self.structure_id = structure_id + self.structure = structure self.matches = [] self.representations = {} - def distinct_matches(self): + def distinct_forms(self): dm = set() keys = list(self.matches[0].keys()) for words in self.matches: dm.add(" ".join(words[k].text for k in keys)) return len(dm) + def distinct_2w_forms(self): + dm = set() + # TODO + # keys = [key for key in self.matches[0] if self.comp + + for words in self.matches: + dm.add(" ".join(words[k].text for k in keys)) + return len(dm) + def append(self, match): self.matches.append(match) @@ -1245,35 +1328,44 @@ class ColocationIds: self.min_frequency = args.min_freq self.dispersions = {} - def _add_match(self, key, sid, match): + def _add_match(self, key, structure, match): if key not in self.data: - self.data[key] = StructureMatch(str(len(self.data) + 1), sid) + self.data[key] = StructureMatch(str(len(self.data) + 1), structure) self.data[key].append(match) def get(self, key, n): return self.data[key][n] def add_matches(self, matches): - for sid, nms in matches.items(): + for structure, nms in matches.items(): for nm in nms: - self._add_match(nm[1], sid, nm[0]) + self._add_match(nm[1], structure, nm[0]) - def get_matches_for(self, structure_id, group): + def get_matches_for(self, structure): for _cid_tup, sm in self.data.items(): - if sm.structure_id != structure_id: + if sm.structure != structure: continue - for words in sm.matches: - yield (sm.match_id, words, len(sm), sm.representations) - if group: - break + yield sm + + # all_words = [] + # more_data = [] + + # for words in sm.matches: + # more_data.append((sm.match_id, words, len(sm), sm.representations)) + # all_words.append(words) + + # if group: + # more_data = more_data[:1] + + # yield all_words, more_data def set_representations(self, structures, word_renderer): components_dict = {structure.id: structure for structure in structures} idx = 1 for _1, sm in tqdm(self.data.items()): ComponentRendition.set_representations( - sm, components_dict[sm.structure_id], word_renderer) + sm, components_dict[sm.structure.id], word_renderer) idx += 1 def determine_colocation_dispersions(self): @@ -1285,7 +1377,7 @@ class ColocationIds: def match_file(words, structures): - matches = {s.id: [] for s in structures} + matches = {s: [] for s in structures} for s in tqdm(structures): for w in words: @@ -1295,7 +1387,7 @@ def match_file(words, structures): colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x: x[0])) colocation_id = tuple(colocation_id) - matches[s.id].append((match, colocation_id)) + matches[s].append((match, colocation_id)) return matches @@ -1364,8 +1456,12 @@ def main(structures_file, args): if args.output: colocation_ids.set_representations(structures, word_renderer) - Writer.make_output_writer(args, colocation_ids).write_out(structures, colocation_ids) - Writer.make_all_writer(args, colocation_ids).write_out(structures, colocation_ids) + Writer.make_output_writer(args, colocation_ids, word_renderer).write_out( + structures, colocation_ids) + Writer.make_all_writer(args, colocation_ids, word_renderer).write_out( + structures, colocation_ids) + Writer.make_stats_writer(args, colocation_ids, word_renderer).write_out( + structures, colocation_ids) if __name__ == '__main__': parser = argparse.ArgumentParser( @@ -1378,6 +1474,8 @@ if __name__ == '__main__': help='Output file (if none given, then output to stdout)') parser.add_argument('--all', help='Additional output file, writes more data') + parser.add_argument('--stats', + help='Output file for statistics') parser.add_argument('--no-msd-translate', help='MSDs are translated from slovene to english by default', @@ -1395,8 +1493,6 @@ if __name__ == '__main__': parser.add_argument('--multiple-output', help='Generate one output for each syntactic structure', action='store_true') - parser.add_argument('--stats', - help='Output file should contain statistics', action='store_true') parser.add_argument('--sort-by', help="Sort by a this column (index)", type=int, default=-1)