diff --git a/wani.py b/wani.py index a4ccf91..72630f2 100644 --- a/wani.py +++ b/wani.py @@ -388,7 +388,6 @@ class Component: for feature in representation: f = ComponentRepresentation.new(dict(feature.attrib)) - print(f) if type(f) is None: logging.warning("Unknown representation in component {}, skipping...".format(self.idx), file=sys.stderr) @@ -875,9 +874,10 @@ class Writer: def write_header(self, file_handler): file_handler.write(", ".join(self.header()) + "\n") - def write_out_worker(self, file_handler, matches, structure_id, components, colocation_ids): + def write_out_worker(self, file_handler, structure_id, components, colocation_ids): rows = [] - for m, reason, cid in matches: + + for cid, m, reason, freq in colocation_ids.get_matches_for(structure_id, not self.all): to_write = [] representation = "" @@ -889,17 +889,12 @@ class Writer: # make them equal size to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) - to_write = [structure_id] + to_write + [colocation_ids.to_id(cid)] + to_write = [structure_id] + to_write + [cid] if not self.all: representation = re.sub(' +', ' ', representation) to_write.append(representation.strip()) - - if colocation_ids.should_write(cid): - to_write.append(colocation_ids.num(cid)) - colocation_ids.set_written(cid) - else: - continue + to_write.append(str(freq)) rows.append(to_write) @@ -908,7 +903,7 @@ class Writer: file_handler.write("\n".join([", ".join(row) for row in rows]) + "\n") file_handler.flush() - def write_out(self, matches, structures, colocation_ids): + def write_out(self, structures, colocation_ids): def fp_close(fp_): if fp_ != sys.stdout: fp_.close() @@ -930,8 +925,7 @@ class Writer: fp=fp_open(s.id) self.write_header(fp) - sid_matches = matches[s.id] - self.write_out_worker(fp, sid_matches, s.id, s.components, colocation_ids) + self.write_out_worker(fp, s.id, s.components, colocation_ids) if self.multiple_output: fp_close(fp) @@ -945,34 +939,35 @@ class ColocationIds: self.data = {} self.min_frequency = args.min_freq - def add_match(self, key): + def _add_match(self, key, sid, match): if key in self.data: - self.data[key][1] += 1 + self.data[key][1].append(match) else: - self.data[key] = [str(len(self.data) + 1), 1, False] + self.data[key] = (str(len(self.data) + 1), [match], sid) def get(self, key, n): return self.data[key][n] - def should_write(self, key): - return self.get(key, 1) >= self.min_frequency and not self.get(key, 2) - def num(self, key): - return str(self.get(key, 1)) + return str(len(self.get(key, 1))) def to_id(self, key): return self.get(key, 0) - def set_written(self, key): - self.data[key][2] = True - - def merge_matches(self, matches, new_matches): - for _id, nms in new_matches.items(): + def add_matches(self, matches): + for sid, nms in matches.items(): for nm in nms: - matches[_id].append(nm) - self.add_match(nm[2]) + self._add_match(nm[2], sid, (nm[0], nm[1])) + + def get_matches_for(self, structure_id, group): + for _cid_tup, (cid, cid_matches, sid) in self.data.items(): + if sid != structure_id: + continue - return matches + for words, reason in cid_matches: + yield (cid, words, reason, len(cid_matches)) + if group: + break def match_file(words, structures): @@ -999,7 +994,6 @@ def main(input_file, structures_file, args): logging.debug(str(s)) colocation_ids = ColocationIds() - matches = {s.id: [] for s in structures} if args.parallel: num_parallel = int(args.parallel) @@ -1026,23 +1020,23 @@ def main(input_file, structures_file, args): # fancy interface to wait for threads to finish for id_input in executor.map(func, [i for i, _ in enumerate(args.input)]): with open("{}/{}.p".format(tmpdirname, id_input), "rb") as fp: - new_matches = pickle.load(fp) - matches = colocation_ids.merge_matches(matches, new_matches) + matches = pickle.load(fp) + colocation_ids.add_matches(matches) else: for words in load_files(args): - new_matches = match_file(words, structures) + matches = match_file(words, structures) # just save to temporary file, used for children of a parallel process if args.match_to_file is not None: with open(args.match_to_file, "wb") as fp: - pickle.dump(new_matches, fp) + pickle.dump(matches, fp) return else: - matches = colocation_ids.merge_matches(matches, new_matches) + colocation_ids.add_matches(matches) if args.all: - Writer.make_all_writer(args).write_out(matches, structures, colocation_ids) - Writer.make_output_writer(args).write_out(matches, structures, colocation_ids) + Writer.make_all_writer(args).write_out(structures, colocation_ids) + Writer.make_output_writer(args).write_out(structures, colocation_ids) logging.debug([(k, len(v)) for k, v in matches.items()]) logging.debug(sum(len(v) for _, v in matches.items()))