Moving matches into colocation ids, now easier for representation

pull/1/head
Ozbolt Menegatti 5 years ago
parent 87712128be
commit 19067e4135

@ -388,7 +388,6 @@ class Component:
for feature in representation:
f = ComponentRepresentation.new(dict(feature.attrib))
print(f)
if type(f) is None:
logging.warning("Unknown representation in component {}, skipping...".format(self.idx), file=sys.stderr)
@ -875,9 +874,10 @@ class Writer:
def write_header(self, file_handler):
file_handler.write(", ".join(self.header()) + "\n")
def write_out_worker(self, file_handler, matches, structure_id, components, colocation_ids):
def write_out_worker(self, file_handler, structure_id, components, colocation_ids):
rows = []
for m, reason, cid in matches:
for cid, m, reason, freq in colocation_ids.get_matches_for(structure_id, not self.all):
to_write = []
representation = ""
@ -889,17 +889,12 @@ class Writer:
# make them equal size
to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write)))
to_write = [structure_id] + to_write + [colocation_ids.to_id(cid)]
to_write = [structure_id] + to_write + [cid]
if not self.all:
representation = re.sub(' +', ' ', representation)
to_write.append(representation.strip())
if colocation_ids.should_write(cid):
to_write.append(colocation_ids.num(cid))
colocation_ids.set_written(cid)
else:
continue
to_write.append(str(freq))
rows.append(to_write)
@ -908,7 +903,7 @@ class Writer:
file_handler.write("\n".join([", ".join(row) for row in rows]) + "\n")
file_handler.flush()
def write_out(self, matches, structures, colocation_ids):
def write_out(self, structures, colocation_ids):
def fp_close(fp_):
if fp_ != sys.stdout:
fp_.close()
@ -930,8 +925,7 @@ class Writer:
fp=fp_open(s.id)
self.write_header(fp)
sid_matches = matches[s.id]
self.write_out_worker(fp, sid_matches, s.id, s.components, colocation_ids)
self.write_out_worker(fp, s.id, s.components, colocation_ids)
if self.multiple_output:
fp_close(fp)
@ -945,34 +939,35 @@ class ColocationIds:
self.data = {}
self.min_frequency = args.min_freq
def add_match(self, key):
def _add_match(self, key, sid, match):
if key in self.data:
self.data[key][1] += 1
self.data[key][1].append(match)
else:
self.data[key] = [str(len(self.data) + 1), 1, False]
self.data[key] = (str(len(self.data) + 1), [match], sid)
def get(self, key, n):
return self.data[key][n]
def should_write(self, key):
return self.get(key, 1) >= self.min_frequency and not self.get(key, 2)
def num(self, key):
return str(self.get(key, 1))
return str(len(self.get(key, 1)))
def to_id(self, key):
return self.get(key, 0)
def set_written(self, key):
self.data[key][2] = True
def merge_matches(self, matches, new_matches):
for _id, nms in new_matches.items():
def add_matches(self, matches):
for sid, nms in matches.items():
for nm in nms:
matches[_id].append(nm)
self.add_match(nm[2])
self._add_match(nm[2], sid, (nm[0], nm[1]))
def get_matches_for(self, structure_id, group):
for _cid_tup, (cid, cid_matches, sid) in self.data.items():
if sid != structure_id:
continue
return matches
for words, reason in cid_matches:
yield (cid, words, reason, len(cid_matches))
if group:
break
def match_file(words, structures):
@ -999,7 +994,6 @@ def main(input_file, structures_file, args):
logging.debug(str(s))
colocation_ids = ColocationIds()
matches = {s.id: [] for s in structures}
if args.parallel:
num_parallel = int(args.parallel)
@ -1026,23 +1020,23 @@ def main(input_file, structures_file, args):
# fancy interface to wait for threads to finish
for id_input in executor.map(func, [i for i, _ in enumerate(args.input)]):
with open("{}/{}.p".format(tmpdirname, id_input), "rb") as fp:
new_matches = pickle.load(fp)
matches = colocation_ids.merge_matches(matches, new_matches)
matches = pickle.load(fp)
colocation_ids.add_matches(matches)
else:
for words in load_files(args):
new_matches = match_file(words, structures)
matches = match_file(words, structures)
# just save to temporary file, used for children of a parallel process
if args.match_to_file is not None:
with open(args.match_to_file, "wb") as fp:
pickle.dump(new_matches, fp)
pickle.dump(matches, fp)
return
else:
matches = colocation_ids.merge_matches(matches, new_matches)
colocation_ids.add_matches(matches)
if args.all:
Writer.make_all_writer(args).write_out(matches, structures, colocation_ids)
Writer.make_output_writer(args).write_out(matches, structures, colocation_ids)
Writer.make_all_writer(args).write_out(structures, colocation_ids)
Writer.make_output_writer(args).write_out(structures, colocation_ids)
logging.debug([(k, len(v)) for k, v in matches.items()])
logging.debug(sum(len(v) for _, v in matches.items()))

Loading…
Cancel
Save