Index: branches/weighted-transfer/apertium-weights-learner/tools/simpletok.py =================================================================== --- branches/weighted-transfer/apertium-weights-learner/tools/simpletok.py (revision 72246) +++ branches/weighted-transfer/apertium-weights-learner/tools/simpletok.py (revision 72247) @@ -7,7 +7,7 @@ # or for sending them to language model training beforepunc_re = re.compile(r'([¿("/])(\w)') afterpunc_re = re.compile(r'(\w)([;:,.!?)"/—])') -quot_re = re.compile("[«»`'“”„‘’‛]") +quot_re = re.compile("[«»`'“”„‘’‛]|"") numfix_re = re.compile('([0-9]) ([,.:][0-9])') beforedash_re = re.compile(r'(\W)-(\w)') afterdash_re = re.compile(r'(\w)-(\W)') Index: branches/weighted-transfer/apertium-weights-learner/twlconfig.py =================================================================== --- branches/weighted-transfer/apertium-weights-learner/twlconfig.py (revision 72246) +++ branches/weighted-transfer/apertium-weights-learner/twlconfig.py (revision 72247) @@ -1,7 +1,17 @@ -# full path to source corpus from which to learn the rules -#source_corpus = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/2007-en-100.txt" -source_corpus = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/new-software-sample.txt" +# learning mode, must be one of the following: +# "mono": learn weights from monolingual corpus using pretrained language module +# "parallel": learn weights from parallel corpus (no language model required) +mode = "mono" +#mode = "parallel" +# full path to source language corpus from which to learn the rules +#source_language_corpus = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/2007-en-100.txt" +source_language_corpus = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/new-software-sample.txt" +#source_language_corpus = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/nc-v7.es-en.en.100.txt" + +# full path to target language corpus (only for parallel mode) +#target_language_corpus = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/nc-v7.es-en.es.100.txt" + # name of apertium language pair (not translation direction) apertium_pair_name = "en-es" @@ -12,7 +22,7 @@ source = "en" target = "es" -# full path to kenlm language model +# full path to kenlm language model (only for mono mode) # may be either arpa (text format) or mmap (binary) # but mmap is strongly preferred as it loads and scores faster language_model = "/media/nm/storage/es-news-tokenized.mmap" @@ -19,3 +29,6 @@ # full path to a folder where to store intermediate data and results data_folder = "/home/nm/source/apertium/weighted-transfer/apertium-weights-learner/data/" + +# optional common prefix for all intermediate and resulting files +#fname_common_prefix = "en-es-newscommentary" Index: branches/weighted-transfer/apertium-weights-learner/twlearner.py =================================================================== --- branches/weighted-transfer/apertium-weights-learner/twlearner.py (revision 72246) +++ branches/weighted-transfer/apertium-weights-learner/twlearner.py (revision 72247) @@ -24,7 +24,8 @@ if __name__ == "__main__": print("lxml library not found. Falling back to xml.etree,\n" "though it's highly recommended that you install lxml\n" - "as it works dramatically faster than xml.etree.") + "as it works dramatically faster than xml.etree.\n" + "Also, it supports pretty print.") using_lxml = False tmpweights_fname = 'tmpweights.w1x' @@ -53,11 +54,23 @@ return tixbasepath, binbasepath, cat_dict, pattern_FST, ambiguous_rules, rule_id_map, rule_xmls -def make_prefix(corpus, data_folder): +def make_prefix(data_folder, source_corpus, target_corpus=None): """ Make common prefix for all intermediate files. """ - trimmed_basename = os.path.splitext(os.path.basename(corpus))[0] + try: + fname_common_prefix = twlconfig.fname_common_prefix + prefix = os.path.join(data_folder, fname_common_prefix) + except AttributeError: + source_basename = os.path.basename(source_corpus) + trimmed_source_basename = os.path.splitext(source_basename)[0] + if target_corpus is None: + trimmed_basename = trimmed_source_basename + else: + target_basename = os.path.basename(target_corpus) + trimmed_target_basename = os.path.splitext(target_basename)[0] + trimmed_basename = '{}-{}'.format(trimmed_source_basename, + trimmed_target_basename) prefix = os.path.join(data_folder, trimmed_basename) return prefix @@ -94,7 +107,9 @@ pattern_list.append((i, part[1], tuple(part[0]))) return pattern_list -def detect_ambiguous(corpus, prefix, cat_dict, pattern_FST, ambiguous_rules, tixfname, binfname, rule_id_map): +def detect_ambiguous(corpus, prefix, + cat_dict, pattern_FST, ambiguous_rules, + tixfname, binfname, rule_id_map): """ Find sentences that contain ambiguous chunks. Translate them in all possible ways. @@ -130,23 +145,44 @@ coverage_list = pattern_FST.get_lrlm(sent_match.group(0), cat_dict) if coverage_list == []: botched_coverages += 1 - #print('Botched coverage:', sent_match.group(0)) - #print() else: - # look for ambiguous chunks + # look for ambiguous chunks... coverage_item = coverage_list[0] pattern_list = search_ambiguous(ambiguous_rules, coverage_item) if pattern_list != []: - #print('Coverage:', coverage_item) - #print('Pattern list:', pattern_list) - #print() + # ...translate them, and output them ambig_sents_count += 1 - # segment the sentence into parts each containing one ambiguous chunk + ambig_chunks_count = translate_ambiguous_sentence(pattern_list, coverage_item, ambig_chunks_count, + ambiguous_rules, rule_id_map, + translator, weighted_translator, ofile) + lines_count += 1 + if lines_count % 100 == 0: + print('\n{} total lines\n{} total sentences'.format(lines_count, total_sents_count)) + print('{} ambiguous sentences\n{} ambiguous chunks'.format(ambig_sents_count, ambig_chunks_count)) + print('{} botched coverages\nanother {:.4f} elapsed'.format(botched_coverages, clock() - lbtime)) + gc.collect() + lbtime = clock() + + # clean up temporary weights file + if os.path.exists(tmpweights_fname): + os.remove(tmpweights_fname) + + print('Done in {:.2f}'.format(clock() - btime)) + return ofname + +def translate_ambiguous_sentence(pattern_list, coverage_item, ambig_chunks_count, + ambiguous_rules, rule_id_map, + translator, weighted_translator, ofile): + """ + Segment sentence into parts each containing one ambiguous chunk, + translate them in every possible way, then make sentence variants + where one segment is translated in every possible way, and the rest + is translated with default rules. + """ sentence_segments, prev = [], 0 for i, rule_group_number, pattern in pattern_list: ambig_chunks_count += 1 - list_with_chunk = sum([chunk[0] - for chunk in coverage_item[prev:i+1]], []) + list_with_chunk = sum([chunk[0] for chunk in coverage_item[prev:i+1]], []) piece_of_line = '^' + '$ ^'.join(list_with_chunk) + '$' sentence_segments.append([rule_group_number, pattern, piece_of_line]) prev = i+1 @@ -154,14 +190,10 @@ if sentence_segments != []: if prev <= len(coverage_item): # add up the tail of the sentence - list_with_chunk = sum([chunk[0] - for chunk in coverage_item[prev:]], []) + list_with_chunk = sum([chunk[0] for chunk in coverage_item[prev:]], []) piece_of_line = ' ^' + '$ ^'.join(list_with_chunk) + '$' sentence_segments[-1][2] += piece_of_line - #print(sentence_segments) - #print() - # first, translate each segment with default rules for sentence_segment in sentence_segments: sentence_segment.append(translator.translate(sentence_segment[2])) @@ -169,37 +201,31 @@ # second, translate each segment with each of the rules, # and make full sentence, where other segments are translated with default rules for j, sentence_segment in enumerate(sentence_segments): - translation_list = translate_ambiguous(weighted_translator, ambiguous_rules[sentence_segment[0]], - sentence_segment[1], sentence_segment[2], rule_id_map) + translation_list = translate_ambiguous_segment(weighted_translator, + ambiguous_rules[sentence_segment[0]], + sentence_segment[1], + sentence_segment[2], rule_id_map) output_list = [] for rule, translation in translation_list: - translated_sentence = ' '.join(sentence_segment[3] for sentence_segment in sentence_segments[:j]) +\ + translated_sentence = ' '.join(sentence_segment[3] + for sentence_segment + in sentence_segments[:j]) +\ ' ' + translation + ' ' +\ - ' '.join(sentence_segment[3] for sentence_segment in sentence_segments[j+1:]) + ' '.join(sentence_segment[3] + for sentence_segment + in sentence_segments[j+1:]) output_list.append('{}\t{}'.format(rule, translated_sentence.strip(' '))) - # store results to a file + # store results to file # first, print rule group number, pattern, and number of rules in the group print('{}\t^{}$\t{}'.format(sentence_segment[0], '$ ^'.join(sentence_segment[1]), len(output_list)), file=ofile) # then, output all the translations in the following way: rule number, then translated sentence print('\n'.join(output_list), file=ofile) - lines_count += 1 - if lines_count % 100 == 0: - print('\n{} total lines\n{} total sentences'.format(lines_count, total_sents_count)) - print('{} ambiguous sentences\n{} ambiguous chunks'.format(ambig_sents_count, ambig_chunks_count)) - print('{} botched coverages\nanother {:.4f} elapsed'.format(botched_coverages, clock() - lbtime)) - gc.collect() - lbtime = clock() + return ambig_chunks_count - # clean up temporary weights file - if os.path.exists(tmpweights_fname): - os.remove(tmpweights_fname) - - print('Done in {:.2f}'.format(clock() - btime)) - return ofname - -def translate_ambiguous(weighted_translator, rule_group, pattern, sent_line, rule_id_map): +def translate_ambiguous_segment(weighted_translator, rule_group, + pattern, sent_line, rule_id_map): """ Translate sent_line for each rule in rule_group. """ @@ -215,10 +241,8 @@ if rule == focus_rule: et_pattern = make_et_pattern(et_rule, pattern) - if using_lxml: - etree.ElementTree(oroot).write(tmpweights_fname, pretty_print=True, encoding='utf-8', xml_declaration=True) - else: - etree.ElementTree(oroot).write(tmpweights_fname, encoding='utf-8', xml_declaration=True) + etree.ElementTree(oroot).write(tmpweights_fname, + encoding='utf-8', xml_declaration=True) # translate using created weights file translation = weighted_translator.translate(sent_line, tmpweights_fname) @@ -361,24 +385,182 @@ et_newpattern = make_et_pattern(et_newrule, prev_pattern, total_pattern_weight) if using_lxml: - etree.ElementTree(oroot).write(ofname, pretty_print=True, encoding='utf-8', xml_declaration=True) + # lxml supports pretty print + etree.ElementTree(oroot).write(ofname, pretty_print=True, + encoding='utf-8', xml_declaration=True) else: - etree.ElementTree(oroot).write(ofname, encoding='utf-8', xml_declaration=True) + etree.ElementTree(oroot).write(ofname, + encoding='utf-8', xml_declaration=True) print('Done in {:.2f}'.format(clock() - btime)) return ofname -if __name__ == "__main__": - if not os.path.exists(twlconfig.data_folder): - os.makedirs(twlconfig.data_folder) - prefix = make_prefix(twlconfig.source_corpus, twlconfig.data_folder) +def detect_ambiguous_parallel(source_corpus, target_corpus, prefix, + cat_dict, pattern_FST, ambiguous_rules, + tixfname, binfname, rule_id_map): + """ + Find ambiguous chunks. + Translate them in all possible ways. + Score them, and store the results. + """ + print('Looking for ambiguous chunks, translating and scoring them.') + btime = clock() - tbtime = clock() + # make output file name + ofname = prefix + '-chunk-weights.txt' + # initialize translators + # for translation with no weights + translator = partialTranslator(tixfname, binfname) + # for weighted translation + weighted_translator = weightedPartialTranslator(tixfname, binfname) + + # initialize statistics + lines_count, ambig_chunks_count = 0, 0 + botched_coverages = 0 + lbtime = clock() + + with open(source_corpus, 'r', encoding='utf-8') as sfile, \ + open(target_corpus, 'r', encoding='utf-8') as tfile, \ + open(ofname, 'w', encoding='utf-8') as ofile: + + for sl_line, tl_line in zip(sfile, tfile): + + # get coverages + coverage_list = pattern_FST.get_lrlm(sl_line.strip(), cat_dict) + if coverage_list == []: + botched_coverages += 1 + else: + # look for ambiguous chunks + coverage_item = coverage_list[0] + pattern_list = search_ambiguous(ambiguous_rules, coverage_item) + + # translate each chunk with each of the relevant rules + for i, rule_group_number, pattern in pattern_list: + ambig_chunks_count += 1 + pattern_chunk = '^' + '$ ^'.join(pattern) + '$' + translation_list = translate_ambiguous_segment(weighted_translator, + ambiguous_rules[rule_group_number], + pattern_chunk, pattern_chunk, + rule_id_map) + tl_line = normalize(tl_line) + for rule, translation in translation_list: + translation = normalize(translation) + if (translation in tl_line): + #print('{} IN {}'.format(translation, tl_line)) + print(rule_group_number, rule, pattern_chunk, '1.0', + sep='\t', file=ofile) + else: + #print('{} NOT IN {}'.format(translation, tl_line)) + pass + + lines_count += 1 + if lines_count % 100 == 0: + print('\n{} total lines\n{} ambiguous chunks'.format(lines_count, ambig_chunks_count)) + print('{} botched coverages\nanother {:.4f} elapsed'.format(botched_coverages, clock() - lbtime)) + gc.collect() + lbtime = clock() + + # clean up temporary weights file + if os.path.exists(tmpweights_fname): + os.remove(tmpweights_fname) + + print('Done in {:.2f}'.format(clock() - btime)) + return ofname + +def make_et_rule_group(et_rulegroup, pattern_rule_weights, rule_map, rule_xmls): + """ + Add a rule-group element to xml tree with normalized pattern weights. + """ + rule_pattern_weights = {} + for pattern, rule_weights in pattern_rule_weights.items(): + total = sum(weight for rule_number, weight in rule_weights.items()) + normalized_rule_weights = ((rule_number, weight / total) + for rule_number, weight + in rule_weights.items()) + for rule_number, weight in normalized_rule_weights: + rule_pattern_weights.setdefault(rule_number, []) + rule_pattern_weights[rule_number].append((pattern, weight)) + + for rule_number, pattern_weights in sorted(rule_pattern_weights.items(), key=lambda x: int(x[0])): + et_newrule = make_et_rule(rule_number, et_rulegroup, rule_map, rule_xmls) + for pattern, weight in pattern_weights: + et_newpattern = make_et_pattern(et_newrule, pattern, weight) + +def make_xml_transfer_weights_parallel(scores_fname, prefix, rule_map, rule_xmls): + """ + Sum up the weights for each rule-pattern pair, + add the result to xml weights file. + """ + print('Summing up the weights and making xml rules.') + btime = clock() + + # make output file names + sorted_scores_fname = prefix + '-chunk-weights-sorted.txt' + ofname = prefix + '-rule-weights.w1x' + + # create pipeline + pipe = pipes.Template() + pipe.append('sort $IN > $OUT', 'ff') + pipe.copy(scores_fname, sorted_scores_fname) + + # create empty output xml tree + oroot = etree.Element('transfer-weights') + et_newrulegroup = etree.SubElement(oroot, 'rule-group') + pattern_rule_weights = {} + + with open(sorted_scores_fname, 'r', encoding='utf-8') as ifile: + # read and process the first line + prev_group_number, rule_number, pattern, weight = ifile.readline().rstrip('\n').split('\t') + pattern_rule_weights[pattern] = {} + pattern_rule_weights[pattern][rule_number] = float(weight) + + # read and process other lines + for line in ifile: + group_number, rule_number, pattern, weight = line.rstrip('\n').split('\t') + + if group_number != prev_group_number: + # rule group changed: flush previuos + make_et_rule_group(et_newrulegroup, pattern_rule_weights, + rule_map, rule_xmls) + et_newrulegroup = etree.SubElement(oroot, 'rule-group') + pattern_rule_weights = {} + + pattern_rule_weights.setdefault(pattern, {}) + pattern_rule_weights[pattern].setdefault(rule_number, 0.) + pattern_rule_weights[pattern][rule_number] += float(weight) + + prev_group_number = group_number + + # flush the last rule-pattern + make_et_rule_group(et_newrulegroup, pattern_rule_weights, + rule_map, rule_xmls) + + if using_lxml: + # lxml supports pretty print + etree.ElementTree(oroot).write(ofname, pretty_print=True, + encoding='utf-8', xml_declaration=True) + else: + etree.ElementTree(oroot).write(ofname, + encoding='utf-8', xml_declaration=True) + + print('Done in {:.2f}'.format(clock() - btime)) + return ofname + +def learn_from_monolingual(): + """ + Learn rule weights from monolingual corpus + using pretrained language model. + """ + print('Learning rule weights from monolingual corpus with pretrained language model.') + + prefix = make_prefix(twlconfig.data_folder, + twlconfig.source_language_corpus) + # tag corpus tagged_fname = tag_corpus(twlconfig.apertium_pair_data, twlconfig.source, twlconfig.target, - twlconfig.source_corpus, prefix, + twlconfig.source_language_corpus, prefix, twlconfig.data_folder) # load rules, build rule FST @@ -403,7 +585,7 @@ # estimate rule weights for each ambiguous chunk scores_fname = score_sentences(ambig_sentences_fname, model, prefix) - # sum up weigths for rule-pattern and make unprunned xml + # sum up weights for rule-pattern and make unprunned xml weights_fname = make_xml_transfer_weights(scores_fname, prefix, rule_id_map, rule_xmls) @@ -410,4 +592,121 @@ # prune weights file prunned_fname = prune_xml_transfer_weights(using_lxml, weights_fname) +def learn_from_parallel(): + """ + Learn rule weights from parallel corpus (no language model required). + """ + print('Learning rule weights from parallel corpus.') + + prefix = make_prefix(twlconfig.data_folder, + twlconfig.source_language_corpus, + twlconfig.target_language_corpus) + + # tag corpus + tagged_fname = tag_corpus(twlconfig.apertium_pair_data, + twlconfig.source, twlconfig.target, + twlconfig.source_language_corpus, prefix, + twlconfig.data_folder) + + # load rules, build rule FST + tixbasepath, binbasepath, cat_dict, pattern_FST, \ + ambiguous_rules, rule_id_map, rule_xmls = \ + load_rules(twlconfig.apertium_pair_data, + twlconfig.source, twlconfig.target) + + # detect, score and store chunks with ambiguity + scores_fname = detect_ambiguous_parallel(tagged_fname, + twlconfig.target_language_corpus, + prefix, + cat_dict, pattern_FST, + ambiguous_rules, + tixbasepath, binbasepath, + rule_id_map) + + # sum up and normalize weights for rule-pattern and make unprunned xml + weights_fname = make_xml_transfer_weights_parallel(scores_fname, prefix, + rule_id_map, rule_xmls) + + # prune xml weights file + prunned_fname = prune_xml_transfer_weights(using_lxml, weights_fname) + +def validate_config(): + """ + Try reading variables from config file and perform basic sanity checks. + """ + try: + mode = twlconfig.mode + except AttributeError: + print('Undefined mode. Please specify either "mono" or "parallel".') + sys.exit(1) + + try: + apertium_pair_name = twlconfig.apertium_pair_name + except AttributeError: + print('Undefined apertium_pair_name.') + sys.exit(1) + + try: + apertium_pair_data = twlconfig.apertium_pair_data + if not os.path.exists(twlconfig.apertium_pair_data): + print('Apertium language pair data folder "{}" not found'.format(twlconfig.apertium_pair_data)) + sys.exit(1) + except AttributeError: + print('Undefined apertium_pair_data.') + sys.exit(1) + + try: + source = twlconfig.source + target = twlconfig.target + except AttributeError: + print('Undefined direction (source and/or target).') + sys.exit(1) + + try: + source_language_corpus = twlconfig.source_language_corpus + if not os.path.exists(twlconfig.source_language_corpus): + print('Source language corpus "{}" not found'.format(twlconfig.source_language_corpus)) + sys.exit(1) + except AttributeError: + print('Undefined source_language_corpus.') + sys.exit(1) + + if mode == "mono": + try: + language_model = twlconfig.language_model + if not os.path.exists(twlconfig.language_model): + print('Language model "{}" not found'.format(twlconfig.language_model)) + sys.exit(1) + except AttributeError: + print('Undefined language model.') + sys.exit(1) + elif mode == "parallel": + try: + tl_corpus = twlconfig.target_language_corpus + if not os.path.exists(twlconfig.target_language_corpus): + print('Target language corpus "{}" not found'.format(twlconfig.target_language_corpus)) + sys.exit(1) + except AttributeError: + print('Undefined target_language_corpus.') + sys.exit(1) + else: + print('Invalid mode {}. Please specify either "mono" or "parallel".'.format(mode)) + + try: + data_folder = twlconfig.data_folder + if not os.path.exists(twlconfig.data_folder): + os.makedirs(twlconfig.data_folder) + except AttributeError: + print('Undefined data_folder.') + sys.exit(1) + +if __name__ == "__main__": + validate_config() + tbtime = clock() + + if twlconfig.mode == "mono": + learn_from_monolingual() + elif twlconfig.mode == "parallel": + learn_from_parallel() + print('Performed in {:.2f}'.format(clock() - tbtime))