Updating for lemma representation of word_form. Also cleaning code, adding tqdm,...
This commit is contained in:
		
							parent
							
								
									3c669c7901
								
							
						
					
					
						commit
						4c2b5f2b13
					
				
							
								
								
									
										185
									
								
								wani.py
									
									
									
									
									
								
							
							
						
						
									
										185
									
								
								wani.py
									
									
									
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
from xml.etree import ElementTree
 | 
			
		||||
import re
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections import defaultdict, namedtuple
 | 
			
		||||
import sys
 | 
			
		||||
import logging
 | 
			
		||||
import argparse
 | 
			
		||||
@ -12,6 +12,7 @@ import concurrent.futures
 | 
			
		||||
import tempfile
 | 
			
		||||
 | 
			
		||||
from msd_translate import MSD_TRANSLATE
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MAX_NUM_COMPONENTS = 5
 | 
			
		||||
@ -217,7 +218,7 @@ class ComponentRendition:
 | 
			
		||||
        return self.rendition is rendition
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def set_representations(matches, structure, word_renderer):
 | 
			
		||||
    def set_representations(matches, structure, word_renderer, lemma_msds):
 | 
			
		||||
        representations = {
 | 
			
		||||
            c.idx: [[], None] if c.representation.isit(Rendition.WordForm) else [True, ""]
 | 
			
		||||
            for c in structure.components
 | 
			
		||||
@ -225,24 +226,34 @@ class ComponentRendition:
 | 
			
		||||
        found_agreements = {}
 | 
			
		||||
        word_component_id = {}
 | 
			
		||||
 | 
			
		||||
        # doprint = structure.id == '1' and matches[0]['1'].text.startswith('evrop') and matches[0]['2'].text.startswith('prv')
 | 
			
		||||
        doprint = False
 | 
			
		||||
 | 
			
		||||
        def render_all(component_id, lst):
 | 
			
		||||
            matches.representations[component_id] = "/".join([w.text for w in set(lst)])
 | 
			
		||||
        def render_all(component_id, lst, _bw):
 | 
			
		||||
            rep = "/".join([w.text for w in set(lst)]) if len(lst) > 0 else None
 | 
			
		||||
            matches.representations[component_id] = rep
 | 
			
		||||
        
 | 
			
		||||
        def render_form(component_id, lst):
 | 
			
		||||
            sorted_lst = sorted(set(lst), key=lst.count)
 | 
			
		||||
            for word in sorted_lst:
 | 
			
		||||
        def render_form(component_id, lst, backup_word):
 | 
			
		||||
            if backup_word is not None:
 | 
			
		||||
                lst.append(backup_word)
 | 
			
		||||
 | 
			
		||||
            lst_ctr = []
 | 
			
		||||
            for word in lst:
 | 
			
		||||
                lst_ctr.append((word.msd, word.lemma))
 | 
			
		||||
            sorted_lst = sorted(set(lst_ctr), key=lst.count)
 | 
			
		||||
 | 
			
		||||
            if len(lst) > 3:
 | 
			
		||||
                a = 3
 | 
			
		||||
 | 
			
		||||
            for word_msd, word_lemma in sorted_lst:
 | 
			
		||||
                if component_id in found_agreements:
 | 
			
		||||
                    other_component_id, other_word, agreements = found_agreements[component_id]
 | 
			
		||||
                    print(word.lemma, other_word.lemma, component_id, other_component_id, word.msd, word.msd)
 | 
			
		||||
                    agr = are_agreements_ok(word.msd, other_word.lemma, other_word.msd, agreements)
 | 
			
		||||
                    agr = are_agreements_ok(word_msd, other_word.lemma, other_word.msd, agreements)
 | 
			
		||||
                    if agr is None:
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    matches.representations[other_component_id] = agr
 | 
			
		||||
 | 
			
		||||
                matches.representations[word_component_id[word.id]] = word.most_frequent_text(word_renderer)
 | 
			
		||||
                if word_lemma is not None:
 | 
			
		||||
                    matches.representations[component_id] = word_renderer.render(word_lemma, word_msd)
 | 
			
		||||
 | 
			
		||||
                break
 | 
			
		||||
                
 | 
			
		||||
                # othw = are_agreements_ok(word, found_agreements)
 | 
			
		||||
@ -255,10 +266,7 @@ class ComponentRendition:
 | 
			
		||||
                if ow_msd[0] != w2_msd[0]:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                print(w1_msd, w2_msd)
 | 
			
		||||
                if check_agreement(w1_msd, w2_msd, agreements):
 | 
			
		||||
                    if doprint:
 | 
			
		||||
                        print("GOOD :)")
 | 
			
		||||
                    return w2_txt
 | 
			
		||||
 | 
			
		||||
        def check_msd(word, selectors):
 | 
			
		||||
@ -279,7 +287,6 @@ class ComponentRendition:
 | 
			
		||||
                # if not in msd, some strange msd was tries, skipping...
 | 
			
		||||
                if agr_case not in TAGSET[t1]:
 | 
			
		||||
                    logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd1))
 | 
			
		||||
                    print("BAAAD")
 | 
			
		||||
                    return False
 | 
			
		||||
 | 
			
		||||
                v1 = TAGSET[t1].index(agr_case)
 | 
			
		||||
@ -293,7 +300,6 @@ class ComponentRendition:
 | 
			
		||||
                t2 = msd2[0]
 | 
			
		||||
                if agr_case not in TAGSET[t2]:
 | 
			
		||||
                    logging.warning("Cannot do agreement: {} for msd {} not found!".format(agr_case, msd2))
 | 
			
		||||
                    print("BAAAD")
 | 
			
		||||
                    return False
 | 
			
		||||
                v2 = TAGSET[t2].index(agr_case)
 | 
			
		||||
                if v2 + 1 >= len(msd2): 
 | 
			
		||||
@ -328,44 +334,26 @@ class ComponentRendition:
 | 
			
		||||
                    assert(rep.isit(Rendition.WordForm))
 | 
			
		||||
                    wf_type, more = rep.more
 | 
			
		||||
 | 
			
		||||
                    if wf_type is WordFormSelection.All:
 | 
			
		||||
                        add = True
 | 
			
		||||
                        func = render_all
 | 
			
		||||
                    elif wf_type is WordFormSelection.Msd:
 | 
			
		||||
                    if wf_type is WordFormSelection.Msd:
 | 
			
		||||
                        add = check_msd(w, more)
 | 
			
		||||
                        func = render_form
 | 
			
		||||
                    elif wf_type is WordFormSelection.All:
 | 
			
		||||
                        add = True
 | 
			
		||||
                        func = render_all
 | 
			
		||||
                    elif wf_type is WordFormSelection.Any:
 | 
			
		||||
                        add = True
 | 
			
		||||
                        func = render_form
 | 
			
		||||
                    else:
 | 
			
		||||
                        assert(wf_type is WordFormSelection.Agreement)
 | 
			
		||||
                        other_w, agreements = more
 | 
			
		||||
                        found_agreements[other_w] = (w_id, w.lemma, agreements)
 | 
			
		||||
                        found_agreements[other_w] = (w_id, w, agreements)
 | 
			
		||||
 | 
			
		||||
                        add = True
 | 
			
		||||
                        func = lambda *x: None
 | 
			
		||||
 | 
			
		||||
                    representations[w_id][1] = func
 | 
			
		||||
                    if add:
 | 
			
		||||
                        representations[w_id][0].append(w)
 | 
			
		||||
                        representations[w_id][1] = func
 | 
			
		||||
 | 
			
		||||
        if doprint:
 | 
			
		||||
            print(len(matches), len(found_agreements))
 | 
			
		||||
 | 
			
		||||
        # for w1i, w2i, agreements in found_agreements:
 | 
			
		||||
        #     w1, w2 = words[w1i], words[w2i]
 | 
			
		||||
        #     if doprint:
 | 
			
		||||
        #         print("? ", w1.msd, w2.msd, end="")
 | 
			
		||||
 | 
			
		||||
        #     if w2i not in bad_words:
 | 
			
		||||
        #             
 | 
			
		||||
        #         if check_agreement(w1, w2, agreements):
 | 
			
		||||
        #             representations[w1i][0].append(w1.text)
 | 
			
		||||
        #             if doprint:
 | 
			
		||||
        #                 print(" :)")
 | 
			
		||||
        #         elif doprint:
 | 
			
		||||
        #             print(" :(")
 | 
			
		||||
        #     elif doprint:
 | 
			
		||||
        #         print(" :((")
 | 
			
		||||
 | 
			
		||||
        # just need to set representation to first group,
 | 
			
		||||
        # but in correct order, agreements last!
 | 
			
		||||
@ -379,20 +367,12 @@ class ComponentRendition:
 | 
			
		||||
 | 
			
		||||
        for w_id, w in representation_sorted_words:
 | 
			
		||||
            data = representations[w_id]
 | 
			
		||||
            if doprint:
 | 
			
		||||
                print([(r.text, r.lemma, r.msd) for r in data[0]])
 | 
			
		||||
 | 
			
		||||
            if type(data[1]) is str:
 | 
			
		||||
                matches.representations[w_id] = None if data[0] else data[1]
 | 
			
		||||
            elif len(data[0]) == 0:
 | 
			
		||||
                matches.representations[w_id] = None
 | 
			
		||||
            else:
 | 
			
		||||
                data[1](str(w_id), data[0])
 | 
			
		||||
                backup_word = lemma_only_word(lemma_msds[w.msd[0]]) if w.msd[0] in lemma_msds else None
 | 
			
		||||
                data[1](str(w_id), data[0], backup_word)
 | 
			
		||||
            
 | 
			
		||||
        if doprint:
 | 
			
		||||
            print(matches.representations)
 | 
			
		||||
            print('--')
 | 
			
		||||
   
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.rendition)
 | 
			
		||||
 | 
			
		||||
@ -468,7 +448,7 @@ def build_morphology_regex(restriction):
 | 
			
		||||
                return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    return " ".join(rgx), matcher
 | 
			
		||||
    return rgx, matcher
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_lexis_regex(restriction):
 | 
			
		||||
@ -493,7 +473,8 @@ class Restriction:
 | 
			
		||||
        restriction_type = restriction_tag.get('type')
 | 
			
		||||
        if restriction_type == "morphology":
 | 
			
		||||
            self.type = RestrictionType.Morphology
 | 
			
		||||
            self.present, self.matcher = build_morphology_regex(list(restriction_tag))
 | 
			
		||||
            present, self.matcher = build_morphology_regex(list(restriction_tag))
 | 
			
		||||
            self.present = " ".join(present)
 | 
			
		||||
        elif restriction_type == "lexis":
 | 
			
		||||
            self.type = RestrictionType.Lexis
 | 
			
		||||
            self.present, self.matcher = build_lexis_regex(list(restriction_tag))
 | 
			
		||||
@ -822,17 +803,44 @@ class SyntacticStructure:
 | 
			
		||||
        #     to_ret.append((m, self.check_agreements(m)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_structures(filename):
 | 
			
		||||
    structures = []
 | 
			
		||||
def load_structures(filename):
 | 
			
		||||
    with open(filename, 'r') as fp:
 | 
			
		||||
        et = ElementTree.XML(fp.read())
 | 
			
		||||
        for structure in et.iter('syntactic_structure'):
 | 
			
		||||
            to_append = SyntacticStructure.from_xml(structure)
 | 
			
		||||
            if to_append is None:
 | 
			
		||||
                continue
 | 
			
		||||
            structures.append(to_append)
 | 
			
		||||
    
 | 
			
		||||
    return build_structures(et), get_lemma_features(et)
 | 
			
		||||
 | 
			
		||||
def build_structures(et):
 | 
			
		||||
    structures = []
 | 
			
		||||
    for structure in et.iter('syntactic_structure'):
 | 
			
		||||
        to_append = SyntacticStructure.from_xml(structure)
 | 
			
		||||
        if to_append is None:
 | 
			
		||||
            continue
 | 
			
		||||
        structures.append(to_append)
 | 
			
		||||
    return structures
 | 
			
		||||
 | 
			
		||||
def get_lemma_features(et):
 | 
			
		||||
    lf = et.find('lemma_features')
 | 
			
		||||
    if lf is None:
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    result = {}
 | 
			
		||||
    for pos in lf.iter('POS'):
 | 
			
		||||
        rgx_list, _ = build_morphology_regex(pos)
 | 
			
		||||
        rgx_str = ""
 | 
			
		||||
        for position in rgx_list:
 | 
			
		||||
            if position == ".":
 | 
			
		||||
                rgx_str += " "
 | 
			
		||||
            elif len(position) == 1:
 | 
			
		||||
                rgx_str += position
 | 
			
		||||
            elif len(position) == 3 and position[0] == "[" and position[2] == "]":
 | 
			
		||||
                rgx_str += position[1]
 | 
			
		||||
            else:
 | 
			
		||||
                raise RuntimeError("Strange rgx for lemma_feature...")
 | 
			
		||||
        
 | 
			
		||||
        assert(rgx_str[0].isupper())
 | 
			
		||||
        result[rgx_str[0]] = rgx_str.strip().replace(' ', '-')
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
def get_msd(comp):
 | 
			
		||||
    d = dict(comp.items())
 | 
			
		||||
@ -844,6 +852,10 @@ def get_msd(comp):
 | 
			
		||||
        logging.error(d, file=sys.stderr)
 | 
			
		||||
        raise NotImplementedError("MSD?")
 | 
			
		||||
 | 
			
		||||
def lemma_only_word(msd):
 | 
			
		||||
    WordLemma = namedtuple('WordLemmaOnly', 'msd most_frequent_text lemma')
 | 
			
		||||
    return WordLemma(msd=msd, most_frequent_text=lambda *x: None, lemma=None)
 | 
			
		||||
 | 
			
		||||
class Word:
 | 
			
		||||
    def __init__(self, xml, do_msd_translate):
 | 
			
		||||
        self.lemma = xml.get('lemma')
 | 
			
		||||
@ -882,6 +894,7 @@ class WordMsdRenderer:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.all_words = []
 | 
			
		||||
        self.rendered_words = {}
 | 
			
		||||
        self.frequent_words = {}
 | 
			
		||||
    
 | 
			
		||||
    def add_words(self, words):
 | 
			
		||||
        self.all_words.extend(words)
 | 
			
		||||
@ -893,19 +906,30 @@ class WordMsdRenderer:
 | 
			
		||||
 | 
			
		||||
        for lemma, ld in data.items():
 | 
			
		||||
            self.rendered_words[lemma] = {}
 | 
			
		||||
            freq_words = defaultdict(int)
 | 
			
		||||
 | 
			
		||||
            for msd, texts in ld.items():
 | 
			
		||||
                rep = max(set(texts), key=texts.count)
 | 
			
		||||
                self.rendered_words[lemma][msd] = rep
 | 
			
		||||
    
 | 
			
		||||
                self.rendered_words[lemma][msd] = (rep, len(texts))
 | 
			
		||||
 | 
			
		||||
                for txt in texts:
 | 
			
		||||
                    freq_words[(msd, txt)] += 1
 | 
			
		||||
            
 | 
			
		||||
            self.frequent_words[lemma] = []
 | 
			
		||||
            for (msd, txt), n in sorted(freq_words.items(), key=lambda x: -x[1]):
 | 
			
		||||
                self.frequent_words[lemma].append((msd, txt, n))
 | 
			
		||||
        
 | 
			
		||||
    def render(self, lemma, msd):
 | 
			
		||||
        if lemma in self.rendered_words:
 | 
			
		||||
            if msd in self.rendered_words[lemma]:
 | 
			
		||||
                return self.rendered_words[lemma][msd]
 | 
			
		||||
                return self.rendered_words[lemma][msd][0]
 | 
			
		||||
    
 | 
			
		||||
    def available_words(self, lemma):
 | 
			
		||||
        if lemma in self.rendered_words:
 | 
			
		||||
            for msd in self.rendered_words[lemma].keys():
 | 
			
		||||
                yield (msd, self.rendered_words[lemma][msd])
 | 
			
		||||
        if lemma in self.frequent_words:
 | 
			
		||||
            # print("--")
 | 
			
		||||
            for msd, text, _ in self.frequent_words[lemma]:
 | 
			
		||||
                # print(lemma, msd, text, _)
 | 
			
		||||
                yield (msd, text)
 | 
			
		||||
 | 
			
		||||
