diff --git a/wani.py b/wani.py index cb6ece8..432f417 100644 --- a/wani.py +++ b/wani.py @@ -142,9 +142,9 @@ class Order(Enum): def new(order): if order is not None: if order == "to-from": - return Order.ToFrom + return Order.ToFrom elif order == "from-to": - return Order.FromTo + return Order.FromTo else: raise NotImplementedError("What kind of ordering is: {}".format(order)) else: @@ -174,7 +174,7 @@ class ComponentRepresentation: self.words = [] self.rendition_text = None self.agreement = [] - + def get_agreement(self): return [] @@ -195,7 +195,7 @@ class LemmaCR(ComponentRepresentation): class LexisCR(ComponentRepresentation): def _render(self): return self.data['lexis'] - + class WordFormAllCR(ComponentRepresentation): def _render(self): if len(self.words) == 0: @@ -228,7 +228,7 @@ class WordFormAnyCR(ComponentRepresentation): return None else: return text_forms[(word_msd, word_lemma)] - + class WordFormMsdCR(WordFormAnyCR): def __init__(self, *args): super().__init__(*args) @@ -251,7 +251,7 @@ class WordFormMsdCR(WordFormAnyCR): return True pass - + def add_word(self, word): if self.lemma is None: self.lemma = word.lemma @@ -259,7 +259,7 @@ class WordFormMsdCR(WordFormAnyCR): if self.check_msd(word.msd): super().add_word(word) - + def _render(self): msd = self.word_renderer.get_lemma_msd(self.lemma, self.msd) WordLemma = namedtuple('WordLemmaOnly', 'msd most_frequent_text lemma text') @@ -272,10 +272,10 @@ class WordFormAgreementCR(WordFormMsdCR): def __init__(self, data, word_renderer): super().__init__(data, word_renderer) self.rendition_candidate = None - + def get_agreement(self): return self.data['other'] - + def match(self, word_msd): existing = [(w.msd, w.text) for w in self.words] @@ -289,7 +289,7 @@ class WordFormAgreementCR(WordFormMsdCR): return True return False - + def confirm_match(self): self.rendition_text = self.rendition_candidate @@ -304,8 +304,8 @@ class WordFormAgreementCR(WordFormMsdCR): v1 = TAGSET[t1].index(agr_case) # if none specified: nedolocnik, always agrees - if v1 + 1 >= len(msd1): - continue + if v1 + 1 >= len(msd1): + continue # first is uppercase, not in TAGSET m1 = msd1[v1 + 1] @@ -315,8 +315,8 @@ class WordFormAgreementCR(WordFormMsdCR): logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd2)) return False v2 = TAGSET[t2].index(agr_case) - if v2 + 1 >= len(msd2): - continue + if v2 + 1 >= len(msd2): + continue m2 = msd2[v2 + 1] # match! @@ -324,7 +324,7 @@ class WordFormAgreementCR(WordFormMsdCR): return False return True - + def render(self): pass @@ -333,7 +333,7 @@ class ComponentRendition: def __init__(self): self.more = {} self.representation_factory = ComponentRepresentation - + def add_feature(self, feature): if 'rendition' in feature: if feature['rendition'] == "lemma": @@ -366,10 +366,10 @@ class ComponentRendition: else: return None - + def cr_instance(self, word_renderer): return self.representation_factory(self.more, word_renderer) - + @staticmethod def set_representations(matches, structure, word_renderer): representations = {} @@ -377,7 +377,7 @@ class ComponentRendition: representations[c.idx] = [] for rep in c.representation: representations[c.idx].append(rep.cr_instance(word_renderer)) - + for cid, reps in representations.items(): for rep in reps: for agr in rep.get_agreement(): @@ -514,7 +514,7 @@ class Restriction: self.matcher = None self.present = None return - + restriction_type = restriction_tag.get('type') if restriction_type == "morphology": self.type = RestrictionType.Morphology @@ -620,7 +620,7 @@ class Component: mn = self._match_next(word) if mn is None: return None - + to_ret = [m1] for cmatch in mn: # if good match but nothing to add, just continue @@ -660,7 +660,7 @@ class Component: # need to get all links that match for next, link, order in self.next_element: - next_links = word.get_links(link) + next_links = word.get_links(link) to_ret.append([]) # good flag @@ -700,7 +700,7 @@ class SyntacticStructure: st = SyntacticStructure() st.id = xml.get('id') st.lbs = xml.get('LBS') - + assert(len(list(xml)) == 1) system = next(iter(xml)) @@ -731,7 +731,7 @@ class SyntacticStructure: st.determine_core2w() return st - + def determine_core2w(self): ppb_components = [] for c in self.components: @@ -777,7 +777,7 @@ class SyntacticStructure: def load_structures(filename): with open(filename, 'r') as fp: et = ElementTree.XML(fp.read()) - + return build_structures(et), get_lemma_features(et) def build_structures(et): @@ -807,7 +807,7 @@ def get_lemma_features(et): rgx_str += position[1] else: raise RuntimeError("Strange rgx for lemma_feature...") - + assert(rgx_str[0].isupper()) result[rgx_str[0]] = rgx_str.strip().replace(' ', '-') @@ -860,7 +860,7 @@ class Word: self.links[link].extend(self.links[l]) return self.links[link] - + def most_frequent_text(self, word_renderer): return word_renderer.render(self.lemma, self.msd) @@ -873,13 +873,13 @@ class WordMsdRenderer: self.lemma_msd = {} self.lemma_features = lemma_features self.memoized_msd_merges = {} - + def add_words(self, words): self.all_words.extend(words) - + def num_all_words(self): return len(self.all_words) - + def generate_renders(self): data = defaultdict(lambda: defaultdict(list)) for w in self.all_words: @@ -898,15 +898,15 @@ class WordMsdRenderer: for txt in texts: freq_words[(msd, txt)] += 1 - + common_msd = self.merge_msd(common_msd, msd) - + self.lemma_msd[lemma] = common_msd - + self.frequent_words[lemma] = [] for (msd, txt), n in sorted(freq_words.items(), key=lambda x: -x[1]): self.frequent_words[lemma].append((msd, txt, n)) - + lf = self.lemma_features for lemma in self.lemma_msd.keys(): cmsd = self.lemma_msd[lemma] @@ -914,7 +914,7 @@ class WordMsdRenderer: self.lemma_msd[lemma] = "".join( l1 if l1 != "-" else l2 for l1, l2 in zip(lf[cmsd[0]], cmsd) ) - + def merge_msd(self, common_msd, new_msd): key = (common_msd, new_msd) if key in self.memoized_msd_merges: @@ -931,12 +931,12 @@ class WordMsdRenderer: value = "".join(merge_letter(l1, l2) for l1, l2 in zip(common_msd, new_msd)) self.memoized_msd_merges[key] = value return value - + def render(self, lemma, msd): if lemma in self.rendered_words: if msd in self.rendered_words[lemma]: return self.rendered_words[lemma][msd][0] - + def available_words(self, lemma, existing_texts): counted_texts = Counter(existing_texts) for (msd, text), _n in counted_texts.most_common(): @@ -946,7 +946,7 @@ class WordMsdRenderer: for msd, text, _ in self.frequent_words[lemma]: if (msd, text) not in counted_texts: yield (msd, text) - + def get_lemma_msd(self, lemma, word_msd): # should be here, since we collect every lemmas lemma_msd = self.lemma_msd[lemma] @@ -1024,7 +1024,7 @@ class Writer: @staticmethod def make_output_writer(args): return Writer(False, args.output, args.multiple_output, int(args.sort_by), args.sort_reversed) - + @staticmethod def make_all_writer(args): return Writer(True, args.all, False, -1, False) @@ -1047,7 +1047,7 @@ class Writer: assert(len(cols) == self.length()) cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS) for thd in cols] cols = ["Structure_ID"] + cols + ["Colocation_ID"] - + if not self.all: cols += ["Joint_representative_form", "Frequency"] @@ -1067,7 +1067,7 @@ class Writer: return [word.lemma, word.lemma, "lemma_fallback"] else: return [word.lemma, representation, "ok"] - + def sorted_rows(self, rows): if self.sort_by < 0 or len(rows) < 2: return rows @@ -1075,7 +1075,7 @@ class Writer: if len(rows[0]) <= self.sort_by: logging.warning("Cannot sort by column #{}: Not enough columns!".format(len(rows[0]))) return rows - + try: int(rows[0][self.sort_by]) key=lambda row: int(row[self.sort_by]) @@ -1103,7 +1103,7 @@ class Writer: representation += " " + to_write[-2] # make them equal size - to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) + to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) to_write = [structure_id] + to_write + [cid] if not self.all: @@ -1144,7 +1144,7 @@ class Writer: if self.multiple_output: fp_close(fp) - + if not self.multiple_output: fp_close(fp) @@ -1155,7 +1155,7 @@ class StructureMatch: self.matches = [] self.representations = {} - + def distinct_matches(self): dm = set() keys = list(self.matches[0].keys()) @@ -1179,7 +1179,7 @@ class ColocationIds: if key not in self.data: self.data[key] = StructureMatch(str(len(self.data) + 1), sid) self.data[key].append(match) - + def get(self, key, n): return self.data[key][n] @@ -1187,7 +1187,7 @@ class ColocationIds: for sid, nms in matches.items(): for nm in nms: self._add_match(nm[1], sid, nm[0]) - + def get_matches_for(self, structure_id, group): for _cid_tup, sm in self.data.items(): if sm.structure_id != structure_id: @@ -1204,7 +1204,7 @@ class ColocationIds: for _1, sm in tqdm(self.data.items()): ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer) idx += 1 - + def determine_colocation_dispersions(self): dispersions = defaultdict(int) for (structure_id, *word_tups) in self.data.keys(): @@ -1219,7 +1219,7 @@ def match_file(words, structures): for s in tqdm(structures): for w in words: mhere = s.match(w) - for match in mhere: + for match in mhere: colocation_id = [(idx, w.lemma) for idx, w in match.items()] colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x:x[0])) colocation_id = tuple(colocation_id) @@ -1246,7 +1246,7 @@ def main(input_file, structures_file, args): # make temporary directory to hold temporary files with tempfile.TemporaryDirectory() as tmpdirname: - cmd = sys.argv + cmd = sys.argv for inpt in args.input: if inpt in cmd: cmd.remove(inpt) @@ -1256,7 +1256,7 @@ def main(input_file, structures_file, args): del cmd[pidx] del cmd[pidx] - def func(n): + def func(n): cmdn = [sys.executable] + cmd + [args.input[n], "--match-to-file", "{}/{}.p".format(tmpdirname, n)] subprocess.check_call(cmdn) return n @@ -1322,3 +1322,5 @@ if __name__ == '__main__': start = time.time() main(args.input, args.structures, args) logging.info("TIME: {}".format(time.time() - start)) + +# 2876, 2945 type