diff --git a/wani.py b/wani.py index 8573be5..f924a59 100644 --- a/wani.py +++ b/wani.py @@ -6,7 +6,7 @@ from collections import defaultdict from msd_translate import MSD_TRANSLATE -STRUKTURE = "Kolokacije_strukture_08_new-system.xml" +STRUKTURE = "Kolokacije_strukture_09_new-system.xml" STAVKI = "k2.xml" CODES = { @@ -171,10 +171,38 @@ def build_lexis_regex(restriction): return re.compile(restr_dict['lemma']) +class Restriction: + def __init__(self, restriction_tag): + restriction_type = restriction_tag.get('type') + if restriction_type == "morphology": + self.type = RestrictionType.Morphology + self.matcher = build_morphology_regex(restriction_tag.getchildren()) + elif restriction_type == "lexis": + self.type = RestrictionType.Lexis + self.matcher = build_lexis_regex(restriction_tag.getchildren()) + else: + raise NotImplementedError() + + def match(self, word): + if self.type == RestrictionType.Morphology: + match_to = word.msd + elif self.type == RestrictionType.Lexis: + match_to = word.lemma + else: + raise RuntimeError("Unreachable!") + + return self.matcher.match(match_to) + + def __str__(self): + return "({:s} {})".format(str(self.type).split('.')[1], self.matcher) + + def __repr__(self): + return str(self) + + class Component: def __init__(self, name): self.name = name if name is not None else "" - self.restriction_type = None self.restriction = None self.next_element = None self.level = None @@ -199,21 +227,24 @@ class Component: def set_next(self, next_component, link_label): self.next_element = (next_component, link_label) - def set_restriction(self, restriction_tag): - restriction_type = restriction_tag.get('type') - if restriction_type == "morphology": - self.restriction_type = RestrictionType.Morphology - self.restriction = build_morphology_regex(restriction_tag.getchildren()) - elif restriction_type == "lexis": - self.restriction_type = RestrictionType.Lexis - self.restriction = build_lexis_regex(restriction_tag.getchildren()) - else: - raise NotImplementedError() + def set_restriction(self, restrictions_tag): + if restrictions_tag.tag == "restriction": + self.restriction = Restriction(restrictions_tag) + self.level = get_level(restrictions_tag) - self.level = get_level(restriction_tag.getchildren()) + elif restrictions_tag.tag == "restriction_or": + self.restriction = [Restriction(el) for el in restrictions_tag] + self.level = get_level(restrictions_tag[0]) + + # same level for every restriction for now and only or available + levels = [get_level(el) for el in restrictions_tag] + assert(len(set(levels)) == 1) + + else: + raise RuntimeError("Unreachable") def __str__(self): - el = "(N.{:7s} {:12s} {})".format(self.name, str(self.restriction_type).split('.')[1], self.restriction) + el = "(N.{:7s} {})".format(self.name, str(self.restriction)) if self.has_next(): el += " -- {} -->\n{}".format(self.link_label(), str(self.get_next())) return el @@ -222,14 +253,19 @@ class Component: return str(self) def match(self, word): - if self.restriction_type == RestrictionType.Morphology: - match_to = word.msd - elif self.restriction_type == RestrictionType.Lexis: - match_to = word.lemma + matched = None + + # matching + if type(self.restriction) is list: + for restr in self.restriction: + matched = restr.match(word) + if matched is not None: + break else: - raise RuntimeError("Unreachable!") + matched = self.restriction.match(word) - if self.restriction.match(match_to): + # recurse to next + if matched: to_ret = [self.word_to_str(word)] # already matched everything! @@ -256,7 +292,7 @@ class SyntacticStructure: @staticmethod def from_xml(xml): st = SyntacticStructure() - st.id = int(xml.get('id')) + st.id = xml.get('id') st.lbs = xml.get('LBS') components, system = xml.getchildren() @@ -294,7 +330,7 @@ def build_structures(filename): structures = [] with open(filename, 'r') as fp: et = ElementTree.XML(fp.read()) - for structure in et.iterfind('syntactic_structure'): + for structure in et.iter('syntactic_structure'): structures.append(SyntacticStructure.from_xml(structure)) return structures @@ -348,9 +384,6 @@ def main(): structures = build_structures(STRUKTURE) for s in structures: print(s) - exit(0) - - print(STAVKI) num_matches = 0 for w in words: @@ -361,7 +394,7 @@ def main(): print(s.id, m) print("TIME", time.time() - t) - # print(num_matches) + print(num_matches) if __name__ == '__main__':