1070 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1070 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from xml.etree import ElementTree
 | 
						|
import re
 | 
						|
from enum import Enum
 | 
						|
from collections import defaultdict
 | 
						|
import sys
 | 
						|
import logging
 | 
						|
import argparse
 | 
						|
import pickle
 | 
						|
import time
 | 
						|
import subprocess
 | 
						|
import concurrent.futures
 | 
						|
import tempfile
 | 
						|
 | 
						|
from msd_translate import MSD_TRANSLATE
 | 
						|
 | 
						|
 | 
						|
MAX_NUM_COMPONENTS = 5
 | 
						|
 | 
						|
 | 
						|
CODES = {
 | 
						|
    "Noun": "N",
 | 
						|
    "Verb": "V",
 | 
						|
    "Adjective": "A",
 | 
						|
    "Adverb": "R",
 | 
						|
    "Pronoun": "P",
 | 
						|
    "Numeral": "M",
 | 
						|
    "Preposition": "S",
 | 
						|
    "Conjunction": "C",
 | 
						|
    "Particle": "Q",
 | 
						|
    "Interjection": "I",
 | 
						|
    "Abbreviation": "Y",
 | 
						|
    "Residual": "X",
 | 
						|
 | 
						|
    'common': 'c',
 | 
						|
    'proper': 'p',
 | 
						|
    'masculine': 'm',
 | 
						|
    'feminine': 'f',
 | 
						|
    'neuter': 'n',
 | 
						|
    "singular": "s",
 | 
						|
    "dual": "d",
 | 
						|
    "plural": "p",
 | 
						|
    "nominative": "n",
 | 
						|
    "genitive": "g",
 | 
						|
    "dative": "d",
 | 
						|
    "accusative": "a",
 | 
						|
    "locative": "l",
 | 
						|
    "instrumental": "i",
 | 
						|
    "no": "n",
 | 
						|
    "yes": "y",
 | 
						|
    "main": "m",
 | 
						|
    "auxiliary": "a",
 | 
						|
    "perfective": "e",
 | 
						|
    "progressive": "p",
 | 
						|
    "biaspectual": "b",
 | 
						|
    "infinitive": "n",
 | 
						|
    "supine": "u",
 | 
						|
    "participle": "p",
 | 
						|
    "present": "r",
 | 
						|
    "future": "f",
 | 
						|
    "conditional": "c",
 | 
						|
    "imperative": "m",
 | 
						|
    "first": "1",
 | 
						|
    "second": "2",
 | 
						|
    "third": "3",
 | 
						|
    "general": "g",
 | 
						|
    "possessive": "s",
 | 
						|
    "positive": "p",
 | 
						|
    "comparative": "c",
 | 
						|
    "superlative": "s",
 | 
						|
    "personal": "p",
 | 
						|
    "demonstrative": "d",
 | 
						|
    "relative": "r",
 | 
						|
    "reflexive": "x",
 | 
						|
    "interrogative": "q",
 | 
						|
    "indefinite": "i",
 | 
						|
    "negative": "z",
 | 
						|
    "bound": "b",
 | 
						|
    "digit": "d",
 | 
						|
    "roman": "r",
 | 
						|
    "letter": "l",
 | 
						|
    "cardinal": "c",
 | 
						|
    "ordinal": "o",
 | 
						|
    "pronominal": "p",
 | 
						|
    "special": "s",
 | 
						|
    "coordinating": "c",
 | 
						|
    "subordinating": "s",
 | 
						|
    "foreign": "f",
 | 
						|
    "typo": "t",
 | 
						|
    "program": "p",
 | 
						|
}
 | 
						|
 | 
						|
TAGSET = {
 | 
						|
    "N": ['type', 'gender', 'number', 'case', 'animate'],
 | 
						|
    "V": ['type', 'aspect', 'vform', 'person', 'number', 'gender', 'negative'],
 | 
						|
    "A": ['type', 'degree', 'gender', 'number', 'case', 'definiteness'],
 | 
						|
    "R": ['type', 'degree'],
 | 
						|
    "P": ['type', 'person', 'gender', 'number', 'case', 'owner_number', 'owned_gender', 'clitic'],
 | 
						|
    "M": ['form', 'type', 'gender', 'number', 'case', 'definiteness'],
 | 
						|
    "S": ['case'],
 | 
						|
    "C": ['type'],
 | 
						|
    "Q": [],
 | 
						|
    "I": [],
 | 
						|
    "Y": [],
 | 
						|
    "X": ['type']
 | 
						|
}
 | 
						|
 | 
						|
