Index: branches/weighted-transfer/apertium-weights-learner/learner.py =================================================================== --- branches/weighted-transfer/apertium-weights-learner/learner.py (revision 70826) +++ branches/weighted-transfer/apertium-weights-learner/learner.py (revision 70827) @@ -74,7 +74,7 @@ pattern_line += ' \n' return pattern_line -def get_weights(rule_group, pattern, sent_line, tixfname, binfname, output_stream): +def get_weights(rule_group, pattern, sent_line, tixfname, binfname): total = 0. weights_list = [] @@ -90,7 +90,7 @@ weights_line += ' \n' + weights_tail with open('tmpweights.w1x', 'w', encoding='utf-8') as wfile: wfile.write(weights_line) - translation = translate(sent_line, tixfname, binfname, 'tmpweights.w1x', output_stream) + translation = translate(sent_line, tixfname, binfname, 'tmpweights.w1x') score = model.score(translation.lower(), bos = True, eos = True) weights_list.append([translation, math.exp(score), focus_rule[1]]) for translation, score, rule_id in weights_list: @@ -113,7 +113,7 @@ wfile.write(' \n') wfile.write(weights_tail) -def translate(sent_line, tixfname, binfname, weightsfname, output_stream): +def translate(sent_line, tixfname, binfname, weightsfname): pipe = pipes.Template() pipe.append('lt-proc -b {}'.format('.'.join((binfname, 'autobil.bin'))), '--') pipe.append('apertium-transfer -bw {} {} {}'.format(weightsfname, '.'.join((tixfname, 't1x')), '.'.join((binfname, 't1x.bin'))), '--') @@ -177,8 +177,7 @@ if '*' not in sent_line: coverage_list, parsed_line = coverage.process_line(sent_line, cat_dict, pattern_FST, - output_stream, - False, True, False) + None, False, True, False) # check if any rule in any coverage is ambiguous, # take first coverage and first rule if any for coverage_item in coverage_list: @@ -188,7 +187,7 @@ #for item in parsed_line: # print(str(item) + '\n') #print(coverage.coverage_to_groups(coverage_item) + '\n\n') - weights_list = get_weights(ambiguous_rules[rule_number], pattern, sent_line, tixfname, binfname, output_stream) + weights_list = get_weights(ambiguous_rules[rule_number], pattern, sent_line, tixfname, binfname) #print(weights_list) for translation, score, rule_id in weights_list: weights_dict[rule_number][rule_id].setdefault(tuple(pattern), 0.)