diff --git a/src/loader.py b/src/loader.py index 14d118f..0215a9e 100644 --- a/src/loader.py +++ b/src/loader.py @@ -18,7 +18,8 @@ def load_files(args): skip_id_check = args.skip_id_check do_msd_translate = not args.no_msd_translate - for fname in progress(filenames, "files", outfile=True): + for idx, fname in enumerate(filenames): + print("FILE ", fname, "{}/{}".format(idx, len(filenames))) extension = pathlib.Path(fname).suffix if extension == ".xml": @@ -49,7 +50,7 @@ def load_gz(filename): result.extend(words.values()) with gzip.open(filename, 'r') as fp: - for line in progress(fp, 'load-gz', infile=True): + for line in progress(fp, 'load-gz'): line_str = line.decode('utf8').strip() line_fixed = line_str.replace(',', '\t').replace('\t\t\t', '\t,\t') line_split = line_fixed.split("\t") @@ -85,7 +86,7 @@ def load_xml(filename): def file_sentence_generator(et, skip_id_check, do_msd_translate, pc_tag): words = {} sentences = list(et.iter('s')) - for sentence in progress(sentences, "load-text", infile=True): + for sentence in progress(sentences, "load-text"): for w in sentence.iter("w"): words[w.get('id')] = Word.from_xml(w, do_msd_translate) for pc in sentence.iter(pc_tag): diff --git a/src/progress_bar.py b/src/progress_bar.py index 6f8ad48..cf8254a 100644 --- a/src/progress_bar.py +++ b/src/progress_bar.py @@ -1,41 +1,32 @@ +import time + try: from tqdm import tqdm except ImportError: tqdm = None +REPORT_ON = 0.3 + class Progress: - def __init__(self): - self.hide_inner = False - - - def __call__(self, iterable, description, infile=False, outfile=False): - show_progress = True - if True in (infile, outfile): - assert False in (infile, outfile) - show_progress = outfile == self.hide_inner - - if not show_progress: - yield from iterable - return - + def __call__(self, iterable, description): if tqdm is None: - iterlist = list(iterable) - proc = -1 - for n, el in enumerate(iterlist): - nxt_proc = int(n / len(iterlist) * 100) - if nxt_proc > proc: - print("\r{}: {:02d}% ({}/{})".format(description, nxt_proc, n, len(iterlist)), end="") - proc = nxt_proc + try: + ln = len(iterable) + except TypeError: + ln = -1 + + last_report = time.time() - REPORT_ON + for n, el in enumerate(iterable): + now = time.time() + if now - last_report > REPORT_ON: + print("\r{}: {}/{}".format(description, n, ln), end="") + last_report = now yield el print("") else: yield from tqdm(iterable, desc=description) - def init(self, args): - self.hide_inner = args.hide_inner_progress - - progress = Progress() diff --git a/src/wani.py b/src/wani.py index 08b68a8..7de759d 100644 --- a/src/wani.py +++ b/src/wani.py @@ -22,7 +22,7 @@ from database import Database def match_file(words, structures): matches = {s: [] for s in structures} - for s in progress(structures, "matching", infile=True): + for s in progress(structures, "matching"): for w in words: mhere = s.match(w) for match in mhere: @@ -153,11 +153,9 @@ if __name__ == '__main__': parser.add_argument('--match-to-file', help='Do not use!') parser.add_argument('--pickled-structures', help='Do not use!', action='store_true') - parser.add_argument('--hide-inner-progress', help='Do not use!', action='store_true') args = parser.parse_args() logging.basicConfig(stream=sys.stderr, level=args.verbose.upper()) - progress.init(args) start = time.time() main(args)