commit 4fd5e75ba496e318e1fccf403ee115c3ac05530d Author: vivekvardhanadepu Date: Tue Aug 10 21:43:23 2021 +0530 non-parallel training added to lexical_selection_training.py diff --git a/config.toml.example b/config.toml.example index c8ad006..d099955 100644 --- a/config.toml.example +++ b/config.toml.example @@ -8,19 +8,19 @@ IS_PARALLEL = false CORPUS = "europarl-v7" # source language[it should match with the language codes of apertium] -SL = "spa" +SL = "eng" # target language[it should match with the language codes of apertium] -TL = "eng" +TL = "spa" # language pair code(as per apertium language codes) PAIR = "eng-spa" # source corpus -CORPUS_SL = "europarl-v7.eng-spa.spa" +CORPUS_SL = "europarl-v7.eng-spa.eng" # target corpus -CORPUS_TL = "europarl-v7.eng-spa.eng" +CORPUS_TL = "europarl-v7.eng-spa.spa" # apertium-lex-tools scripts # LEX_TOOLS = "../apertium-lex-tools/scripts" @@ -40,5 +40,5 @@ MAX_RULES = 3 # crisphold CRISPHOLD = 1.5 -# TL binary language model[not required for parallel training] -TL_MODEL = "europarl-v7.eng-spa.eng.lm" \ No newline at end of file +# TL binary language model(.lm or .blm)[not required for parallel training] +TL_MODEL = "europarl-v7.eng-spa.spa.lm" \ No newline at end of file diff --git a/lexical_selection_training.py b/lexical_selection_training.py index 0181932..d558839 100644 --- a/lexical_selection_training.py +++ b/lexical_selection_training.py @@ -1,6 +1,7 @@ # lexical training script import os import sys +import gzip import shutil from subprocess import Popen, PIPE, call @@ -98,7 +99,7 @@ def parallel_training(config, cache_dir, log): cache_dir, 'rules_all.txt') ngrams_all = os.path.join( cache_dir, 'ngrams_all.txt') - rules = f"{config['CORPUS']}-{config['SL']}-{config['TL']}.ngrams-lm-{MIN}.xml" + rules = f"{config['CORPUS']}.{config['SL']}-{config['TL']}.ngrams-lm-{MIN}.xml" if os.path.isfile(rules): if not query(f"Do you want to overwrite '{rules}'"): @@ -300,13 +301,22 @@ def parallel_training(config, cache_dir, log): def non_parallel_training(config, cache_dir, log): - MIN = 1 + # MIN = 1 # file names sl_tagged = os.path.join( cache_dir, f"{config['CORPUS']}.tagged.{config['SL']}") lines = os.path.join(cache_dir, f"{config['CORPUS']}.lines") - rules = f"{config['CORPUS']}-{config['SL']}-{config['TL']}.ngrams-lm-{MIN}.xml" + tl_lm = f"cache-{config['SL']}-{config['TL']}/{config['CORPUS']}.{config['TL']}.lm" + biltrans = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.biltrans") + ambig = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.ambig") + multi_trimmed = os.path.join('./', f"{config['CORPUS']}.{config['SL']}-{config['TL']}.multi-trimmed") + ranked = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.ranked") + annotated = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.annotated") + lex_freq = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.freq") + ngrams = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.ngrams") + patterns = os.path.join(cache_dir, f"{config['CORPUS']}.{config['SL']}-{config['TL']}.patterns") + rules = f"{config['CORPUS']}.{config['SL']}-{config['TL']}.ngrams-lm.xml" if os.path.isfile(rules): if not query(f"Do you want to overwrite '{rules}'"): @@ -357,8 +367,79 @@ def non_parallel_training(config, cache_dir, log): os.remove(clean_tagged) + if 'TL_MODEL' in config: + tl_lm = config['TL_MODEL'] + else: + call([os.path.join(os.environ['IRSTLM'], 'bin/build-lm.sh'), '-i', config['TL'], '-o', + tl_lm+'.gz', '-t', 'tmp'], stdout=log, stderr=log) + + with gzip.open(tl_lm+'.gz', 'rb') as f_in, open(tl_lm, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + + os.remove(tl_lm+'.gz') + # os.remove('tmp_tl') + + sl_tl_autobil = os.path.join( + config['LANG_DATA'], f"{config['SL']}-{config['TL']}.autobil.bin") + t1x = os.path.join(config['LANG_DATA'], f"apertium-{config['PAIR']}.{config['SL']}-{config['TL']}.t1x") + t2x = os.path.join(config['LANG_DATA'], f"apertium-{config['PAIR']}.{config['SL']}-{config['TL']}.t2x") + t3x = os.path.join(config['LANG_DATA'], f"apertium-{config['PAIR']}.{config['SL']}-{config['TL']}.t3x") + t1x_bin = os.path.join( + config['LANG_DATA'], f"{config['SL']}-{config['TL']}.t1x.bin") + t2x_bin = os.path.join( + config['LANG_DATA'], f"{config['SL']}-{config['TL']}.t2x.bin") + t3x_bin = os.path.join( + config['LANG_DATA'], f"{config['SL']}-{config['TL']}.t3x.bin") + sl_tl_autogen = os.path.join( + config['LANG_DATA'], f"{config['SL']}-{config['TL']}.autobil.bin") + sl_tl_autopgen = os.path.join( + config['LANG_DATA'], f"{config['SL']}-{config['TL']}.autobil.bin") + + with open(sl_tagged) as f_in: + with open(biltrans, 'w') as f_out: + call(['multitrans', '-b', '-t', sl_tl_autobil], stdin=f_in, stdout=f_out, stderr=log) + + # f_in.seek(0) + # with open(multi_trimmed, 'w') as f_out: + # call(['multitrans', '-m', '-t', sl_tl_autobil], stdin=f_in, stdout=f_out, stderr=log) + + with open(ambig, 'w') as f_out: + call(['paste', lines, biltrans], stdout=f_out, stderr=log) + + with open(multi_trimmed) as f_in, open(ranked, 'w') as f_out: + cmds = [['apertium-transfer', '-b', t1x, t1x_bin], ['apertium-interchunk', t2x, t2x_bin], + ['apertium-postchunk', t3x, t3x_bin], ['lt-proc', '-g', sl_tl_autogen], ['lt-proc', '-p', sl_tl_autopgen], + ['irstlm-ranker', tl_lm, multi_trimmed, '-f']] + pipe(cmds, f_in, f_out, log).wait() + + with open(annotated, 'w') as f_out: + call(['paste', multi_trimmed, ranked], stdout=f_out, stderr=log) + + # extract frac freq + mod = import_module('biltrans-extract-frac-freq') + extract_frac_freq = getattr(mod, 'biltrans_extract_frac_freq') + with open(lex_freq, 'w') as f, redirect_stdout(f), redirect_stderr(log): + extract_frac_freq(ambig, annotated) + + # ngrams + mod = import_module('biltrans-count-patterns-ngrams') + count_patterns_ngrams = getattr(mod, 'biltrans_count_patterns_ngrams') + with open(ngrams, 'w') as f, redirect_stdout(f), redirect_stderr(log): + count_patterns_ngrams(lex_freq, ambig, annotated) + + # patterns + mod = import_module('ngram-pruning-frac') + ngram_pruning_frac = getattr(mod, 'ngram_pruning_frac') + with open(patterns, 'w') as f, redirect_stdout(f), redirect_stderr(log): + ngram_pruning_frac(lex_freq, ngrams) + + # extracting rules + mod = import_module('ngrams-to-rules') + ngrams_to_rules = getattr(mod, 'ngrams_to_rules') + with open(rules, 'w') as f, redirect_stdout(f), redirect_stderr(log): + ngrams_to_rules(patterns, config['CRISPHOLD']) def main(config_file): print("validating configuration....")