NOT WORKING inbetween commit

This commit is contained in:
Ozbolt Menegatti 2019-06-09 22:25:58 +02:00
parent dff9643edf
commit c6440162b8

194
wani.py
View File

@ -891,6 +891,7 @@ class WordMsdRenderer:
common_msd = "*" * 10 common_msd = "*" * 10
for msd, texts in ld.items(): for msd, texts in ld.items():
# TODO: this should be out of generate_renders...
self.num_words[(lemma, msd[0])] += len(texts) self.num_words[(lemma, msd[0])] += len(texts)
rep = max(set(texts), key=texts.count) rep = max(set(texts), key=texts.count)
@ -1022,15 +1023,16 @@ def load_tei_file(filename, skip_id_check, do_msd_translate, pc_tag, status):
class Formatter: class Formatter:
def __init__(self, colocation_ids): def __init__(self, colocation_ids, word_renderer):
self.colocation_ids = colocation_ids self.colocation_ids = colocation_ids
self.word_renderer = word_renderer
self.additional_init() self.additional_init()
def header_repeat(self): def header_repeat(self):
raise NotImplementedError("Header repeat formatter not implemented") raise NotImplementedError("Header repeat formatter not implemented")
def header_right(self): def header_right(self):
raise NotImplementedError("Header right formatter not implemented") raise NotImplementedError("Header right formatter not implemented")
def content_repeat(self, words, representations, idx): def content_repeat(self, words, representations, idx, sidx):
raise NotImplementedError("Content repeat formatter not implemented") raise NotImplementedError("Content repeat formatter not implemented")
def content_right(self, freq): def content_right(self, freq):
raise NotImplementedError("Content right formatter not implemented") raise NotImplementedError("Content right formatter not implemented")
@ -1042,6 +1044,11 @@ class Formatter:
def length(self): def length(self):
return len(self.header_repeat()) return len(self.header_repeat())
def set_structure(self, structure):
pass
def new_match(self, match):
pass
class OutFormatter(Formatter): class OutFormatter(Formatter):
def additional_init(self): def additional_init(self):
@ -1053,7 +1060,7 @@ class OutFormatter(Formatter):
def header_right(self): def header_right(self):
return ["Joint_representative_form", "Frequency"] return ["Joint_representative_form", "Frequency"]
def content_repeat(self, words, representations, idx): def content_repeat(self, words, representations, idx, _sidx):
word = words[idx] word = words[idx]
if idx not in representations: if idx not in representations:
return [word.lemma, "", ""] return [word.lemma, "", ""]
@ -1082,7 +1089,7 @@ class AllFormatter(Formatter):
def header_right(self): def header_right(self):
return [] return []
def content_repeat(self, words, representations, idx): def content_repeat(self, words, _representations, idx, _sidx):
word = words[idx] word = words[idx]
return [word.id, word.text, word.lemma, word.msd] return [word.id, word.text, word.lemma, word.msd]
@ -1092,6 +1099,67 @@ class AllFormatter(Formatter):
def group(self): def group(self):
return False return False
class StatsFormatter(Formatter):
def additional_init(self):
self.stats = None
self.jppb = None
self.corew = None
def set_structure(self, structure):
jppb = []
corew = []
for component in structure.components:
if component.type == ComponentType.Core2w:
jppb.append(component.idx)
if component.type != ComponentType.Other:
corew.append(component.idx)
assert(len(jppb) == 2)
self.jppb = tuple(jppb)
self.corew = tuple(corew)
def new_match(self, match):
jppb_forms = set()
self.stats = {"freq": {}}
for words in match.matches:
cw1 = words[self.jppb[0]]
cw2 = words[self.jppb[1]]
jppb_forms.add((cw1.text, cw2.text))
for cid, word in match.matches[0].items():
if cid in self.corew:
self.stats["freq"][cid] = self.word_renderer.num_words[(
word.text, word.msd[0])]
self.stats['fc'] = match.distinct_forms()
self.stats['fc'] = len(jppb_forms)
self.stats['n'] = self.word_renderer.num_all_words()
def header_repeat(self):
return ["Distribution", "Delta"]
def header_right(self):
return ["LogDice_core", "LogDice_all", "Distinct_forms"]
def content_repeat(self, words, representations, idx, sidx):
word = words[idx]
key = (sidx, idx, word.lemma)
distribution = self.colocation_ids.dispersions[key]
# TODO...
delta = "?"
return [str(distribution), delta]
def content_right(self, freq):
# TODO...
return ["?"] * 3
def group(self):
return True
class Writer: class Writer:
@staticmethod @staticmethod
@ -1099,19 +1167,18 @@ class Writer:
return (args.multiple_output, int(args.sort_by), args.sort_reversed) return (args.multiple_output, int(args.sort_by), args.sort_reversed)
@staticmethod @staticmethod
def make_output_writer(args, colocation_ids): def make_output_writer(args, colocation_ids, word_renderer):
params = Writer.other_params(args) params = Writer.other_params(args)
return Writer(args.output, OutFormatter(colocation_ids), params) return Writer(args.output, OutFormatter(colocation_ids, word_renderer), params)
@staticmethod @staticmethod
def make_all_writer(args, colocation_ids): def make_all_writer(args, colocation_ids, word_renderer):
return Writer(args.all, AllFormatter(colocation_ids), None) return Writer(args.all, AllFormatter(colocation_ids, word_renderer), None)
@staticmethod @staticmethod
# todo... def make_stats_writer(args, colocation_ids, word_renderer):
def make_stats_writer(args):
params = Writer.other_params(args) params = Writer.other_params(args)
return Writer(args.stats, None, None) return Writer(args.stats, StatsFormatter(colocation_ids, word_renderer), params)
def __init__(self, file_out, formatter, params): def __init__(self, file_out, formatter, params):
if params is None: if params is None:
@ -1156,29 +1223,35 @@ class Writer:
def write_header(self, file_handler): def write_header(self, file_handler):
file_handler.write(", ".join(self.header()) + "\n") file_handler.write(", ".join(self.header()) + "\n")
def write_out_worker(self, file_handler, structure_id, components, colocation_ids): def write_out_worker(self, file_handler, structure, colocation_ids):
rows = [] rows = []
group = self.formatter.group() components = structure.components
for cid, m, freq, rprsnt in colocation_ids.get_matches_for(structure_id, group): for match in colocation_ids.get_matches_for(structure):
to_write = [] self.formatter.new_match(match)
for idx, _comp in enumerate(components): for words in match.matches:
idx = str(idx + 1) to_write = []
if idx not in m:
to_write.extend([""] * self.formatter.length())
else:
to_write.extend(self.formatter.content_repeat(m, rprsnt, idx))
# make them equal size for idx, _comp in enumerate(components):
to_write.extend([""] * (MAX_NUM_COMPONENTS * self.formatter.length() - len(to_write))) idx = str(idx + 1)
if idx not in words:
to_write.extend([""] * self.formatter.length())
else:
to_write.extend(self.formatter.content_repeat(words, match.representations, idx, structure.id))
# structure_id and colocation_id # make them equal size
to_write = [structure_id] + to_write + [cid] to_write.extend([""] * (MAX_NUM_COMPONENTS * self.formatter.length() - len(to_write)))
# header_right # structure_id and colocation_id
to_write.extend(self.formatter.content_right(freq)) to_write = [structure.id] + to_write + [match.match_id]
rows.append(to_write)
# header_right
to_write.extend(self.formatter.content_right(len(match)))
rows.append(to_write)
if self.formatter.group():
break
if rows != []: if rows != []:
rows = self.sorted_rows(rows) rows = self.sorted_rows(rows)
@ -1210,7 +1283,8 @@ class Writer:
fp = fp_open(s.id) fp = fp_open(s.id)
self.write_header(fp) self.write_header(fp)
self.write_out_worker(fp, s.id, s.components, colocation_ids) self.formatter.set_structure(s)
self.write_out_worker(fp, s, colocation_ids)
if self.multiple_output: if self.multiple_output:
fp_close(fp) fp_close(fp)
@ -1219,20 +1293,29 @@ class Writer:
fp_close(fp) fp_close(fp)
class StructureMatch: class StructureMatch:
def __init__(self, match_id, structure_id): def __init__(self, match_id, structure):
self.match_id = match_id self.match_id = match_id
self.structure_id = structure_id self.structure = structure
self.matches = [] self.matches = []
self.representations = {} self.representations = {}
def distinct_matches(self): def distinct_forms(self):
dm = set() dm = set()
keys = list(self.matches[0].keys()) keys = list(self.matches[0].keys())
for words in self.matches: for words in self.matches:
dm.add(" ".join(words[k].text for k in keys)) dm.add(" ".join(words[k].text for k in keys))
return len(dm) return len(dm)
def distinct_2w_forms(self):
dm = set()
# TODO
# keys = [key for key in self.matches[0] if self.comp
for words in self.matches:
dm.add(" ".join(words[k].text for k in keys))
return len(dm)
def append(self, match): def append(self, match):
self.matches.append(match) self.matches.append(match)
@ -1245,35 +1328,44 @@ class ColocationIds:
self.min_frequency = args.min_freq self.min_frequency = args.min_freq
self.dispersions = {} self.dispersions = {}
def _add_match(self, key, sid, match): def _add_match(self, key, structure, match):
if key not in self.data: if key not in self.data:
self.data[key] = StructureMatch(str(len(self.data) + 1), sid) self.data[key] = StructureMatch(str(len(self.data) + 1), structure)
self.data[key].append(match) self.data[key].append(match)
def get(self, key, n): def get(self, key, n):
return self.data[key][n] return self.data[key][n]
def add_matches(self, matches): def add_matches(self, matches):
for sid, nms in matches.items(): for structure, nms in matches.items():
for nm in nms: for nm in nms:
self._add_match(nm[1], sid, nm[0]) self._add_match(nm[1], structure, nm[0])
def get_matches_for(self, structure_id, group): def get_matches_for(self, structure):
for _cid_tup, sm in self.data.items(): for _cid_tup, sm in self.data.items():
if sm.structure_id != structure_id: if sm.structure != structure:
continue continue
for words in sm.matches: yield sm
yield (sm.match_id, words, len(sm), sm.representations)
if group: # all_words = []
break # more_data = []
# for words in sm.matches:
# more_data.append((sm.match_id, words, len(sm), sm.representations))
# all_words.append(words)
# if group:
# more_data = more_data[:1]
# yield all_words, more_data
def set_representations(self, structures, word_renderer): def set_representations(self, structures, word_renderer):
components_dict = {structure.id: structure for structure in structures} components_dict = {structure.id: structure for structure in structures}
idx = 1 idx = 1
for _1, sm in tqdm(self.data.items()): for _1, sm in tqdm(self.data.items()):
ComponentRendition.set_representations( ComponentRendition.set_representations(
sm, components_dict[sm.structure_id], word_renderer) sm, components_dict[sm.structure.id], word_renderer)
idx += 1 idx += 1
def determine_colocation_dispersions(self): def determine_colocation_dispersions(self):
@ -1285,7 +1377,7 @@ class ColocationIds:
def match_file(words, structures): def match_file(words, structures):
matches = {s.id: [] for s in structures} matches = {s: [] for s in structures}
for s in tqdm(structures): for s in tqdm(structures):
for w in words: for w in words:
@ -1295,7 +1387,7 @@ def match_file(words, structures):
colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x: x[0])) colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x: x[0]))
colocation_id = tuple(colocation_id) colocation_id = tuple(colocation_id)
matches[s.id].append((match, colocation_id)) matches[s].append((match, colocation_id))
return matches return matches
@ -1364,8 +1456,12 @@ def main(structures_file, args):
if args.output: if args.output:
colocation_ids.set_representations(structures, word_renderer) colocation_ids.set_representations(structures, word_renderer)
Writer.make_output_writer(args, colocation_ids).write_out(structures, colocation_ids) Writer.make_output_writer(args, colocation_ids, word_renderer).write_out(
Writer.make_all_writer(args, colocation_ids).write_out(structures, colocation_ids) structures, colocation_ids)
Writer.make_all_writer(args, colocation_ids, word_renderer).write_out(
structures, colocation_ids)
Writer.make_stats_writer(args, colocation_ids, word_renderer).write_out(
structures, colocation_ids)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -1378,6 +1474,8 @@ if __name__ == '__main__':
help='Output file (if none given, then output to stdout)') help='Output file (if none given, then output to stdout)')
parser.add_argument('--all', parser.add_argument('--all',
help='Additional output file, writes more data') help='Additional output file, writes more data')
parser.add_argument('--stats',
help='Output file for statistics')
parser.add_argument('--no-msd-translate', parser.add_argument('--no-msd-translate',
help='MSDs are translated from slovene to english by default', help='MSDs are translated from slovene to english by default',
@ -1395,8 +1493,6 @@ if __name__ == '__main__':
parser.add_argument('--multiple-output', parser.add_argument('--multiple-output',
help='Generate one output for each syntactic structure', help='Generate one output for each syntactic structure',
action='store_true') action='store_true')
parser.add_argument('--stats',
help='Output file should contain statistics', action='store_true')
parser.add_argument('--sort-by', parser.add_argument('--sort-by',
help="Sort by a this column (index)", type=int, default=-1) help="Sort by a this column (index)", type=int, default=-1)