diff --git a/wani.py b/wani.py index 17db942..8896d2a 100644 --- a/wani.py +++ b/wani.py @@ -739,18 +739,15 @@ def is_root_id(id_): return len(id_.split('.')) == 3 -def load_corpus(args): +def load_files(args): filenames = args.input skip_id_check = args.skip_id_check - result = [] for fname in filenames: - load_tei_file(fname, skip_id_check, result) - - return result + yield load_tei_file(fname, skip_id_check) -def load_tei_file(filename, skip_id_check, previous_words): +def load_tei_file(filename, skip_id_check): logging.info("LOADING FILE: {}".format(filename)) with open(filename, 'r') as fp: @@ -792,7 +789,7 @@ def load_tei_file(filename, skip_id_check, previous_words): # strange errors, just skip... pass - previous_words.extend(words.values()) + return words.values() class Writer: def __init__(self, args): @@ -924,29 +921,9 @@ class ColocationIds: self.data[key][2] = True - -def main(input_file, structures_file, args): - writer = Writer(args) - structures = build_structures(structures_file) - for s in structures: - logging.debug(str(s)) - - if args.temporary_load: - logging.info("Loading temporary file: {}".format(args.temporary_load)) - with open(args.temporary_load, "rb") as fp: - words = pickle.load(fp) - else: - words = load_corpus(args) - - if args.temporary_save is not None: - logging.info("Saving to temporary file: {}".format(args.temporary_save)) - with open(args.temporary_save, "wb") as fp: - pickle.dump(words, fp) - return - - logging.info("MATCHES...") - matches = {s.id: [] for s in structures} - colocation_ids = ColocationIds() +def match_file(words, structures, colocation_ids, matches=None): + if matches is None: + matches = {s.id: [] for s in structures} for idx, s in enumerate(structures): logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id)) @@ -961,6 +938,20 @@ def main(input_file, structures_file, args): colocation_ids.add_match(colocation_id) matches[s.id].append((match, reason, colocation_id)) + return matches + +def main(input_file, structures_file, args): + writer = Writer(args) + structures = build_structures(structures_file) + for s in structures: + logging.debug(str(s)) + + colocation_ids = ColocationIds() + matches = None + + for words in load_files(args): + matches = match_file(words, structures, colocation_ids, matches) + writer.write_out(matches, structures, colocation_ids) logging.debug([(k, len(v)) for k, v in matches.items()]) @@ -979,9 +970,6 @@ if __name__ == '__main__': parser.add_argument('--verbose', help='Enable verbose output to stderr', choices=["warning", "info", "debug"], default="info") parser.add_argument('--multiple-output', help='Generate one output for each syntactic structure', action='store_true') - parser.add_argument('--temporary-save', help='Save corpus given as input to a temporary file for faster loading') - parser.add_argument('--temporary-load', help='Load corpus from a temporary file') - args = parser.parse_args() logging.basicConfig(stream=sys.stderr, level=args.verbose.upper())