diff --git a/src/match_store.py b/src/match_store.py index 52767a3..0297c64 100644 --- a/src/match_store.py +++ b/src/match_store.py @@ -85,7 +85,8 @@ class MatchStore: def set_representations(self, word_renderer, structures): structures_dict = {s.id: s for s in structures} - for cid, sid in progress(self.db.execute("SELECT colocation_id, structure_id FROM Colocations"), "representations"): + num_representations = int(self.db.execute("SELECT Count(*) FROM Colocations").fetchone()[0]) + for cid, sid in progress(self.db.execute("SELECT colocation_id, structure_id FROM Colocations"), "representations", total=num_representations): structure = structures_dict[sid] match = StructureMatch.from_db(self.db, cid, structure) RepresentationAssigner.set_representations(match, word_renderer) diff --git a/src/progress_bar.py b/src/progress_bar.py index cf8254a..4b4bcf6 100644 --- a/src/progress_bar.py +++ b/src/progress_bar.py @@ -9,23 +9,24 @@ except ImportError: REPORT_ON = 0.3 class Progress: - def __call__(self, iterable, description): + def __call__(self, iterable, description, total=None): if tqdm is None: try: - ln = len(iterable) + total = len(iterable) except TypeError: - ln = -1 + total = -1 - last_report = time.time() - REPORT_ON + start_time = time.time() + last_report = start_time - REPORT_ON for n, el in enumerate(iterable): now = time.time() if now - last_report > REPORT_ON: - print("\r{}: {}/{}".format(description, n, ln), end="") + print("\r{}: {}/{}".format(description, n, total), end="") last_report = now yield el - print("") + print(" -> {}".format(time.time() - start_time)) else: - yield from tqdm(iterable, desc=description) + yield from tqdm(iterable, desc=description, total=total) progress = Progress()