luscenje_struktur/wani.py
Ozbolt Menegatti c0939fbbd4 fixed performance bug for representations
No more creating millions of namedtuple classes. Works about 15x faster
2019-06-11 10:26:10 +02:00

1551 lines
48 KiB
Python

from xml.etree import ElementTree
import re
from enum import Enum
from collections import defaultdict, namedtuple, Counter
import sys
import logging
import argparse
import pickle
import time
import subprocess
import concurrent.futures
import tempfile
from math import log2
from msd_translate import MSD_TRANSLATE
try:
from tqdm import tqdm
except ImportError:
tqdm = lambda x: x
MAX_NUM_COMPONENTS = -1
CODES = {
"Noun": "N",
"Verb": "V",
"Adjective": "A",
"Adverb": "R",
"Pronoun": "P",
"Numeral": "M",
"Preposition": "S",
"Conjunction": "C",
"Particle": "Q",
"Interjection": "I",
"Abbreviation": "Y",
"Residual": "X",
'common': 'c',
'proper': 'p',
'masculine': 'm',
'feminine': 'f',
'neuter': 'n',
"singular": "s",
"dual": "d",
"plural": "p",
"nominative": "n",
"genitive": "g",
"dative": "d",
"accusative": "a",
"locative": "l",
"instrumental": "i",
"no": "n",
"yes": "y",
"main": "m",
"auxiliary": "a",
"perfective": "e",
"progressive": "p",
"biaspectual": "b",
"infinitive": "n",
"supine": "u",
"participle": "p",
"present": "r",
"future": "f",
"conditional": "c",
"imperative": "m",
"first": "1",
"second": "2",
"third": "3",
"general": "g",
"possessive": "s",
"positive": "p",
"comparative": "c",
"superlative": "s",
"personal": "p",
"demonstrative": "d",
"relative": "r",
"reflexive": "x",
"interrogative": "q",
"indefinite": "i",
"negative": "z",
"bound": "b",
"digit": "d",
"roman": "r",
"letter": "l",
"cardinal": "c",
"ordinal": "o",
"pronominal": "p",
"special": "s",
"coordinating": "c",
"subordinating": "s",
"foreign": "f",
"typo": "t",
"program": "p",
}
TAGSET = {
"N": ['type', 'gender', 'number', 'case', 'animate'],
"V": ['type', 'aspect', 'vform', 'person', 'number', 'gender', 'negative'],
"A": ['type', 'degree', 'gender', 'number', 'case', 'definiteness'],
"R": ['type', 'degree'],
"P": ['type', 'person', 'gender', 'number', 'case', 'owner_number', 'owned_gender', 'clitic'],
"M": ['form', 'type', 'gender', 'number', 'case', 'definiteness'],
"S": ['case'],
"C": ['type'],
"Q": [],
"I": [],
"Y": [],
"X": ['type']
}
CATEGORY_BASES = {
"N": ['.'] * 5,
"V": ['.'] * 7,
"A": ['.'] * 6,
"R": ['.'] * 2,
"P": ['.'] * 6,
"M": ['.'] * 6,
"S": ['.'] * 1,
"C": ['.'] * 1,
"Q": [],
"I": [],
"Y": [],
"X": ['.'] * 1
}
class ComponentType(Enum):
Other = 0
Core = 2
Core2w = 3
class RestrictionType(Enum):
Morphology = 0
Lexis = 1
MatchAll = 2
class Order(Enum):
FromTo = 0
ToFrom = 1
Any = 2
@staticmethod
def new(order):
if order is None:
return Order.Any
elif order == "to-from":
return Order.ToFrom
elif order == "from-to":
return Order.FromTo
else:
raise NotImplementedError("What kind of ordering is: {}".format(order))
def match(self, from_w, to_w):
if self is Order.Any:
return True
fi = from_w.int_id
ti = to_w.int_id
if self is Order.FromTo:
return fi < ti
elif self is Order.ToFrom:
return ti < fi
else:
raise NotImplementedError("Should not be here: Order match")
class ComponentRepresentation:
def __init__(self, data, word_renderer):
self.data = data
self.word_renderer = word_renderer
self.words = []
self.rendition_text = None
self.agreement = []
def get_agreement(self):
return []
def add_word(self, word):
self.words.append(word)
def render(self):
if self.rendition_text is None:
self.rendition_text = self._render()
def _render(self):
raise NotImplementedError("Not implemented for class: {}".format(type(self)))
class LemmaCR(ComponentRepresentation):
def _render(self):
return self.words[0].lemma if len(self.words) > 0 else None
class LexisCR(ComponentRepresentation):
def _render(self):
return self.data['lexis']
class WordFormAllCR(ComponentRepresentation):
def _render(self):
if len(self.words) == 0:
return None
else:
forms = [w.text.lower() for w in self.words]
return "/".join(set(forms))
class WordFormAnyCR(ComponentRepresentation):
def _render(self):
text_forms = {}
msd_lemma_txt_triplets = Counter([(w.msd, w.lemma, w.text) for w in self.words])
for (msd, lemma, text), _n in reversed(msd_lemma_txt_triplets.most_common()):
text_forms[(msd, lemma)] = text
words_counter = []
for word in self.words:
words_counter.append((word.msd, word.lemma))
sorted_words = sorted(set(words_counter), key=lambda x: -words_counter.count(x))
for word_msd, word_lemma in sorted_words:
for agr in self.agreement:
if not agr.match(word_msd):
break
else:
for agr in self.agreement:
agr.confirm_match()
if word_lemma is None:
return None
else:
return text_forms[(word_msd, word_lemma)]
class WordFormMsdCR(WordFormAnyCR):
def __init__(self, *args):
super().__init__(*args)
self.lemma = None
self.msd = None
def check_msd(self, word_msd):
if 'msd' not in self.data:
return True
selectors = self.data['msd']
for key, value in selectors.items():
t = word_msd[0]
v = TAGSET[t].index(key.lower())
f1 = word_msd[v + 1]
f2 = CODES[value]
if '-' not in [f1, f2] and f1 != f2:
return False
return True
def add_word(self, word):
if self.lemma is None:
self.lemma = word.lemma
self.msd = word.msd
if self.check_msd(word.msd):
super().add_word(word)
def _render(self):
msd = self.word_renderer.get_lemma_msd(self.lemma, self.msd)
self.words.append(WordMsdOnly(msd))
return super()._render()
class WordFormAgreementCR(WordFormMsdCR):
def __init__(self, data, word_renderer):
super().__init__(data, word_renderer)
self.rendition_candidate = None
def get_agreement(self):
return self.data['other']
def match(self, word_msd):
existing = [(w.msd, w.text) for w in self.words]
lemma_available_words = self.word_renderer.available_words(self.lemma, existing)
for candidate_msd, candidate_text in lemma_available_words:
if self.msd[0] != candidate_msd[0]:
continue
if WordFormAgreementCR.check_agreement(word_msd, candidate_msd, self.data['agreement']):
if self.check_msd(candidate_msd):
self.rendition_candidate = candidate_text
return True
return False
def confirm_match(self):
self.rendition_text = self.rendition_candidate
@staticmethod
def check_agreement(msd1, msd2, agreements):
for agr_case in agreements:
t1 = msd1[0]
# if not in msd, some strange msd was tries, skipping...
if agr_case not in TAGSET[t1]:
logging.warning("Cannot do agreement: {} for msd {} not found!"
.format(agr_case, msd1))
return False
v1 = TAGSET[t1].index(agr_case)
# if none specified: nedolocnik, always agrees
if v1 + 1 >= len(msd1):
continue
# first is uppercase, not in TAGSET
m1 = msd1[v1 + 1]
# REPEAT (not DRY!)
t2 = msd2[0]
if agr_case not in TAGSET[t2]:
logging.warning("Cannot do agreement: {} for msd {} not found!"
.format(agr_case, msd2))
return False
v2 = TAGSET[t2].index(agr_case)
if v2 + 1 >= len(msd2):
continue
m2 = msd2[v2 + 1]
# match!
if '-' not in [m1, m2] and m1 != m2:
return False
return True
def render(self):
pass
class ComponentRendition:
def __init__(self):
self.more = {}
self.representation_factory = ComponentRepresentation
def add_feature(self, feature):
if 'rendition' in feature:
if feature['rendition'] == "lemma":
self.representation_factory = LemmaCR
elif feature['rendition'] == "word_form":
# just by default, changes with selection
self.representation_factory = WordFormAnyCR
elif feature['rendition'] == "lexis":
self.representation_factory = LexisCR
self.more['lexis'] = feature['string']
else:
raise NotImplementedError("Representation rendition: {}".format(feature))
elif 'selection' in feature:
if feature['selection'] == "msd":
# could already be agreement
if self.representation_factory != WordFormAgreementCR:
self.representation_factory = WordFormMsdCR
self.more['msd'] = {k: v for k, v in feature.items() if k != 'selection'}
elif feature['selection'] == "all":
self.representation_factory = WordFormAllCR
elif feature['selection'] == 'agreement':
assert feature['head'][:4] == 'cid_'
assert feature['msd'] is not None
self.representation_factory = WordFormAgreementCR
self.more['agreement'] = feature['msd'].split('+')
self.more['other'] = feature['head'][4:]
else:
raise NotImplementedError("Representation selection: {}".format(feature))
def cr_instance(self, word_renderer):
return self.representation_factory(self.more, word_renderer)
@staticmethod
def set_representations(match, word_renderer):
representations = {}
for c in match.structure.components:
representations[c.idx] = []
for rep in c.representation:
representations[c.idx].append(rep.cr_instance(word_renderer))
for cid, reps in representations.items():
for rep in reps:
for agr in rep.get_agreement():
if len(representations[agr]) != 1:
n = len(representations[agr])
raise NotImplementedError(
"Structure {}: ".format(match.structure.id) +
"component {} has agreement".format(cid) +
" with component {}".format(agr) +
", however there are {} (!= 1) representations".format(n) +
" of component {}!".format(agr))
representations[agr][0].agreement.append(rep)
for words in match.matches:
# first pass, check everything but agreements
for w_id, w in words.items():
component = match.structure.get_component(w_id)
component_representations = representations[component.idx]
for representation in component_representations:
representation.add_word(w)
for cid, reps in representations.items():
for rep in reps:
rep.render()
for cid, reps in representations.items():
reps = [rep.rendition_text for rep in reps]
if reps == []:
pass
elif all(r is None for r in reps):
match.representations[cid] = None
else:
match.representations[cid] = " ".join(("" if r is None else r) for r in reps)
class ComponentStatus(Enum):
Optional = 0
Required = 1
Forbidden = 2
def get_level(restriction):
for feature in restriction:
if "level" in feature.keys():
lvl = feature.get("level")
else:
continue
raise RuntimeError("Unreachable!")
def determine_ppb(rgx):
if rgx[0] in ("A", "N", "R"):
return 0
elif rgx[0] == "V":
if 'a' in rgx[1]:
return 3
elif 'm' in rgx[1]:
return 1
else:
return 2
else:
return 4
class MorphologyRegex:
def __init__(self, restriction):
self.min_msd_length = 1
restr_dict = {}
for feature in restriction:
feature_dict = dict(feature.items())
match_type = True
if "filter" in feature_dict:
assert feature_dict['filter'] == "negative"
match_type = False
del feature_dict['filter']
assert len(feature_dict) == 1
key, value = next(iter(feature_dict.items()))
restr_dict[key] = (value, match_type)
assert 'POS' in restr_dict
category = restr_dict['POS'][0].capitalize()
cat_code = CODES[category]
rgx = [cat_code] + CATEGORY_BASES[cat_code]
del restr_dict['POS']
for attribute, (value, typ) in restr_dict.items():
index = TAGSET[cat_code].index(attribute.lower())
assert index >= 0
if '|' in value:
match = "".join(CODES[val] for val in value.split('|'))
else:
match = CODES[value]
match = "[{}{}]".format("" if typ else "^", match)
rgx[index + 1] = match
if typ:
self.min_msd_length = max(index + 1, self.min_msd_length)
self.re_objects = [re.compile(r) for r in rgx]
self.rgx = rgx
def __call__(self, text):
if len(text) <= self.min_msd_length:
return False
for c, r in zip(text, self.re_objects):
if not r.match(c):
return False
return True
class LexisRegex:
def __init__(self, restriction):
restr_dict = {}
for feature in restriction:
restr_dict.update(feature.items())
assert "lemma" in restr_dict
self.match_list = restr_dict['lemma'].split('|')
def __call__(self, text):
return text in self.match_list
class Restriction:
def __init__(self, restriction_tag):
self.ppb = 4 # polnopomenska beseda (0-4)
if restriction_tag is None:
self.type = RestrictionType.MatchAll
self.matcher = None
self.present = None
return
restriction_type = restriction_tag.get('type')
if restriction_type == "morphology":
self.type = RestrictionType.Morphology
self.matcher = MorphologyRegex(list(restriction_tag))
self.ppb = determine_ppb(self.matcher.rgx)
elif restriction_type == "lexis":
self.type = RestrictionType.Lexis
self.matcher = LexisRegex(list(restriction_tag))
else:
raise NotImplementedError()
def match(self, word):
if self.type == RestrictionType.Morphology:
match_to = word.msd
elif self.type == RestrictionType.Lexis:
match_to = word.lemma
elif self.type == RestrictionType.MatchAll:
return True
else:
raise RuntimeError("Unreachable!")
return self.matcher(match_to)
class Component:
def __init__(self, info):
idx = info['cid']
name = info['name'] if 'name' in info else None
typ = ComponentType.Core if info['type'] == "core" else ComponentType.Other
if 'status' not in info:
status = ComponentStatus.Required
elif info['status'] == 'forbidden':
status = ComponentStatus.Forbidden
elif info['status'] == 'obligatory':
status = ComponentStatus.Required
elif info['status'] == 'optional':
status = ComponentStatus.Optional
else:
raise NotImplementedError("strange status: {}".format(info['status']))
self.status = status
self.name = name
self.idx = idx
self.restrictions = []
self.next_element = []
self.representation = []
self.selection = {}
self.type = typ
self.iter_ctr = 0
def add_next(self, next_component, link_label, order):
self.next_element.append((next_component, link_label, Order.new(order)))
def set_restriction(self, restrictions_tag):
if restrictions_tag is None:
self.restrictions = [Restriction(None)]
elif restrictions_tag.tag == "restriction":
self.restrictions = [Restriction(restrictions_tag)]
elif restrictions_tag.tag == "restriction_or":
self.restrictions = [Restriction(el) for el in restrictions_tag]
else:
raise RuntimeError("Unreachable")
def set_representation(self, representation):
for rep in representation:
crend = ComponentRendition()
for feature in rep:
crend.add_feature(feature.attrib)
self.representation.append(crend)
def find_next(self, deps, comps, restrs, reprs):
to_ret = []
for d in deps:
if d[0] == self.idx:
_, idx, dep_label, order = d
next_component = Component(comps[idx])
next_component.set_restriction(restrs[idx])
next_component.set_representation(reprs[idx])
to_ret.append(next_component)
self.add_next(next_component, dep_label, order)
others = next_component.find_next(deps, comps, restrs, reprs)
to_ret.extend(others)
return to_ret
def name_str(self):
return "_" if self.name is None else self.name
def match(self, word):
m1 = self._match_self(word)
if m1 is None:
return None
mn = self._match_next(word)
if mn is None:
return None
to_ret = [m1]
for cmatch in mn:
# if good match but nothing to add, just continue
if len(cmatch) == 0:
continue
# if more than one match found for particular component
elif len(cmatch) > 1:
# if more than one match in multiple components, NOPE!
if len(to_ret) > 1:
logging.warning("Strange multiple match: {}".format(
str([w.id for w in cmatch[0].values()])))
for tr in to_ret:
tr.update(cmatch[0])
continue
# yeah, so we have found more than one match, =>
# more than one element in to_ret
to_ret = [{**dict(to_ret[0]), **m} for m in cmatch]
else:
for tr in to_ret:
tr.update(cmatch[0])
return to_ret
def _match_self(self, word):
# matching
for restr in self.restrictions:
if restr.match(word): # match either
return {self.idx: word}
def _match_next(self, word):
# matches for every component in links from this component
to_ret = []
# need to get all links that match
for next, link, order in self.next_element:
next_links = word.get_links(link)
to_ret.append([])
# good flag
good = next.status != ComponentStatus.Required
for next_word in next_links:
if not order.match(word, next_word):
continue
match = next.match(next_word)
if match is not None:
# special treatement for forbidden
if next.status == ComponentStatus.Forbidden:
good = False
break
else:
assert type(match) is list
to_ret[-1].extend(match)
good = True
# if none matched, nothing found!
if not good:
return None
return to_ret
class SyntacticStructure:
def __init__(self):
self.id = None
self.lbs = None
self.components = []
@staticmethod
def from_xml(xml):
st = SyntacticStructure()
st.id = xml.get('id')
st.lbs = xml.get('LBS')
assert len(list(xml)) == 1
system = next(iter(xml))
assert system.get('type') == 'JOS'
components, dependencies, definitions = list(system)
deps = [(dep.get('from'), dep.get('to'), dep.get('label'), dep.get('order'))
for dep in dependencies]
comps = {comp.get('cid'): dict(comp.items()) for comp in components}
restrs, forms = {}, {}
for comp in definitions:
n = comp.get('cid')
restrs[n] = None
forms[n] = []
for el in comp:
if el.tag.startswith("restriction"):
assert restrs[n] is None
restrs[n] = el
elif el.tag.startswith("representation"):
st.add_representation(n, el, forms)
else:
raise NotImplementedError("Unknown definition: {} in structure {}"
.format(el.tag, st.id))
fake_root_component = Component({'cid': '#', 'type': 'other'})
st.components = fake_root_component.find_next(deps, comps, restrs, forms)
st.determine_core2w()
return st
def determine_core2w(self):
ppb_components = []
for c in self.components:
if c.type != ComponentType.Core:
continue
ppb = 4
for r in c.restrictions:
ppb = min(r.ppb, ppb)
ppb_components.append((c, ppb))
ppb_components = sorted(ppb_components, key=lambda c: c[1])
if len(ppb_components) > 2 and ppb_components[1][1] == ppb_components[2][1]:
raise RuntimeError("Cannot determine 2 'jedrna polnopomenska beseda' for", self.id)
for c, _ in ppb_components[:2]:
c.type = ComponentType.Core2w
def add_representation(self, n, rep_el, forms):
assert rep_el.tag == "representation"
to_add = []
for el in rep_el:
assert el.tag == "feature"
if 'rendition' in el.attrib or 'selection' in el.attrib:
to_add.append(el)
else:
logging.warning("Strange representation feature in structure {}. Skipping"
.format(self.id))
continue
forms[n].append(to_add)
def get_component(self, idx):
for c in self.components:
if c.idx == idx:
return c
raise RuntimeError("Unknown component id: {}".format(idx))
def match(self, word):
matches = self.components[0].match(word)
return [] if matches is None else matches
def load_structures(filename):
with open(filename, 'r') as fp:
et = ElementTree.XML(fp.read())
return build_structures(et), get_lemma_features(et)
def build_structures(et):
global MAX_NUM_COMPONENTS
structures = []
for structure in et.iter('syntactic_structure'):
to_append = SyntacticStructure.from_xml(structure)
if to_append is None:
continue
structures.append(to_append)
MAX_NUM_COMPONENTS = max(MAX_NUM_COMPONENTS, len(to_append.components))
return structures
def get_lemma_features(et):
lf = et.find('lemma_features')
if lf is None:
return {}
result = {}
for pos in lf.iter('POS'):
rgx_list = MorphologyRegex(pos).rgx
rgx_str = ""
for position in rgx_list:
if position == ".":
rgx_str += " "
elif len(position) == 1:
rgx_str += position
elif len(position) == 3 and position[0] == "[" and position[2] == "]":
rgx_str += position[1]
else:
raise RuntimeError("Strange rgx for lemma_feature...")
assert rgx_str[0].isupper()
result[rgx_str[0]] = rgx_str.strip().replace(' ', '-')
return result
def get_msd(comp):
d = dict(comp.items())
if 'msd' in d:
return d['msd']
elif 'ana' in d:
return d['ana'][4:]
else:
logging.error(d, file=sys.stderr)
raise NotImplementedError("MSD?")
class WordMsdOnly:
def __init__(self, msd):
self.msd = msd
self.lemma = None
self.text = None
def most_frequent_text(self, _):
return None
class Word:
def __init__(self, xml, do_msd_translate):
self.lemma = xml.get('lemma')
self.msd = MSD_TRANSLATE[get_msd(xml)] if do_msd_translate else get_msd(xml)
self.id = xml.get('id')
self.text = xml.text
self.links = defaultdict(list)
last_num = self.id.split('.')[-1]
if last_num[0] not in '0123456789':
last_num = last_num[1:]
self.int_id = int(last_num)
assert None not in (self.id, self.lemma, self.msd)
@staticmethod
def pc_word(pc, do_msd_translate):
pc.set('lemma', pc.text)
pc.set('msd', "N" if do_msd_translate else "U")
return Word(pc, do_msd_translate)
def add_link(self, link, to):
self.links[link].append(to)
def get_links(self, link):
if link not in self.links and "|" in link:
for l in link.split('|'):
self.links[link].extend(self.links[l])
return self.links[link]
def most_frequent_text(self, word_renderer):
return word_renderer.render(self.lemma, self.msd)
class WordMsdRenderer:
def __init__(self, lemma_features):
self.all_words = []
self.rendered_words = {}
self.frequent_words = {}
self.num_words = {}
self.lemma_msd = {}
self.lemma_features = lemma_features
self.memoized_msd_merges = {}
def add_words(self, words):
self.all_words.extend(words)
def num_all_words(self):
return len(self.all_words)
def generate_renders(self):
num_words = defaultdict(int)
data = defaultdict(lambda: defaultdict(list))
for w in self.all_words:
data[w.lemma][w.msd].append(w.text)
for lemma, ld in data.items():
self.rendered_words[lemma] = {}
freq_words = defaultdict(int)
common_msd = "*" * 10
for msd, texts in ld.items():
# TODO: this should be out of generate_renders...
num_words[(lemma, msd[0])] += len(texts)
rep = max(set(texts), key=texts.count)
self.rendered_words[lemma][msd] = (rep, len(texts))
for txt in texts:
freq_words[(msd, txt)] += 1
common_msd = self.merge_msd(common_msd, msd)
self.lemma_msd[lemma] = common_msd
self.frequent_words[lemma] = []
for (msd, txt), n in sorted(freq_words.items(), key=lambda x: -x[1]):
self.frequent_words[lemma].append((msd, txt, n))
lf = self.lemma_features
for lemma in self.lemma_msd:
cmsd = self.lemma_msd[lemma]
if cmsd[0] in lf:
self.lemma_msd[lemma] = "".join(
l1 if l1 != "-" else l2 for l1, l2 in zip(lf[cmsd[0]], cmsd)
)
self.num_words = dict(num_words)
def merge_msd(self, common_msd, new_msd):
key = (common_msd, new_msd)
if key in self.memoized_msd_merges:
return self.memoized_msd_merges[key]
def merge_letter(l1, l2):
if l1 == "*":
return l2
elif l1 != l2:
return "-"
else:
return l1
value = "".join(merge_letter(l1, l2) for l1, l2 in zip(common_msd, new_msd))
self.memoized_msd_merges[key] = value
return value
def render(self, lemma, msd):
if lemma in self.rendered_words:
if msd in self.rendered_words[lemma]:
return self.rendered_words[lemma][msd][0]
def available_words(self, lemma, existing_texts):
counted_texts = Counter(existing_texts)
for (msd, text), _n in counted_texts.most_common():
yield (msd, text)
if lemma in self.frequent_words:
for msd, text, _ in self.frequent_words[lemma]:
if (msd, text) not in counted_texts:
yield (msd, text)
def get_lemma_msd(self, lemma, word_msd):
# should be here, since we collect every lemmas
lemma_msd = self.lemma_msd[lemma]
if lemma_msd[0] == '-':
if word_msd[0] in self.lemma_features:
return self.lemma_features[word_msd[0]]
else:
return '-'
else:
return lemma_msd
def is_root_id(id_):
return len(id_.split('.')) == 3
def load_files(args):
filenames = args.input
skip_id_check = args.skip_id_check
do_msd_translate = not args.no_msd_translate
for n, fname in enumerate(filenames):
if args.count_files:
status = " :: {} / {}".format(n, len(filenames))
else:
status = ""
yield load_tei_file(fname, skip_id_check, do_msd_translate, args.pc_tag, status)
def load_tei_file(filename, skip_id_check, do_msd_translate, pc_tag, status):
logging.info("LOADING FILE: {}{}".format(filename, status))
with open(filename, 'r') as fp:
xmlstring = re.sub(' xmlns="[^"]+"', '', fp.read(), count=1)
xmlstring = xmlstring.replace(' xml:', ' ')
et = ElementTree.XML(xmlstring)
words = {}
for w in et.iter("w"):
words[w.get('id')] = Word(w, do_msd_translate)
for pc in et.iter(pc_tag):
words[pc.get('id')] = Word.pc_word(pc, do_msd_translate)
for l in et.iter("link"):
if 'dep' in l.keys():
ana = l.get('afun')
lfrom = l.get('from')
dest = l.get('dep')
else:
ana = l.get('ana')
if ana[:4] != 'syn:': # dont bother...
continue
ana = ana[4:]
lfrom, dest = l.get('target').replace('#', '').split()
if lfrom in words:
if not skip_id_check and is_root_id(lfrom):
logging.error("NOO: {}".format(lfrom))
sys.exit(1)
if dest in words:
next_word = words[dest]
words[lfrom].add_link(ana, next_word)
else:
logging.error("Unknown id: {}".format(dest))
sys.exit(1)
else:
# strange errors, just skip...
pass
return list(words.values())
class Formatter:
def __init__(self, colocation_ids, word_renderer):
self.colocation_ids = colocation_ids
self.word_renderer = word_renderer
self.additional_init()
def header_repeat(self):
raise NotImplementedError("Header repeat formatter not implemented")
def header_right(self):
raise NotImplementedError("Header right formatter not implemented")
def content_repeat(self, words, representations, idx, sidx):
raise NotImplementedError("Content repeat formatter not implemented")
def content_right(self, freq):
raise NotImplementedError("Content right formatter not implemented")
def group(self):
raise NotImplementedError("Group for formatter not implemented")
def additional_init(self):
pass
def length(self):
return len(self.header_repeat())
def set_structure(self, structure):
pass
def new_match(self, match):
pass
class OutNoStatFormatter(Formatter):
def additional_init(self):
self.representation = ""
def header_repeat(self):
return ["Lemma", "Representative_form", "RF_scenario"]
def header_right(self):
return ["Joint_representative_form", "Frequency"]
def content_repeat(self, words, representations, idx, _sidx):
word = words[idx]
if idx not in representations:
return [word.lemma, "", ""]
rep = representations[idx]
if rep is None:
self.representation += " " + word.lemma
return [word.lemma, word.lemma, "lemma_fallback"]
else:
self.representation += " " + rep
return [word.lemma, rep, "ok"]
def content_right(self, freq):
rep = re.sub(' +', ' ', self.representation.strip())
result = [rep, str(freq)]
self.representation = ""
return result
def group(self):
return True
class AllFormatter(Formatter):
def header_repeat(self):
return ["Token_ID", "Word_form", "Lemma", "Msd"]
def header_right(self):
return []
def content_repeat(self, words, _representations, idx, _sidx):
word = words[idx]
return [word.id, word.text, word.lemma, word.msd]
def content_right(self, _freq):
return []
def group(self):
return False
class StatsFormatter(Formatter):
def additional_init(self):
self.stats = None
self.jppb = None
self.corew = None
@staticmethod
def stat_str(num):
return "{:.5f}".format(num) if isinstance(num, float) else str(num)
def set_structure(self, structure):
jppb = []
corew = []
for component in structure.components:
if component.type == ComponentType.Core2w:
jppb.append(component.idx)
if component.type != ComponentType.Other:
corew.append(component.idx)
assert(len(jppb) == 2)
self.jppb = tuple(jppb)
self.corew = tuple(corew)
def new_match(self, match):
self.stats = {"freq": {}}
for cid in self.corew:
if cid not in match.matches[0]:
freq = 0
else:
word = match.matches[0][cid]
freq = self.word_renderer.num_words[(word.lemma, word.msd[0])]
self.stats["freq"][cid] = freq
fx = self.stats["freq"][self.jppb[0]]
fy = self.stats["freq"][self.jppb[1]]
freq = len(match)
N = self.word_renderer.num_all_words()
self.stats['d12'] = freq / fx - (fy - freq) / (N - fx)
self.stats['d21'] = freq / fy - (fx - freq) / (N - fy)
self.stats['df'] = match.distinct_forms()
self.stats['freq_all'] = freq
def header_repeat(self):
return ["Distribution"]
def header_right(self):
return ["Delta_p12", "Delta_p21", "LogDice_core", "LogDice_all", "Distinct_forms"]
def content_repeat(self, words, representations, idx, sidx):
# not a core word
if idx not in self.corew:
return [""] * self.length()
word = words[idx]
key = (sidx, idx, word.lemma)
distribution = self.colocation_ids.dispersions[key]
return [self.stat_str(distribution)]
def content_right(self, freq):
fx = self.stats["freq"][self.jppb[0]]
fy = self.stats["freq"][self.jppb[1]]
freq = self.stats['freq_all']
logdice_core = 14 + log2(2 * freq / (fx + fy))
fi = [self.stats["freq"][idx] for idx in self.corew]
fi = [f for f in fi if f > 0]
logdice_all = 14 + log2(len(fi) * freq / sum(fi))
return [self.stat_str(x) for x in (
self.stats["d12"], self.stats["d21"], logdice_core, logdice_all, self.stats['df']
)]
def group(self):
return True
class OutFormatter(Formatter):
def additional_init(self):
self.f1 = OutNoStatFormatter(self.colocation_ids, self.word_renderer)
self.f2 = StatsFormatter(self.colocation_ids, self.word_renderer)
def header_repeat(self):
return self.f1.header_repeat() + self.f2.header_repeat()
def header_right(self):
return self.f1.header_right() + self.f2.header_right()
def content_repeat(self, words, representations, idx, sidx):
cr1 = self.f1.content_repeat(words, representations, idx, sidx)
cr2 = self.f2.content_repeat(words, representations, idx, sidx)
return cr1 + cr2
def content_right(self, freq):
return self.f1.content_right(freq) + self.f2.content_right(freq)
def group(self):
return self.f1.group() and self.f2.group()
def set_structure(self, structure):
self.f2.set_structure(structure)
def new_match(self, match):
self.f2.new_match(match)
class Writer:
@staticmethod
def other_params(args):
return (args.multiple_output, int(args.sort_by), args.sort_reversed)
@staticmethod
def make_output_writer(args, colocation_ids, word_renderer):
params = Writer.other_params(args)
return Writer(args.out, OutFormatter(colocation_ids, word_renderer), params)
@staticmethod
def make_output_no_stat_writer(args, colocation_ids, word_renderer):
params = Writer.other_params(args)
return Writer(args.out_no_stat, OutNoStatFormatter(colocation_ids, word_renderer), params)
@staticmethod
def make_all_writer(args, colocation_ids, word_renderer):
return Writer(args.all, AllFormatter(colocation_ids, word_renderer), None)
@staticmethod
def make_stats_writer(args, colocation_ids, word_renderer):
params = Writer.other_params(args)
return Writer(args.stats, StatsFormatter(colocation_ids, word_renderer), params)
def __init__(self, file_out, formatter, params):
if params is None:
self.multiple_output = False
self.sort_by = -1
self.sort_order = None
else:
self.multiple_output = params[0]
self.sort_by = params[1]
self.sort_order = params[2]
self.output_file = file_out
self.formatter = formatter
def header(self):
repeating_cols = self.formatter.header_repeat()
cols = ["C{}_{}".format(i + 1, thd) for i in range(MAX_NUM_COMPONENTS)
for thd in repeating_cols]
cols = ["Structure_ID"] + cols + ["Colocation_ID"]
cols += self.formatter.header_right()
return cols
def sorted_rows(self, rows):
if self.sort_by < 0 or len(rows) < 2:
return rows
if len(rows[0]) <= self.sort_by:
logging.warning("Cannot sort by column #{}: Not enough columns!".format(len(rows[0])))
return rows
try:
int(rows[0][self.sort_by])
def key(row):
return int(row[self.sort_by])
except ValueError:
def key(row):
return row[self.sort_by].lower()
return sorted(rows, key=key, reverse=self.sort_order)
def write_header(self, file_handler):
file_handler.write(", ".join(self.header()) + "\n")
def write_out_worker(self, file_handler, structure, colocation_ids):
rows = []
components = structure.components
for match in colocation_ids.get_matches_for(structure):
self.formatter.new_match(match)
for words in match.matches:
to_write = []
for idx, _comp in enumerate(components):
idx = str(idx + 1)
if idx not in words:
to_write.extend([""] * self.formatter.length())
else:
to_write.extend(self.formatter.content_repeat(words, match.representations, idx, structure.id))
# make them equal size
to_write.extend([""] * (MAX_NUM_COMPONENTS * self.formatter.length() - len(to_write)))
# structure_id and colocation_id
to_write = [structure.id] + to_write + [match.match_id]
# header_right
to_write.extend(self.formatter.content_right(len(match)))
rows.append(to_write)
if self.formatter.group():
break
if rows != []:
rows = self.sorted_rows(rows)
file_handler.write("\n".join([", ".join(row) for row in rows]) + "\n")
file_handler.flush()
def write_out(self, structures, colocation_ids):
if self.output_file is None:
return
def fp_close(fp_):
if fp_ != sys.stdout:
fp_.close()
def fp_open(snum=None):
if snum is None:
return open(self.output_file, "w")
else:
return open("{}.{}".format(self.output_file, snum), "w")
if not self.multiple_output:
fp = fp_open()
self.write_header(fp)
for s in structures:
if self.multiple_output:
fp = fp_open(s.id)
self.write_header(fp)
self.formatter.set_structure(s)
self.write_out_worker(fp, s, colocation_ids)
if self.multiple_output:
fp_close(fp)
if not self.multiple_output:
fp_close(fp)
class StructureMatch:
def __init__(self, match_id, structure):
self.match_id = match_id
self.structure = structure
self.matches = []
self.representations = {}
def distinct_forms(self):
dm = set()
keys = list(self.matches[0].keys())
for words in self.matches:
dm.add(" ".join(words[k].text for k in keys))
return len(dm)
def append(self, match):
self.matches.append(match)
def __len__(self):
return len(self.matches)
class ColocationIds:
def __init__(self):
self.data = {}
self.min_frequency = args.min_freq
self.dispersions = {}
def _add_match(self, key, structure, match):
if key not in self.data:
self.data[key] = StructureMatch(str(len(self.data) + 1), structure)
self.data[key].append(match)
def get(self, key, n):
return self.data[key][n]
def add_matches(self, matches):
for structure, nms in matches.items():
for nm in nms:
self._add_match(nm[1], structure, nm[0])
def get_matches_for(self, structure):
for _cid_tup, sm in self.data.items():
if sm.structure != structure:
continue
yield sm
def set_representations(self, word_renderer):
for _1, sm in tqdm(self.data.items()):
ComponentRendition.set_representations(sm, word_renderer)
def determine_colocation_dispersions(self):
dispersions = defaultdict(int)
for (structure_id, *word_tups) in self.data.keys():
for component_id, lemma in word_tups:
dispersions[(structure_id, component_id, lemma)] += 1
self.dispersions = dict(dispersions)
def match_file(words, structures):
matches = {s: [] for s in structures}
for s in tqdm(structures):
for w in words:
mhere = s.match(w)
for match in mhere:
colocation_id = [(idx, w.lemma) for idx, w in match.items()]
colocation_id = [s.id] + list(sorted(colocation_id, key=lambda x: x[0]))
colocation_id = tuple(colocation_id)
matches[s].append((match, colocation_id))
return matches
def main(structures_file, args):
structures, lemma_msds = load_structures(structures_file)
colocation_ids = ColocationIds()
word_renderer = WordMsdRenderer(lemma_msds)
if args.parallel:
num_parallel = int(args.parallel)
# make temporary directory to hold temporary files
with tempfile.TemporaryDirectory() as tmpdirname:
cmd = sys.argv
for inpt in args.input:
if inpt in cmd:
cmd.remove(inpt)
# remove "--parallel X"
pidx = cmd.index('--parallel')
del cmd[pidx]
del cmd[pidx]
def func(n):
cmdn = [sys.executable] + cmd + [args.input[n],
"--match-to-file", "{}/{}.p".format(tmpdirname, n)]
subprocess.check_call(cmdn)
return n
# use ThreadPoolExecuter to run subprocesses in parallel using py threads
with concurrent.futures.ThreadPoolExecutor(max_workers=num_parallel) as executor:
# fancy interface to wait for threads to finish
for id_input in executor.map(func, [i for i, _ in enumerate(args.input)]):
with open("{}/{}.p".format(tmpdirname, id_input), "rb") as fp:
words, matches = pickle.load(fp)
colocation_ids.add_matches(matches)
word_renderer.add_words(words)
else:
for words in load_files(args):
matches = match_file(words, structures)
# just save to temporary file, used for children of a parallel process
# MUST NOT have more than one file
if args.match_to_file is not None:
with open(args.match_to_file, "wb") as fp:
pickle.dump((words, matches), fp)
return
else:
colocation_ids.add_matches(matches)
word_renderer.add_words(words)
# get word renders for lemma/msd
word_renderer.generate_renders()
colocation_ids.determine_colocation_dispersions()
# figure out representations!
if args.out or args.out_no_stat:
colocation_ids.set_representations(word_renderer)
Writer.make_output_writer(args, colocation_ids, word_renderer).write_out(
structures, colocation_ids)
Writer.make_output_no_stat_writer(args, colocation_ids, word_renderer).write_out(
structures, colocation_ids)
Writer.make_all_writer(args, colocation_ids, word_renderer).write_out(
structures, colocation_ids)
Writer.make_stats_writer(args, colocation_ids, word_renderer).write_out(
structures, colocation_ids)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Extract structures from a parsed corpus.')
parser.add_argument('structures',
help='Structures definitions in xml file')
parser.add_argument('input',
help='input xml file in `ssj500k form`, can list more than one', nargs='+')
parser.add_argument('--out',
help='Classic output file')
parser.add_argument('--out-no-stat',
help='Output file, but without statistical columns')
parser.add_argument('--all',
help='Additional output file, writes more data')
parser.add_argument('--stats',
help='Output file for statistics')
parser.add_argument('--no-msd-translate',
help='MSDs are translated from slovene to english by default',
action='store_true')
parser.add_argument('--skip-id-check',
help='Skips checks for ids of <w> and <pc>, if they are in correct format',
action='store_true')
parser.add_argument('--min_freq', help='Minimal frequency in output',
type=int, default=0, const=1, nargs='?')
parser.add_argument('--verbose', help='Enable verbose output to stderr',
choices=["warning", "info", "debug"], default="info",
const="info", nargs='?')
parser.add_argument('--count-files',
help="Count files: more verbose output", action='store_true')
parser.add_argument('--multiple-output',
help='Generate one output for each syntactic structure',
action='store_true')
parser.add_argument('--sort-by',
help="Sort by a this column (index)", type=int, default=-1)
parser.add_argument('--sort-reversed',
help="Sort in reversed ored", action='store_true')
parser.add_argument('--pc-tag',
help='Tag for separators, usually pc or c', default="pc")
parser.add_argument('--parallel',
help='Run in multiple processes, should speed things up')
parser.add_argument('--match-to-file', help='Do not use!')
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose.upper())
start = time.time()
main(args.structures, args)
logging.info("TIME: {}".format(time.time() - start))