diff --git a/wani.py b/wani.py index 150a28e..41f72d1 100644 --- a/wani.py +++ b/wani.py @@ -5,6 +5,7 @@ from collections import defaultdict import sys import logging import argparse +import time from msd_translate import MSD_TRANSLATE @@ -858,8 +859,7 @@ class Writer: if self.output_file is not None: fp.close() -def main(): - import time + class ColocationIds: def __init__(self): self.data = {} @@ -885,7 +885,11 @@ class ColocationIds: def set_written(self, key): self.data[key][2] = True + + +def main(input_file, structures_file, args): t = time.time() + writer = Writer(args) structures = build_structures(structures_file) for s in structures: @@ -893,6 +897,8 @@ class ColocationIds: logging.info("LOADING TEXT...") words = load_corpus(input_file) + + # useful for faster debugging... # import pickle # with open("words.p", "wb") as fp: # pickle.dump(words, fp) @@ -901,6 +907,7 @@ class ColocationIds: logging.info("MATCHES...") matches = {s.id: [] for s in structures} + colocation_ids = ColocationIds() for idx, s in enumerate(structures): logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id)) @@ -908,70 +915,18 @@ class ColocationIds: mhere = s.match(w) logging.debug(" GOT: {}".format(len(mhere))) for match, reason in mhere: - matches[s.id].append((match, reason)) + colocation_id = [(idx, w.lemma) for idx, w in match.items()] + colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x:x[0])) + colocation_id = tuple(colocation_id) - print("") + colocation_ids.add_match(colocation_id) + matches[s.id].append((match, reason, colocation_id)) - header = ["Structure_ID"] - for i in range(MAX_NUM_COMPONENTS): - header.extend("C{}_{}".format(i + 1, thd) for thd in - ["Token_ID", "Word_form", "Lemma", "Msd", "Representative_form"]) - header.extend(["Collocation_ID", "Joint_representative_form"]) - - csv = [", ".join(header)] - colocation_ids = {} - - for s in structures: - ms = matches[s.id] - - for m, reason in ms: - colocation_id = [s.id] - to_print = [] - - m_sorted = defaultdict(lambda: None, m.items()) - for idx, comp in enumerate(s.components): - idx = str(idx + 1) - if idx not in m_sorted: - to_print.extend(["", "", "", "", ""]) - else: - w = m_sorted[idx] - # if comp.render_word(m_sorted[idx]) is not None: - if True: - to_print.extend([w.id, w.text, w.lemma, w.msd, ""]) - colocation_id.append(w.lemma) - - colocation_id = tuple(colocation_id) - if colocation_id in colocation_ids: - cid = colocation_ids[colocation_id] - else: - cid = len(colocation_ids) + 1 - colocation_ids[colocation_id] = cid - - to_print = [s.id] + to_print - length = 1 + MAX_NUM_COMPONENTS * 5 - # make them equal size - to_print.extend([""] * (length - len(to_print))) - to_print.extend([str(cid), ""]) + writer.write_out(matches, structures, colocation_ids) logging.info("TIME: {}".format(time.time() - t)) logging.debug([(k, len(v)) for k, v in matches.items()]) logging.debug(sum(len(v) for _, v in matches.items())) - csv.append(", ".join(to_print)) - - - with open(FILE_OUT, "w") as fp: - print("\n".join(csv), file=fp) - - # groups = defaultdict(int) - # for m, reason in ms: - # if reason != "OK": - # continue - # lemmas = [(n, w.lemma) for n, w in m.items()] - # lemmas = tuple(sorted(lemmas, key=lambda x: x[0])) - # groups[lemmas] += 1 - - # print(s.id) - # print(groups) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Extract structures from a parsed corpus.') @@ -987,4 +942,4 @@ if __name__ == '__main__': args = parser.parse_args() logging.basicConfig(stream=sys.stderr, level=args.verbose.upper()) -# 6, 7 primeri laznih zadetkov? + main(args.input, args.structures, args)