CATEGORY_BASES = {
 | 
						|
    "N": ['.'] * 5,
 | 
						|
    "V": ['.'] * 7,
 | 
						|
    "A": ['.'] * 6,
 | 
						|
    "R": ['.'] * 2,
 | 
						|
    "P": ['.'] * 6,
 | 
						|
    "M": ['.'] * 6,
 | 
						|
    "S": ['.'] * 1,
 | 
						|
    "C": ['.'] * 1,
 | 
						|
    "Q": [],
 | 
						|
    "I": [],
 | 
						|
    "Y": [],
 | 
						|
    "X": ['.'] * 1
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
class RestrictionType(Enum):
 | 
						|
    Morphology = 0
 | 
						|
    Lexis = 1
 | 
						|
    MatchAll = 2
 | 
						|
 | 
						|
 | 
						|
class Rendition(Enum):
 | 
						|
    Lemma = 0
 | 
						|
    WordForm = 1
 | 
						|
    Unknown = 2
 | 
						|
 | 
						|
class Order(Enum):
 | 
						|
    FromTo = 0
 | 
						|
    ToFrom = 1
 | 
						|
    Any = 2
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def new(order):
 | 
						|
        if order is not None:
 | 
						|
            if order == "to-from":
 | 
						|
                return Order.ToFrom 
 | 
						|
            elif order == "from-to":
 | 
						|
                return Order.FromTo 
 | 
						|
            else:
 | 
						|
                raise NotImplementedError("What kind of ordering is: {}".format(order))
 | 
						|
        else:
 | 
						|
            return Order.Any
 | 
						|
 | 
						|
 | 
						|
    def match(self, from_w, to_w):
 | 
						|
        if self is Order.Any:
 | 
						|
            return True
 | 
						|
 | 
						|
        fi = from_w.int_id
 | 
						|
        ti = to_w.int_id
 | 
						|
 | 
						|
        if self is Order.FromTo:
 | 
						|
            return fi < ti
 | 
						|
        elif self is Order.ToFrom:
 | 
						|
            return ti < fi
 | 
						|
        else:
 | 
						|
            raise NotImplementedError("Should not be here: Order match")
 | 
						|
 | 
						|
class ComponentRendition:
 | 
						|
    def __init__(self, rendition=Rendition.Unknown):
 | 
						|
        self.word_form = {}
 | 
						|
        self.rendition = rendition
 | 
						|
 | 
						|
    def render(self, word):
 | 
						|
        if self.rendition == Rendition.Lemma:
 | 
						|
            return word.lemma
 | 
						|
        elif self.rendition == Rendition.WordForm:
 | 
						|
            return word.text
 | 
						|
        elif self.rendition == Rendition.Unknown:
 | 
						|
            return None
 | 
						|
        else:
 | 
						|
            raise RuntimeError("Unknown rendition: {}".format(self.rendition))
 | 
						|
    
 | 
						|
    def __str__(self):
 | 
						|
        return str(self.rendition)
 | 
						|
 | 
						|
 | 
						|
# dont know...
 | 
						|
class StructureSelection(Enum):
 | 
						|
    All = 0
 | 
						|
    Frequency = 1
 | 
						|
 | 
						|
class ComponentRepresentation:
 | 
						|
    def new(s):
 | 
						|
        if 'rendition' in s:
 | 
						|
            if s['rendition'] == "lemma":
 | 
						|
                return ComponentRendition(Rendition.Lemma)
 | 
						|
            elif s['rendition'] == "word_form":
 | 
						|
                return ComponentRendition(Rendition.WordForm)
 | 
						|
            else:
 | 
						|
                raise NotImplementedError("Rendition: {}".format(s))
 | 
						|
        elif 'selection' in s:
 | 
						|
            if s['selection'] == "frequency":
 | 
						|
                return StructureSelection.Frequency
 | 
						|
            elif s['selection'] == "all":
 | 
						|
                return StructureSelection.All
 | 
						|
            else:
 | 
						|
                return {s['selection']: s['value']}
 | 
						|
        else:
 | 
						|
            return None
 | 
						|
 | 
						|
 | 
						|
class ComponentStatus(Enum):
 | 
						|
    Optional = 0
 | 
						|
    Required = 1
 | 
						|
    Forbidden = 2
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        if self == ComponentStatus.Optional:
 | 
						|
            return "?"
 | 
						|
        elif self == ComponentStatus.Required:
 | 
						|
            return "!"
 | 
						|
        else: #Forbidden
 | 
						|
            return "X"
 | 
						|
 | 
						|
 | 
						|
def get_level(restriction):
 | 
						|
    for feature in restriction:
 | 
						|
        if "level" in feature.keys():
 | 
						|
            lvl = feature.get("level")
 | 
						|
        else:
 | 
						|
            continue
 | 
						|
 | 
						|
    raise RuntimeError("Unreachable!")
 | 
						|
 | 
						|
 | 
						|
def build_morphology_regex(restriction):
 | 
						|
    restr_dict = {}
 | 
						|
    for feature in restriction:
 | 
						|
        feature_dict = dict(feature.items())
 | 
						|
 | 
						|
        match_type = True
 | 
						|
        if "filter" in feature_dict:
 | 
						|
            assert(feature_dict['filter'] == "negative")
 | 
						|
            match_type = False
 | 
						|
            del feature_dict['filter']
 | 
						|
 | 
						|
        assert(len(feature_dict) == 1)
 | 
						|
        key, value = next(iter(feature_dict.items()))
 | 
						|
        restr_dict[key] = (value, match_type)
 | 
						|
 | 
						|
    assert('POS' in restr_dict)
 | 
						|
    category = restr_dict['POS'][0].capitalize()
 | 
						|
    cat_code = CODES[category]
 | 
						|
    rgx = [cat_code] + CATEGORY_BASES[cat_code]
 | 
						|
 | 
						|
    del restr_dict['POS']
 | 
						|
    min_msd_length = 1
 | 
						|
 | 
						|
    for attribute, (value, typ) in restr_dict.items():
 | 
						|
        index = TAGSET[cat_code].index(attribute.lower())
 | 
						|
        assert(index >= 0)
 | 
						|
 | 
						|
        if '|' in value:
 | 
						|
            match = "".join(CODES[val] for val in value.split('|'))
 | 
						|
        else:
 | 
						|
            match = CODES[value]
 | 
						|
 | 
						|
        match = "[{}{}]".format("" if typ else "^", match)
 | 
						|
        rgx[index + 1] = match
 | 
						|
 | 
						|
        if typ:
 | 
						|
            min_msd_length = max(index + 1, min_msd_length)
 | 
						|
 | 
						|
    def matcher(text):
 | 
						|
        if len(text) <= min_msd_length:
 | 
						|
            return False
 | 
						|
 | 
						|
        for c, r in zip(text, rgx):
 | 
						|
            if not re.match(r, c):
 | 
						|
                return False
 | 
						|
        return True
 | 
						|
 | 
						|
    return " ".join(rgx), matcher
 | 
						|
 | 
						|
 | 
						|
def build_lexis_regex(restriction):
 | 
						|
    restr_dict = {}
 | 
						|
    for feature in restriction:
 | 
						|
        restr_dict.update(feature.items())
 | 
						|
 | 
						|
    assert("lemma" in restr_dict)
 | 
						|
    match_list = restr_dict['lemma'].split('|')
 | 
						|
 | 
						|
    return match_list, lambda text: text in match_list
 | 
						|
 | 
						|
 | 
						|
class Restriction:
 | 
						|
    def __init__(self, restriction_tag):
 | 
						|
        if restriction_tag is None:
 | 
						|
            self.type = RestrictionType.MatchAll
 | 
						|
            self.matcher = None
 | 
						|
            self.present = None
 | 
						|
            return
 | 
						|
        
 | 
						|
        restriction_type = restriction_tag.get('type')
 | 
						|
        if restriction_type == "morphology":
 | 
						|
            self.type = RestrictionType.Morphology
 | 
						|
            self.present, self.matcher = build_morphology_regex(list(restriction_tag))
 | 
						|
        elif restriction_type == "lexis":
 | 
						|
            self.type = RestrictionType.Lexis
 | 
						|
            self.present, self.matcher = build_lexis_regex(list(restriction_tag))
 | 
						|
        else:
 | 
						|
            raise NotImplementedError()
 | 
						|
 | 
						|
    def match(self, word):
 | 
						|
        if self.type == RestrictionType.Morphology:
 | 
						|
            match_to = word.msd
 | 
						|
        elif self.type == RestrictionType.Lexis:
 | 
						|
            match_to = word.lemma
 | 
						|
        elif self.type == RestrictionType.MatchAll:
 | 
						|
            return True
 | 
						|
        else:
 | 
						|
            raise RuntimeError("Unreachable!")
 | 
						|
 | 
						|
        return self.matcher(match_to)
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return "({:s} {})".format(str(self.type).split('.')[1], self.present)
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return str(self)
 | 
						|
 | 
						|
 | 
						|
class Component:
 | 
						|
    def __init__(self, info):
 | 
						|
        idx = info['cid']
 | 
						|
        name = info['name'] if 'name' in info else None
 | 
						|
 | 
						|
        if 'status' not in info:
 | 
						|
            status = ComponentStatus.Required
 | 
						|
        elif info['status'] == 'forbidden':
 | 
						|
            status = ComponentStatus.Forbidden
 | 
						|
        elif info['status'] == 'obligatory':
 | 
						|
            status = ComponentStatus.Required
 | 
						|
        elif info['status'] == 'optional':
 | 
						|
            status = ComponentStatus.Optional
 | 
						|
        else:
 | 
						|
            raise NotImplementedError("strange status: {}".format(info['status']))
 | 
						|
 | 
						|
        self.status = status
 | 
						|
        self.name = name
 | 
						|
        self.idx = idx
 | 
						|
        self.restriction = None
 | 
						|
        self.next_element = []
 | 
						|
        self.rendition = ComponentRendition()
 | 
						|
        self.selection = {}
 | 
						|
 | 
						|
        self.iter_ctr = 0
 | 
						|
 | 
						|
    def render_word(self, word):
 | 
						|
        return self.rendition.render(word)
 | 
						|
 | 
						|
    def add_next(self, next_component, link_label, order):
 | 
						|
        self.next_element.append((next_component, link_label, Order.new(order)))
 | 
						|
 | 
						|
    def set_restriction(self, restrictions_tag):
 | 
						|
        if restrictions_tag is None:
 | 
						|
            self.restriction = Restriction(None)
 | 
						|
 | 
						|
        elif restrictions_tag.tag == "restriction":
 | 
						|
            self.restriction = Restriction(restrictions_tag)
 | 
						|
 | 
						|
        elif restrictions_tag.tag == "restriction_or":
 | 
						|
            self.restriction = [Restriction(el) for el in restrictions_tag]
 | 
						|
 | 
						|
        else:
 | 
						|
            raise RuntimeError("Unreachable")
 | 
						|
 | 
						|
    def set_representation(self, representation):
 | 
						|
        cr = None
 | 
						|
        if representation is not None:
 | 
						|
            self.representation = []
 | 
						|
 | 
						|
            for feature in representation:
 | 
						|
                f = ComponentRepresentation.new(dict(feature.attrib))
 | 
						|
 | 
						|
                if type(f) is None:
 | 
						|
                    logging.warning("Unknown representation in component {}, skipping...".format(self.idx), file=sys.stderr)
 | 
						|
                    continue
 | 
						|
                if type(f) is StructureSelection:
 | 
						|
                    assert(cr is None)
 | 
						|
                    cr = f
 | 
						|
                elif type(f) is ComponentRendition:
 | 
						|
                    self.rendition = f
 | 
						|
                elif type(f) is dict:
 | 
						|
                    self.selection.update(f)
 | 
						|
                else:
 | 
						|
                    raise RuntimeError("Unreachable: {}".format(f))
 | 
						|
 | 
						|
        return cr
 | 
						|
 | 
						|
    def find_next(self, deps, comps, restrs, reprs):
 | 
						|
        representation = StructureSelection.All
 | 
						|
 | 
						|
        to_ret = []
 | 
						|
        for d in deps:
 | 
						|
            if d[0] == self.idx:
 | 
						|
                _, idx, dep_label, order = d
 | 
						|
 | 
						|
                next_component = Component(comps[idx])
 | 
						|
                next_component.set_restriction(restrs[idx])
 | 
						|
                r1 = next_component.set_representation(reprs[idx])
 | 
						|
                to_ret.append(next_component)
 | 
						|
 | 
						|
                self.add_next(next_component, dep_label, order)
 | 
						|
                others, r2 = next_component.find_next(deps, comps, restrs, reprs)
 | 
						|
                to_ret.extend(others)
 | 
						|
 | 
						|
                if StructureSelection.Frequency in (r1, r2):
 | 
						|
                    representation = StructureSelection.Frequency
 | 
						|
 | 
						|
        return to_ret, representation
 | 
						|
 | 
						|
    def name_str(self):
 | 
						|
        return "_" if self.name is None else self.name
 | 
						|
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        n = self.name_str()
 | 
						|
        return "{:s}) {:7s}:{} [{}] :{}".format(
 | 
						|
                self.idx, n, self.status, self.restriction, self.rendition)
 | 
						|
 | 
						|
    def tree(self):
 | 
						|
        el = []
 | 
						|
        for next, link, order in self.next_element:
 | 
						|
            s = "{:3} -- {:5} --> {:3}".format(self.idx, link, next.idx)
 | 
						|
            if order != Order.Any:
 | 
						|
                s += " " + str(order)[6:]
 | 
						|
 | 
						|
            el.append(s)
 | 
						|
            el.extend(next.tree())
 | 
						|
        return el
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return str(self)
 | 
						|
 | 
						|
    def match(self, word):
 | 
						|
        m1 = self._match_self(word)
 | 
						|
        if m1 is None:
 | 
						|
            return None
 | 
						|
 | 
						|
        mn = self._match_next(word)
 | 
						|
        if mn is None:
 | 
						|
            return None
 | 
						|
        
 | 
						|
        to_ret = [m1]
 | 
						|
        for cmatch in mn:
 | 
						|
            # if good match but nothing to add, just continue
 | 
						|
            if len(cmatch) == 0:
 | 
						|
                continue
 | 
						|
 | 
						|
            # if more than one match found for particular component
 | 
						|
            elif len(cmatch) > 1:
 | 
						|
                logging.debug("MULTIPLE: {}, {}".format(self.idx, cmatch))
 | 
						|
                # if more than one match in multiple components, NOPE!
 | 
						|
                if len(to_ret) > 1:
 | 
						|
                    logging.warning("Strange multiple match: {}".format(
 | 
						|
                        str([w.id for w in cmatch[0].values()])))
 | 
						|
 | 
						|
                    for tr in to_ret:
 | 
						|
                        tr.update(cmatch[0])
 | 
						|
                    continue
 | 
						|
 | 
						|
                # yeah, so we have found more than one match, =>
 | 
						|
                # more than one element in to_ret
 | 
						|
                to_ret = [{**dict(to_ret[0]), **m} for m in cmatch]
 | 
						|
 | 
						|
            else:
 | 
						|
                for tr in to_ret:
 | 
						|
                    tr.update(cmatch[0])
 | 
						|
 | 
						|
        logging.debug("MA: {}".format(str(to_ret)))
 | 
						|
        return to_ret
 | 
						|
 | 
						|
    def _match_self(self, word):
 | 
						|
        matched = None
 | 
						|
 | 
						|
        # matching
 | 
						|
        if type(self.restriction) is list:
 | 
						|
            for restr in self.restriction:
 | 
						|
                matched = restr.match(word)
 | 
						|
                if matched: # match either
 | 
						|
                    break
 | 
						|
        else:
 | 
						|
            matched = self.restriction.match(word)
 | 
						|
 | 
						|
        logging.debug("SELF MATCH({}: {} -> {}".format(self.idx, word.text, matched))
 | 
						|
 | 
						|
        # recurse to next
 | 
						|
        if not matched:
 | 
						|
            return None
 | 
						|
        else:
 | 
						|
            return {self.idx: word}
 | 
						|
 | 
						|
    def _match_next(self, word):
 | 
						|
        # matches for every component in links from this component
 | 
						|
        to_ret = []
 | 
						|
 | 
						|
        # need to get all links that match
 | 
						|
        for next, link, order in self.next_element:
 | 
						|
            next_links = word.get_links(link) 
 | 
						|
            logging.debug("FIND LINKS FOR: {} -> {}: #{}".format(self.idx, next.idx, len(next_links)))
 | 
						|
            to_ret.append([])
 | 
						|
 | 
						|
            # good flag
 | 
						|
            good = next.status != ComponentStatus.Required
 | 
						|
            for next_word in next_links:
 | 
						|
                logging.debug("link: {}: {} -> {}".format(link, word.id, next_word.id))
 | 
						|
                if not order.match(word, next_word):
 | 
						|
                    continue
 | 
						|
 | 
						|
                match = next.match(next_word)
 | 
						|
 | 
						|
                if match is not None:
 | 
						|
                    # special treatement for forbidden
 | 
						|
                    if next.status == ComponentStatus.Forbidden:
 | 
						|
                        good = False
 | 
						|
                        break
 | 
						|
 | 
						|
                    else:
 | 
						|
                        assert(type(match) is list)
 | 
						|
                        to_ret[-1].extend(match)
 | 
						|
                        good = True
 | 
						|
 | 
						|
            # if none matched, nothing found!
 | 
						|
            if not good:
 | 
						|
                logging.debug("BAD")
 | 
						|
                return None
 | 
						|
 | 
						|
        return to_ret
 | 
						|
 | 
						|
 | 
						|
class SyntacticStructure:
 | 
						|
    def __init__(self):
 | 
						|
        self.id = None
 | 
						|
        self.lbs = None
 | 
						|
        self.agreements = []
 | 
						|
        self.components = []
 | 
						|
        self.selection = StructureSelection.All
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def from_xml(xml):
 | 
						|
        st = SyntacticStructure()
 | 
						|
        st.id = xml.get('id')
 | 
						|
        st.lbs = xml.get('LBS')
 | 
						|
        
 | 
						|
        assert(len(list(xml)) == 1)
 | 
						|
        system = next(iter(xml))
 | 
						|
 | 
						|
        assert(system.get('type') == 'JOS')
 | 
						|
        components, dependencies, definitions = list(system)
 | 
						|
 | 
						|
        deps = [ (dep.get('from'), dep.get('to'), dep.get('label'), dep.get('order')) for dep in dependencies ]
 | 
						|
        comps = { comp.get('cid'): dict(comp.items()) for comp in components }
 | 
						|
 | 
						|
        restrs, forms = {}, {}
 | 
						|
 | 
						|
        for comp in definitions:
 | 
						|
            n = comp.get('cid')
 | 
						|
            restrs[n] = None
 | 
						|
            forms[n] = None
 | 
						|
 | 
						|
            for el in comp:
 | 
						|
                if el.tag.startswith("restriction"):
 | 
						|
                    assert(restrs[n] is None)
 | 
						|
                    restrs[n] = el
 | 
						|
                elif el.tag.startswith("representation"):
 | 
						|
                    st.add_representation(n, el, forms)
 | 
						|
                else:
 | 
						|
                    raise NotImplementedError("definition??")
 | 
						|
 | 
						|
        fake_root_component = Component({'cid': '#', 'type': 'other'})
 | 
						|
        st.components, st.selection = fake_root_component.find_next(deps, comps, restrs, forms)
 | 
						|
        return st
 | 
						|
 | 
						|
    def add_representation(self, n, el, forms):
 | 
						|
        if el.tag == "representation":
 | 
						|
            els = [el]
 | 
						|
        elif el.tag == "representation_and":
 | 
						|
            els = list(el)
 | 
						|
        else:
 | 
						|
            raise NotImplementedError("Unknown representation tag: {}".format(el.tag))
 | 
						|
        
 | 
						|
        for el in els:
 | 
						|
            if el.get('basic') == 'form':
 | 
						|
                assert(forms[n] is None)
 | 
						|
                forms[n] = el
 | 
						|
            elif el.get('basic') == "agreement":
 | 
						|
                self.add_agreement(n, el)
 | 
						|
            else:
 | 
						|
                logging.warning("Strange representation (basic={}) in structure {}. Skipping"
 | 
						|
                        .format(el.get('basic'), self.id))
 | 
						|
                continue
 | 
						|
 | 
						|
    def add_agreement(self, n, el):
 | 
						|
        assert(el.get('head')[:4] == 'cid_')
 | 
						|
 | 
						|
        n1 = n
 | 
						|
        n2 = el.get('head')[4:]
 | 
						|
        agreement_str = next(iter(el)).get('agreement')
 | 
						|
 | 
						|
        self.agreements.append({
 | 
						|
            'n1': n1,
 | 
						|
            'n2': n2,
 | 
						|
            'match': agreement_str.split('|')})
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        comp_str = "\n".join(str(comp) for comp in self.components)
 | 
						|
 | 
						|
        agrs = "\n".join("({} -[{}]- {}) ".format(
 | 
						|
            a['n1'], "|".join(a['match']), a['n2']) for a in self.agreements)
 | 
						|
 | 
						|
        links_str = "\n".join(self.components[0].tree())
 | 
						|
 | 
						|
        return "{} LBS {}\nCOMPONENTS\n{}\nAGREEMENTS\n{}\nLINKS\n{}\n{}".format(
 | 
						|
                self.id, self.lbs, comp_str, agrs, links_str, "-" * 40)
 | 
						|
 | 
						|
    def get_component(self, idx):
 | 
						|
        for c in self.components:
 | 
						|
            if c.idx == idx:
 | 
						|
                return c
 | 
						|
        raise RuntimeError("Unknown component id: {}".format(idx))
 | 
						|
 | 
						|
    def check_agreements(self, match):
 | 
						|
        for agr in self.agreements:
 | 
						|
            w1 = match[agr['n1']]
 | 
						|
            w2 = match[agr['n2']]
 | 
						|
 | 
						|
            for agr_case in agr['match']:
 | 
						|
                t1 = w1.msd[0]
 | 
						|
                v1 = TAGSET[t1].index(agr_case)
 | 
						|
                assert(v1 >= 0)
 | 
						|
                # if none specified: nedolocnik, always agrees
 | 
						|
                if v1 + 1 >= len(w1.msd): 
 | 
						|
                    continue 
 | 
						|
                # first is uppercase, not in TAGSET
 | 
						|
                m1 = w1.msd[v1 + 1]
 | 
						|
 | 
						|
                # REPEAT (not DRY!)
 | 
						|
                t2 = w2.msd[0]
 | 
						|
                v2 = TAGSET[t2].index(agr_case)
 | 
						|
                assert(v2 >= 0)
 | 
						|
                if v2 + 1 >= len(w2.msd): 
 | 
						|
                    continue 
 | 
						|
                m2 = w2.msd[v2 + 1]
 | 
						|
 | 
						|
                # match!
 | 
						|
                if '-' not in [m1, m2] and m1 != m2:
 | 
						|
                    return False
 | 
						|
 | 
						|
        return True
 | 
						|
 | 
						|
    def check_form(self, match):
 | 
						|
        for midx, w in match.items():
 | 
						|
            c = self.get_component(midx)
 | 
						|
            for key, value in c.selection.items():
 | 
						|
                t = w.msd[0]
 | 
						|
                v = TAGSET[t].index(key.lower())
 | 
						|
                f1 = w.msd[v + 1]
 | 
						|
                f2 = CODES[value]
 | 
						|
                
 | 
						|
                if '-' not in [f1, f2] and f1 != f2:
 | 
						|
                    return False
 | 
						|
 | 
						|
        return True
 | 
						|
 | 
						|
    def match(self, word):
 | 
						|
        matches = self.components[0].match(word)
 | 
						|
        if matches is None:
 | 
						|
            return []
 | 
						|
 | 
						|
        to_ret = []
 | 
						|
        for m in matches:
 | 
						|
            if not self.check_agreements(m):
 | 
						|
                bad = "Agreement"
 | 
						|
            elif not self.check_form(m):
 | 
						|
                bad = "Form"
 | 
						|
            else:
 | 
						|
                bad = "OK"
 | 
						|
 | 
						|
            to_ret.append((m, bad))
 | 
						|
 | 
						|
        return to_ret
 | 
						|
 | 
						|
 | 
						|
def build_structures(filename):
 | 
						|
    structures = []
 | 
						|
    with open(filename, 'r') as fp:
 | 
						|
        et = ElementTree.XML(fp.read())
 | 
						|
        for structure in et.iter('syntactic_structure'):
 | 
						|
            to_append = SyntacticStructure.from_xml(structure)
 | 
						|
            if to_append is None:
 | 
						|
                continue
 | 
						|
            structures.append(to_append)
 | 
						|
    return structures
 | 
						|
 | 
						|
 | 
						|
def get_msd(comp):
 | 
						|
    d = dict(comp.items())
 | 
						|
    if 'msd' in d:
 | 
						|
        return d['msd']
 | 
						|
    elif 'ana' in d:
 | 
						|
        return d['ana'][4:]
 | 
						|
    else:
 | 
						|
        logging.error(d, file=sys.stderr)
 | 
						|
        raise NotImplementedError("MSD?")
 | 
						|
 | 
						|
class Word:
 | 
						|
    def __init__(self, xml, do_msd_translate):
 | 
						|
        self.lemma = xml.get('lemma')
 | 
						|
        self.msd = MSD_TRANSLATE[get_msd(xml)] if do_msd_translate else get_msd(xml)
 | 
						|
        self.id = xml.get('id')
 | 
						|
        self.text = xml.text
 | 
						|
        self.links = defaultdict(list)
 | 
						|
 | 
						|
        last_num = self.id.split('.')[-1]
 | 
						|
        if last_num[0] not in '0123456789':
 | 
						|
            last_num = last_num[1:]
 | 
						|
        self.int_id = int(last_num)
 | 
						|
 | 
						|
        assert(None not in (self.id, self.lemma, self.msd))
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def pcWord(pc, do_msd_translate):
 | 
						|
        pc.set('lemma', pc.text)
 | 
						|
        pc.set('msd', "N" if do_msd_translate else "U")
 | 
						|
        return Word(pc, do_msd_translate)
 | 
						|
 | 
						|
    def add_link(self, link, to):
 | 
						|
        self.links[link].append(to)
 | 
						|
 | 
						|
    def get_links(self, link):
 | 
						|
        if link not in self.links and "|" in link:
 | 
						|
            for l in link.split('|'):
 | 
						|
                self.links[link].extend(self.links[l])
 | 
						|
 | 
						|
        return self.links[link]
 | 
						|
 | 
						|
 | 
						|
def is_root_id(id_):
 | 
						|
    return len(id_.split('.')) == 3
 | 
						|
 | 
						|
 | 
						|
def load_files(args):
 | 
						|
    filenames = args.input
 | 
						|
    skip_id_check = args.skip_id_check
 | 
						|
    do_msd_translate = not args.no_msd_translate
 | 
						|
 | 
						|
    for n, fname in enumerate(filenames):
 | 
						|
        if args.count_files:
 | 
						|
            status = " :: {} / {}".format(n, len(filenames))
 | 
						|
        else:
 | 
						|
            status = ""
 | 
						|
        yield load_tei_file(fname, skip_id_check, do_msd_translate, args.pc_tag, status)
 | 
						|
 | 
						|
 | 
						|
def load_tei_file(filename, skip_id_check, do_msd_translate, pc_tag, status):
 | 
						|
    logging.info("LOADING FILE: {}{}".format(filename, status))
 | 
						|
 | 
						|
    with open(filename, 'r') as fp:
 | 
						|
        xmlstring = re.sub(' xmlns="[^"]+"', '', fp.read(), count=1)
 | 
						|
        xmlstring = xmlstring.replace(' xml:', ' ')
 | 
						|
        et = ElementTree.XML(xmlstring)
 | 
						|
 | 
						|
    words = {}
 | 
						|
    for w in et.iter("w"):
 | 
						|
        words[w.get('id')] = Word(w, do_msd_translate)
 | 
						|
    for pc in et.iter(pc_tag):
 | 
						|
        words[pc.get('id')] = Word.pcWord(pc, do_msd_translate)
 | 
						|
 | 
						|
    for l in et.iter("link"):
 | 
						|
        if 'dep' in l.keys():
 | 
						|
            ana = l.get('afun')
 | 
						|
            lfrom = l.get('from')
 | 
						|
            dest = l.get('dep')
 | 
						|
        else:
 | 
						|
            ana = l.get('ana')
 | 
						|
            if ana[:4] != 'syn:': # dont bother...
 | 
						|
                continue
 | 
						|
            ana = ana[4:]
 | 
						|
            lfrom, dest = l.get('target').replace('#', '').split()
 | 
						|
 | 
						|
        if lfrom in words:
 | 
						|
            if not skip_id_check and is_root_id(lfrom):
 | 
						|
                logging.error("NOO: ", lfrom)
 | 
						|
                sys.exit(1)
 | 
						|
 | 
						|
            if dest in words:
 | 
						|
                next_word = words[dest]
 | 
						|
                words[lfrom].add_link(ana, next_word)
 | 
						|
            else:
 | 
						|
                logging.error("Unknown id: {}".format(dest))
 | 
						|
                sys.exit(1)
 | 
						|
 | 
						|
        else:
 | 
						|
            # strange errors, just skip...
 | 
						|
            pass
 | 
						|
 | 
						|
    return list(words.values())
 | 
						|
 | 
						|
class Writer:
 | 
						|
    def __init__(self, args):
 | 
						|
        self.group = args.group
 | 
						|
        self.lemma_only = args.lemma_only
 | 
						|
        self.without_rep = args.without_rep
 | 
						|
        self.output_file = args.output
 | 
						|
        self.multiple_output = args.multiple_output
 | 
						|
 | 
						|
        self.sort_by = int(args.sort_by)
 | 
						|
        self.sort_order = args.sort_reversed
 | 
						|
 | 
						|
    def header(self):
 | 
						|
        cols = ["Lemma"]
 | 
						|
        if not self.lemma_only:
 | 
						|
            cols = ["Token_ID", "Word_form"] + cols + ["Msd"]
 | 
						|
 | 
						|
        if not self.without_rep:
 | 
						|
            cols.append("Representative_form")
 | 
						|
 | 
						|
        assert(len(cols) == self.length())
 | 
						|
        cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS) for thd in cols]
 | 
						|
        cols = ["Structure_ID"] + cols + ["Collocation_ID"]
 | 
						|
        
 | 
						|
        if not self.without_rep:
 | 
						|
            cols.append("Joint_representative_form")
 | 
						|
        if self.group:
 | 
						|
            cols.append("Frequency")
 | 
						|
 | 
						|
        return cols
 | 
						|
 | 
						|
    def length(self):
 | 
						|
        return 1 + 3 * int(not self.lemma_only) + int(not self.without_rep)
 | 
						|
 | 
						|
    def from_word(self, word):
 | 
						|
        if word is None:
 | 
						|
            return "" * self.length()
 | 
						|
        else:
 | 
						|
            cols = [word.lemma]
 | 
						|
            if not self.lemma_only:
 | 
						|
                cols = [word.id, word.text] + cols + [word.msd]
 | 
						|
            if not self.without_rep:
 | 
						|
                cols += [""] #not yet implemented...
 | 
						|
 | 
						|
        return cols
 | 
						|
    
 | 
						|
    def sorted_rows(self, rows):
 | 
						|
        if self.sort_by < 0 or len(rows) < 2:
 | 
						|
            return rows
 | 
						|
 | 
						|
        if len(rows[0]) <= self.sort_by:
 | 
						|
            logging.warning("Cannot sort by column #{}: Not enough columns!".format(len(rows[0])))
 | 
						|
            return rows
 | 
						|
        
 | 
						|
        try:
 | 
						|
            int(rows[0][self.sort_by])
 | 
						|
            key=lambda row: int(row[self.sort_by])
 | 
						|
        except ValueError:
 | 
						|
            key=lambda row: row[self.sort_by].lower()
 | 
						|
 | 
						|
        return sorted(rows, key=key, reverse=self.sort_order)
 | 
						|
 | 
						|
    def write_header(self, file_handler):
 | 
						|
        file_handler.write(", ".join(self.header()) + "\n")
 | 
						|
 | 
						|
    def write_out_worker(self, file_handler, matches, structure_id, components, colocation_ids):
 | 
						|
        rows = []
 | 
						|
        for m, reason, cid in matches:
 | 
						|
            to_write = []
 | 
						|
 | 
						|
            for idx, comp in enumerate(components):
 | 
						|
                idx = str(idx + 1)
 | 
						|
                word = m[idx] if idx in m else None
 | 
						|
                to_write.extend(self.from_word(word))
 | 
						|
 | 
						|
            # make them equal size
 | 
						|
            to_write.extend([""] * (MAX_NUM_COMPONENTS * self.length() - len(to_write))) 
 | 
						|
            to_write = [structure_id] + to_write + [colocation_ids.to_id(cid)]
 | 
						|
 | 
						|
            if not self.without_rep:
 | 
						|
                to_write.append("") # not yet implemented...
 | 
						|
 | 
						|
            if self.group:
 | 
						|
                if colocation_ids.should_write(cid):
 | 
						|
                    to_write.append(colocation_ids.num(cid))
 | 
						|
                    colocation_ids.set_written(cid)
 | 
						|
                else:
 | 
						|
                    continue
 | 
						|
 | 
						|
            rows.append(to_write)
 | 
						|
 | 
						|
        if len(rows) > 0:
 | 
						|
            rows = self.sorted_rows(rows)
 | 
						|
            file_handler.write("\n".join([", ".join(row) for row in rows]) + "\n")
 | 
						|
            file_handler.flush()
 | 
						|
 | 
						|
    def write_out(self, matches, structures, colocation_ids):
 | 
						|
        def fp_close(fp_):
 | 
						|
            if fp_ != sys.stdout:
 | 
						|
                fp_.close()
 | 
						|
 | 
						|
        def fp_open(snum=None):
 | 
						|
            if self.output_file is None:
 | 
						|
                return sys.stdout
 | 
						|
            elif snum is None:
 | 
						|
                return open(self.output_file, "w")
 | 
						|
            else:
 | 
						|
                return open("{}.{}".format(self.output_file, snum), "w")
 | 
						|
 | 
						|
        if not self.multiple_output:
 | 
						|
            fp = fp_open()
 | 
						|
            self.write_header(fp)
 | 
						|
 | 
						|
        for s in structures:
 | 
						|
            if self.multiple_output:
 | 
						|
                fp=fp_open(s.id)
 | 
						|
                self.write_header(fp)
 | 
						|
 | 
						|
            sid_matches = matches[s.id]
 | 
						|
            self.write_out_worker(fp, sid_matches, s.id, s.components, colocation_ids)
 | 
						|
 | 
						|
            if self.multiple_output:
 | 
						|
                fp_close(fp)
 | 
						|
            
 | 
						|
        if not self.multiple_output:
 | 
						|
            fp_close(fp)
 | 
						|
 | 
						|
 | 
						|
