diff --git a/wani.py b/wani.py index 5f4e0e2..7d10c1d 100644 --- a/wani.py +++ b/wani.py @@ -1020,11 +1020,6 @@ def load_tei_file(filename, skip_id_check, do_msd_translate, pc_tag, status): return list(words.values()) -class Writer: - @staticmethod - def make_output_writer(args): - return Writer(False, args.output, args.multiple_output, - int(args.sort_by), args.sort_reversed, args.stats) class Formatter: def __init__(self, colocation_ids): @@ -1096,49 +1091,50 @@ class AllFormatter(Formatter): def group(self): return False - def make_all_writer(args): - return Writer(True, args.all, False, -1, False, False) - def __init__(self, all, filename, multiple_output, sort_by, sort_reversed, stats): - self.all = all - self.output_file = filename - self.multiple_output = multiple_output - self.stats = stats - self.sort_by = sort_by - self.sort_order = sort_reversed +class Writer: + @staticmethod + def other_params(args): + return (args.multiple_output, int(args.sort_by), args.sort_reversed) + + @staticmethod + def make_output_writer(args, colocation_ids): + params = Writer.other_params(args) + return Writer(args.output, OutFormatter(colocation_ids), params) + + @staticmethod + def make_all_writer(args, colocation_ids): + return Writer(args.all, AllFormatter(colocation_ids), None) + + @staticmethod + # todo... + def make_stats_writer(args): + params = Writer.other_params(args) + return Writer(args.stats, None, None) + + def __init__(self, file_out, formatter, params): + if params is None: + self.multiple_output = False + self.sort_by = -1 + self.sort_order = None + else: + self.multiple_output = params[0] + self.sort_by = params[1] + self.sort_order = params[2] + + self.output_file = file_out + self.formatter = formatter def header(self): - cols = ["Lemma"] - if self.all: - cols = ["Token_ID", "Word_form"] + cols + ["Msd"] - else: - cols.extend(["Representative_form", "RF_scenario"]) + repeating_cols = self.formatter.header_repeat() + cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS) + for thd in repeating_cols] - assert len(cols) == self.length() - cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS) for thd in cols] cols = ["Structure_ID"] + cols + ["Colocation_ID"] - - if not self.all: - cols += ["Joint_representative_form", "Frequency"] - + cols += self.formatter.header_right() return cols - def length(self): - return 4 if self.all else 3 - - def from_word(self, word, representation, rep_exists): - if word is None: - return [""] * self.length() - elif self.all: - return [word.id, word.text, word.lemma, word.msd] - elif not rep_exists: - return [word.lemma, "", ""] - elif representation is None: - return [word.lemma, word.lemma, "lemma_fallback"] - else: - return [word.lemma, representation, "ok"] - def sorted_rows(self, rows): if self.sort_by < 0 or len(rows) < 2: return rows @@ -1162,28 +1158,26 @@ class AllFormatter(Formatter): def write_out_worker(self, file_handler, structure_id, components, colocation_ids): rows = [] + group = self.formatter.group() - for cid, m, freq, rprsnt in colocation_ids.get_matches_for(structure_id, not self.all): + for cid, m, freq, rprsnt in colocation_ids.get_matches_for(structure_id, group): to_write = [] - representation = "" for idx, _comp in enumerate(components): idx = str(idx + 1) - word = m[idx] if idx in m else None - rep_exists = idx in rprsnt - rep = rprsnt[idx] if rep_exists else None - to_write.extend(self.from_word(word, rep, rep_exists)) - representation += " " + to_write[-2] + if idx not in m: + to_write.extend([""] * self.formatter.length()) + else: + to_write.extend(self.formatter.content_repeat(m, rprsnt, idx)) # make them equal size - to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) + to_write.extend([""] * (MAX_NUM_COMPONENTS * self.formatter.length() - len(to_write))) + + # structure_id and colocation_id to_write = [structure_id] + to_write + [cid] - if not self.all: - representation = re.sub(' +', ' ', representation) - to_write.append(representation.strip()) - to_write.append(str(freq)) - + # header_right + to_write.extend(self.formatter.content_right(freq)) rows.append(to_write) if rows != []: