You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cjvt-srl-tagging/tools/gen_json_fix_errors.py

295 lines
8.9 KiB

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import pickle
from pathlib import Path
from parser.parser import Parser
import configparser
import json
import sys
import logging
from multiprocessing import Pool
# parse config
config = configparser.ConfigParser()
config.read("tools.cfg")
ORIGPATH = Path(config["tools"]["giga"])
INPATH = Path(config["tools"]["giga_srl_errors"])
OUTPATH = Path(config["tools"]["giga_json"])
INTERNAL_DATA = Path(config["tools"]["internal_data"])
DEBUG = config["tools"]["debug"] == "True"
CPU_CORES = int(config["tools"]["cpu_cores"])
LOGFILE = Path(config["tools"]["logfile"]).absolute()
LOGFILE.touch(exist_ok=True)
LOGFILE.resolve()
logging.basicConfig(filename=str(LOGFILE), level=logging.INFO)
error_sentences = [line.rstrip('\n') for line in open(os.path.join(INTERNAL_DATA, 'sentences_with_less_than_token.txt'))]
def get_origfile(filename):
for origfile in ORIGPATH.iterdir():
if filename.name.split('.')[0] == origfile.name.split('.')[0]:
return origfile
raise FileNotFoundError
def extract_sentences(line_reader):
acc = []
# last char in line is \n, remove it
for line in [x.decode("utf-8")[:-1].split('\t') for x in line_reader]:
if len(line) == 1: # empty line
tmp = acc
acc = []
yield tmp
else:
acc.append(line)
def to_sentence(sentence_arr):
return " ".join([token[1] for token in sentence_arr])
def match_sentence_id(sentence, orig_dict):
for k, e in orig_dict.items():
orig_sentence = " ".join(token[2] for token in e["tokens"])
if sentence == orig_sentence:
return k
raise KeyError
def match_sentence_id_giga(sentence, orig_dict):
for k, e in orig_dict.items():
# orig_sentence = " ".join(token[2] for token in e["tokens"])
if sentence == e["text"]:
return k
raise KeyError
def get_dep_rel(token):
logging.debug(token)
for i, field in enumerate(token[14:]):
if field != "_":
return {
"arg": field,
"from": i, # i-th predicate in sentence
"dep": token[0],
}
return None
def handle_file_old(infile_tpl):
i = infile_tpl[0]
infile = infile_tpl[1]
outfile = (OUTPATH / infile.name).with_suffix(".json")
origfile = get_origfile(infile)
orig_dict = par.parse_tei(origfile)
with infile.open("rb") as fp:
outdata = {}
for sentence_arr in extract_sentences(fp.readlines()):
# tsv dropped sentence ids, match the ID, using original data
sid = match_sentence_id(to_sentence(sentence_arr), orig_dict)
outdata[sid] = []
# find all predicate indices in the sentence
predicates = []
for token in sentence_arr:
if token[12] == "Y":
predicates += [token[0]] # idx
deprel = get_dep_rel(token)
if deprel is not None:
outdata[sid].append(deprel)
# deprel["from"] points to n-th predicate
# replace with predicate's token index
for deprel in outdata[sid]:
deprel["from"] = predicates[deprel["from"]]
if DEBUG:
print(to_sentence(sentence_arr))
print(outdata[sid])
print(sid)
print()
print()
with outfile.open("w") as fp:
json.dump(outdata, fp)
logging.info("SRL relations written to: {}".format(outfile))
def fix_json(srl_gen, error_sentence, orig_json_data):
# sentence_id = whole_input[0][3]
# orig_infile = whole_input[0][1]
# sentence_id = whole_input[3]
# orig_infile = whole_input[1]
# origfile = origfiles[0][1]
# infile_tpl = infile_tpl[0]
# i = infile_tpl[0]
# infile = infile_tpl[1]
# outfile = (OUTPATH / orig_infile.name).with_suffix(".json")
# if outfile.exists():
# return
# origfile = get_origfile()
# orig_dict = par.parse_tei(orig_infile)
# outdata = {}
# gen = srl_multiple_files_sentences_generator(sentence_id)
# gen = srl_multiple_files_sentences_generator(whole_input[1])
# mismatch_sentences = 0
# look at neighbouring sentences if they are correct
sentence, sentence_arr = next(srl_gen)
# orig_sentence = " ".join(token[2] for token in e["tokens"])
sid = error_sentence
# a = orig_json_data[sid]
if orig_json_data[sid] != []:
# print('POSSIBLE ERROR:')
# print(orig_json_data[sid])
orig_json_data[sid] = []
# find all predicate indices in the sentence
predicates = []
for token in sentence_arr:
if token[12] == "Y":
predicates += [token[0]] # idx
deprel = get_dep_rel(token)
if deprel is not None:
orig_json_data[sid].append(deprel)
# deprel["from"] points to n-th predicate
# replace with predicate's token index
for deprel in orig_json_data[sid]:
deprel["from"] = predicates[deprel["from"]]
if DEBUG:
print(to_sentence(sentence_arr))
print(orig_json_data[sid])
print(sid)
print()
print()
# a = orig_json_data[sid]
return orig_json_data
def count_orig_file_sentences(filename):
if os.path.exists(os.path.join(INTERNAL_DATA, 'orig_chunks', filename[1].name)):
return
print(filename[0])
orig_dict = par.parse_tei(filename[1])
# return filename[0], filename[1], len(orig_dict)
with open(os.path.join(INTERNAL_DATA, 'orig_chunks', filename[1].name), 'wb') as output:
pickle.dump((filename[0], filename[1], len(orig_dict)), output)
def count_srl_file_sentences(filename):
if os.path.exists(os.path.join(INTERNAL_DATA, 'srl_chunks', filename[1].name)):
return
print(filename[0])
num_sentences = 0
with filename[1].open("r") as fp:
for line in fp:
if line == '\n':
num_sentences += 1
# return filename[0], filename[1], num_sentences
with open(os.path.join(INTERNAL_DATA, 'srl_chunks', filename[1].name), 'wb') as output:
pickle.dump((filename[0], filename[1], num_sentences), output)
def srl_error_fix_generator(infile):
with infile.open("rb") as fp:
for sentence_arr in extract_sentences(fp.readlines()):
yield to_sentence(sentence_arr), sentence_arr
yield None
def srl_sentences_generator(infile, curr_index, sen_start_index):
with infile.open("rb") as fp:
outdata = {}
for sentence_arr in extract_sentences(fp.readlines()):
if curr_index < sen_start_index:
curr_index += 1
else:
yield to_sentence(sentence_arr), sentence_arr
yield None
def srl_multiple_files_sentences_generator(sentence_id): # srl_files):
sentence_id = max(0, sentence_id - 10)
for i, srl_file in enumerate(srl_file_sizes):
if sentence_id >= srl_file[3] and sentence_id < srl_file[3] + srl_file[2]:
srl_files = srl_file_sizes[i:]
break
for file_info in srl_files:
# srl_gen = srl_sentences_generator(file_info[1], file_info[3], file_info[4])
srl_gen = srl_sentences_generator(file_info[1], file_info[3], sentence_id)
el = next(srl_gen)
while el is not None:
yield el
el = next(srl_gen)
yield None
error_sentences_grouped = []
group = False
prev_name = ''
# group sentences by their files
for name in error_sentences:
if name[:9] == prev_name:
group.append(name)
else:
prev_name = name[:9]
if group:
error_sentences_grouped.append(group)
group = [name]
error_sentences_grouped.append(group)
srl_gen = srl_error_fix_generator(INPATH)
# find errors in json files:
# with open(os.path.join(INTERNAL_DATA, 'sentence_ids_list.pkl'), 'rb') as output:
# sentence_ids = pickle.load(output)
#
#
#
# origfiles = []
# for subdir, dirs, files in os.walk(OUTPATH):
# for file in files:
# origfiles.append(Path(os.path.join(subdir, file)))
# origfiles=sorted(origfiles)
#
#
#
# for sent in origfiles:
# # for sent in sentence_ids:
# # outfile = Path(OUTPATH, sent[:9] + '-dedup.json')
# outfile = sent
#
# try:
# with outfile.open() as json_file:
# json.load(json_file)
# pass
# except:
# print(outfile.name)
#
#
# raise Exception('test')
# iterate over all wronged sentences and fix them
for errors_in_file in error_sentences_grouped:
outfile = Path(OUTPATH, errors_in_file[0][:9] + '-dedup.json')
with outfile.open() as json_file:
print(outfile.name)
orig_json_data = json.load(json_file)
for error_sentence in errors_in_file:
orig_json_data = fix_json(srl_gen, error_sentence, orig_json_data)
with outfile.open('w') as json_file:
json.dump(orig_json_data, json_file)
logging.info("SRL relations written to: {}".format(outfile))