class ColocationIds:
 | 
						|
    def __init__(self):
 | 
						|
        self.data = {}
 | 
						|
        self.min_frequency = args.group
 | 
						|
 | 
						|
    def add_match(self, key):
 | 
						|
        if key in self.data:
 | 
						|
            self.data[key][1] += 1
 | 
						|
        else:
 | 
						|
            self.data[key] = [str(len(self.data) + 1), 1, False]
 | 
						|
    
 | 
						|
    def get(self, key, n):
 | 
						|
        return self.data[key][n]
 | 
						|
 | 
						|
    def should_write(self, key):
 | 
						|
        return self.get(key, 1) >= self.min_frequency and not self.get(key, 2)
 | 
						|
 | 
						|
    def num(self, key):
 | 
						|
        return str(self.get(key, 1))
 | 
						|
 | 
						|
    def to_id(self, key):
 | 
						|
        return self.get(key, 0)
 | 
						|
 | 
						|
    def set_written(self, key):
 | 
						|
        self.data[key][2] = True
 | 
						|
 | 
						|
    def merge_matches(self, matches, new_matches):
 | 
						|
        for _id, nms in new_matches.items():
 | 
						|
            for nm in nms:
 | 
						|
                matches[_id].append(nm)
 | 
						|
                self.add_match(nm[2])
 | 
						|
 | 
						|
        return matches
 | 
						|
 | 
						|
 | 
						|
