Moving matches into colocation ids, now easier for representation

This commit is contained in:
Ozbolt Menegatti 2019-05-13 08:35:55 +02:00
parent 87712128be
commit 19067e4135

66
wani.py
View File

@ -388,7 +388,6 @@ class Component:
for feature in representation: for feature in representation:
f = ComponentRepresentation.new(dict(feature.attrib)) f = ComponentRepresentation.new(dict(feature.attrib))
print(f)
if type(f) is None: if type(f) is None:
logging.warning("Unknown representation in component {}, skipping...".format(self.idx), file=sys.stderr) logging.warning("Unknown representation in component {}, skipping...".format(self.idx), file=sys.stderr)
@ -875,9 +874,10 @@ class Writer:
def write_header(self, file_handler): def write_header(self, file_handler):
file_handler.write(", ".join(self.header()) + "\n") file_handler.write(", ".join(self.header()) + "\n")
def write_out_worker(self, file_handler, matches, structure_id, components, colocation_ids): def write_out_worker(self, file_handler, structure_id, components, colocation_ids):
rows = [] rows = []
for m, reason, cid in matches:
for cid, m, reason, freq in colocation_ids.get_matches_for(structure_id, not self.all):
to_write = [] to_write = []
representation = "" representation = ""
@ -889,17 +889,12 @@ class Writer:
# make them equal size # make them equal size
to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write)))
to_write = [structure_id] + to_write + [colocation_ids.to_id(cid)] to_write = [structure_id] + to_write + [cid]
if not self.all: if not self.all:
representation = re.sub(' +', ' ', representation) representation = re.sub(' +', ' ', representation)
to_write.append(representation.strip()) to_write.append(representation.strip())
to_write.append(str(freq))
if colocation_ids.should_write(cid):
to_write.append(colocation_ids.num(cid))
colocation_ids.set_written(cid)
else:
continue
rows.append(to_write) rows.append(to_write)
@ -908,7 +903,7 @@ class Writer:
file_handler.write("\n".join([", ".join(row) for row in rows]) + "\n") file_handler.write("\n".join([", ".join(row) for row in rows]) + "\n")
file_handler.flush() file_handler.flush()
def write_out(self, matches, structures, colocation_ids): def write_out(self, structures, colocation_ids):
def fp_close(fp_): def fp_close(fp_):
if fp_ != sys.stdout: if fp_ != sys.stdout:
fp_.close() fp_.close()
@ -930,8 +925,7 @@ class Writer:
fp=fp_open(s.id) fp=fp_open(s.id)
self.write_header(fp) self.write_header(fp)
sid_matches = matches[s.id] self.write_out_worker(fp, s.id, s.components, colocation_ids)
self.write_out_worker(fp, sid_matches, s.id, s.components, colocation_ids)
if self.multiple_output: if self.multiple_output:
fp_close(fp) fp_close(fp)
@ -945,34 +939,35 @@ class ColocationIds:
self.data = {} self.data = {}
self.min_frequency = args.min_freq self.min_frequency = args.min_freq
def add_match(self, key): def _add_match(self, key, sid, match):
if key in self.data: if key in self.data:
self.data[key][1] += 1 self.data[key][1].append(match)
else: else:
self.data[key] = [str(len(self.data) + 1), 1, False] self.data[key] = (str(len(self.data) + 1), [match], sid)
def get(self, key, n): def get(self, key, n):
return self.data[key][n] return self.data[key][n]
def should_write(self, key):
return self.get(key, 1) >= self.min_frequency and not self.get(key, 2)
def num(self, key): def num(self, key):
return str(self.get(key, 1)) return str(len(self.get(key, 1)))
def to_id(self, key): def to_id(self, key):
return self.get(key, 0) return self.get(key, 0)
def set_written(self, key): def add_matches(self, matches):
self.data[key][2] = True for sid, nms in matches.items():
def merge_matches(self, matches, new_matches):
for _id, nms in new_matches.items():
for nm in nms: for nm in nms:
matches[_id].append(nm) self._add_match(nm[2], sid, (nm[0], nm[1]))
self.add_match(nm[2])
def get_matches_for(self, structure_id, group):
for _cid_tup, (cid, cid_matches, sid) in self.data.items():
if sid != structure_id:
continue
return matches for words, reason in cid_matches:
yield (cid, words, reason, len(cid_matches))
if group:
break
def match_file(words, structures): def match_file(words, structures):
@ -999,7 +994,6 @@ def main(input_file, structures_file, args):
logging.debug(str(s)) logging.debug(str(s))
colocation_ids = ColocationIds() colocation_ids = ColocationIds()
matches = {s.id: [] for s in structures}
if args.parallel: if args.parallel:
num_parallel = int(args.parallel) num_parallel = int(args.parallel)
@ -1026,23 +1020,23 @@ def main(input_file, structures_file, args):
# fancy interface to wait for threads to finish # fancy interface to wait for threads to finish
for id_input in executor.map(func, [i for i, _ in enumerate(args.input)]): for id_input in executor.map(func, [i for i, _ in enumerate(args.input)]):
with open("{}/{}.p".format(tmpdirname, id_input), "rb") as fp: with open("{}/{}.p".format(tmpdirname, id_input), "rb") as fp:
new_matches = pickle.load(fp) matches = pickle.load(fp)
matches = colocation_ids.merge_matches(matches, new_matches) colocation_ids.add_matches(matches)
else: else:
for words in load_files(args): for words in load_files(args):
new_matches = match_file(words, structures) matches = match_file(words, structures)
# just save to temporary file, used for children of a parallel process # just save to temporary file, used for children of a parallel process
if args.match_to_file is not None: if args.match_to_file is not None:
with open(args.match_to_file, "wb") as fp: with open(args.match_to_file, "wb") as fp:
pickle.dump(new_matches, fp) pickle.dump(matches, fp)
return return
else: else:
matches = colocation_ids.merge_matches(matches, new_matches) colocation_ids.add_matches(matches)
if args.all: if args.all:
Writer.make_all_writer(args).write_out(matches, structures, colocation_ids) Writer.make_all_writer(args).write_out(structures, colocation_ids)
Writer.make_output_writer(args).write_out(matches, structures, colocation_ids) Writer.make_output_writer(args).write_out(structures, colocation_ids)
logging.debug([(k, len(v)) for k, v in matches.items()]) logging.debug([(k, len(v)) for k, v in matches.items()])
logging.debug(sum(len(v) for _, v in matches.items())) logging.debug(sum(len(v) for _, v in matches.items()))