def is_root_id(id_):
 | 
			
		||||
    return len(id_.split('.')) == 3
 | 
			
		||||
@ -1010,7 +1034,6 @@ class Writer:
 | 
			
		||||
        elif self.all:
 | 
			
		||||
            return [word.id, word.text, word.lemma, word.msd]
 | 
			
		||||
        else:
 | 
			
		||||
            # print("1", word)
 | 
			
		||||
            if representation is None:
 | 
			
		||||
                return [word.lemma, word.lemma, "lemma_fallback"]
 | 
			
		||||
            else:
 | 
			
		||||
@ -1137,17 +1160,21 @@ class ColocationIds:
 | 
			
		||||
                if group:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
    def set_representations(self, structures, word_renderer):
 | 
			
		||||
    def set_representations(self, structures, word_renderer, lemma_msds):
 | 
			
		||||
        components_dict = {structure.id: structure for structure in structures}
 | 
			
		||||
        for _1, sm in self.data.items():
 | 
			
		||||
            ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer)
 | 
			
		||||
        idx = 1
 | 
			
		||||
        for _1, sm in tqdm(self.data.items()):
 | 
			
		||||
            if idx == 120:
 | 
			
		||||
                a = 3
 | 
			
		||||
            ComponentRendition.set_representations(sm, components_dict[sm.structure_id], word_renderer, lemma_msds)
 | 
			
		||||
            idx += 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def match_file(words, structures):
 | 
			
		||||
    matches = {s.id: [] for s in structures}
 | 
			
		||||
 | 
			
		||||
    for idx, s in enumerate(structures):
 | 
			
		||||
        logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id))
 | 
			
		||||
    for idx, s in tqdm(list(enumerate(structures))):
 | 
			
		||||
        # logging.info("{}/{}: {:7s}".format(idx, len(structures), s.id))
 | 
			
		||||
        for w in words:
 | 
			
		||||
            mhere = s.match(w)
 | 
			
		||||
            logging.debug("  GOT: {}".format(len(mhere)))
 | 
			
		||||
@ -1162,13 +1189,19 @@ def match_file(words, structures):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(input_file, structures_file, args):
 | 
			
		||||
    structures = build_structures(structures_file)
 | 
			
		||||
    structures, lemma_msds = load_structures(structures_file)
 | 
			
		||||
    for s in structures:
 | 
			
		||||
        logging.debug(str(s))
 | 
			
		||||
 | 
			
		||||
    colocation_ids = ColocationIds()
 | 
			
		||||
    word_renderer = WordMsdRenderer()
 | 
			
		||||
 | 
			
		||||
    # if True:
 | 
			
		||||
    #     with open("match_word.p", "rb") as fp:
 | 
			
		||||
    #         words, matches = pickle.load(fp)
 | 
			
		||||
    #     colocation_ids.add_matches(matches)
 | 
			
		||||
    #     word_renderer.add_words(words)
 | 
			
		||||
 | 
			
		||||
    if args.parallel:
 | 
			
		||||
        num_parallel = int(args.parallel)
 | 
			
		||||
 | 
			
		||||
@ -1215,7 +1248,7 @@ def main(input_file, structures_file, args):
 | 
			
		||||
    # get word renders for lemma/msd
 | 
			
		||||
    word_renderer.generate_renders()
 | 
			
		||||
    # figure out representations!
 | 
			
		||||
    colocation_ids.set_representations(structures, word_renderer)
 | 
			
		||||
    colocation_ids.set_representations(structures, word_renderer, lemma_msds)
 | 
			
		||||
 | 
			
		||||
    if args.all:
 | 
			
		||||
        Writer.make_all_writer(args).write_out(structures, colocation_ids)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user