def match_file(words, structures):
 | 
						|
    matches = {s.id: [] for s in structures}
 | 
						|
 | 
						|
    for idx, s in enumerate(structures):
 | 
						|
        logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id))
 | 
						|
        for w in words:
 | 
						|
            mhere = s.match(w)
 | 
						|
            logging.debug("  GOT: {}".format(len(mhere)))
 | 
						|
            for match, reason in mhere: 
 | 
						|
                colocation_id = [(idx, w.lemma) for idx, w in match.items()]
 | 
						|
                colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x:x[0]))
 | 
						|
                colocation_id = tuple(colocation_id)
 | 
						|
 | 
						|
                matches[s.id].append((match, reason, colocation_id))
 | 
						|
 | 
						|
    return matches
 | 
						|
 | 
						|
 | 
						|
def main(input_file, structures_file, args):
 | 
						|
    writer = Writer(args)
 | 
						|
    structures = build_structures(structures_file)
 | 
						|
    for s in structures:
 | 
						|
        logging.debug(str(s))
 | 
						|
 | 
						|
    colocation_ids = ColocationIds()
 | 
						|
    matches = {s.id: [] for s in structures}
 | 
						|
 | 
						|
    if args.parallel:
 | 
						|
        num_parallel = int(args.parallel)
 | 
						|
 | 
						|
        # make temporary directory to hold temporary files
 | 
						|
        with tempfile.TemporaryDirectory() as tmpdirname:
 | 
						|
            cmd = sys.argv 
 | 
						|
            for inpt in args.input:
 | 
						|
                if inpt in cmd:
 | 
						|
                    cmd.remove(inpt)
 | 
						|
 | 
						|
            # remove "--parallel X"
 | 
						|
            pidx = cmd.index('--parallel')
 | 
						|
            del cmd[pidx]
 | 
						|
            del cmd[pidx]
 | 
						|
 | 
						|
            def func(n): 
 | 
						|
                cmdn = [sys.executable] + cmd + [args.input[n], "--match-to-file", "{}/{}.p".format(tmpdirname, n)]
 | 
						|
                subprocess.check_call(cmdn)
 | 
						|
                return n
 | 
						|
 | 
						|
            # use ThreadPoolExecuter to run subprocesses in parallel using py threads
 | 
						|
            with concurrent.futures.ThreadPoolExecutor(max_workers=num_parallel) as executor:
 | 
						|
                # fancy interface to wait for threads to finish
 | 
						|
                for id_input in executor.map(func, [i for i, _ in enumerate(args.input)]):
 | 
						|
                    with open("{}/{}.p".format(tmpdirname, id_input), "rb") as fp:
 | 
						|
                        new_matches = pickle.load(fp)
 | 
						|
                    matches = colocation_ids.merge_matches(matches, new_matches)
 | 
						|
 | 
						|
    else:
 | 
						|
        for words in load_files(args):
 | 
						|
            new_matches = match_file(words, structures)
 | 
						|
            # just save to temporary file, used for children of a parallel process
 | 
						|
            if args.match_to_file is not None:
 | 
						|
                with open(args.match_to_file, "wb") as fp:
 | 
						|
                    pickle.dump(new_matches, fp)
 | 
						|
                    return
 | 
						|
            else:
 | 
						|
                matches = colocation_ids.merge_matches(matches, new_matches)
 | 
						|
 | 
						|
    writer.write_out(matches, structures, colocation_ids)
 | 
						|
 | 
						|
    logging.debug([(k, len(v)) for k, v in matches.items()])
 | 
						|
    logging.debug(sum(len(v) for _, v in matches.items()))
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    parser = argparse.ArgumentParser(description='Extract structures from a parsed corpus.')
 | 
						|
    parser.add_argument('structures', help='Structures definitions in xml file')
 | 
						|
    parser.add_argument('input', help='input xml file in `ssj500k form`, can list more than one', nargs='+')
 | 
						|
    parser.add_argument('--output', help='Output file (if none given, then output to stdout)')
 | 
						|
 | 
						|
    parser.add_argument('--no-msd-translate', help='MSDs are translated from slovene to english by default', action='store_true')
 | 
						|
    parser.add_argument('--skip-id-check', help='Skips checks for ids of <w> and <pc>, if they are in correct format', action='store_true')
 | 
						|
    parser.add_argument('--lemma-only', help='Will not write word ids, forms and msds in output', action='store_true')
 | 
						|
    parser.add_argument('--without-rep', help='Will not write representtaions in output', action='store_true')
 | 
						|
    parser.add_argument('--group', help='Group collocations with same collocation ID', type=int, default=0, const=1, nargs='?')
 | 
						|
    parser.add_argument('--verbose', help='Enable verbose output to stderr', choices=["warning", "info", "debug"], default="info", const="info", nargs='?')
 | 
						|
    parser.add_argument('--count-files', help="Count files: more verbose output", action='store_true')
 | 
						|
    parser.add_argument('--multiple-output', help='Generate one output for each syntactic structure', action='store_true')
 | 
						|
 | 
						|
    parser.add_argument('--sort-by', help="Sort by a this column (index)", type=int, default=-1)
 | 
						|
    parser.add_argument('--sort-reversed', help="Sort in reversed ored", action='store_true')
 | 
						|
 | 
						|
    parser.add_argument('--pc-tag', help='Tag for separators, usually pc or c', default="pc")
 | 
						|
    parser.add_argument('--parallel', help='Run in multiple processes, should speed things up')
 | 
						|
    parser.add_argument('--match-to-file', help='Do not use!')
 | 
						|
 | 
						|
    args = parser.parse_args()
 | 
						|
    logging.basicConfig(stream=sys.stderr, level=args.verbose.upper())
 | 
						|
 | 
						|
    start = time.time()
 | 
						|
    main(args.input, args.structures, args)
 | 
						|
    logging.info("TIME: {}".format(time.time() - start))
 |