commit 3b3bef2607666ed130b03be743e6e62757497e2e Author: Daniel Swanson Date: Wed Aug 4 17:36:16 2021 -0500 finish loading HMM & LSW and write most of executor for HMM diff --git a/apertium/Makefile.am b/apertium/Makefile.am index 8843844..d8225d4 100644 --- a/apertium/Makefile.am +++ b/apertium/Makefile.am @@ -44,6 +44,7 @@ h_sources = a.h \ tagger_data_hmm.h \ tagger_data_lsw.h \ tagger_data_percep_coarse_tags.h \ + tagger_exe.h \ tagger_flags.h \ tagger_utils.h \ tagger_word.h \ @@ -114,6 +115,7 @@ cc_sources = a.cc \ tagger_data_hmm.cc \ tagger_data_lsw.cc \ tagger_data_percep_coarse_tags.cc \ + tagger_exe.cc \ tagger_flags.cc \ tagger_utils.cc \ tagger_word.cc \ diff --git a/apertium/file_morpho_stream.cc b/apertium/file_morpho_stream.cc index 4f1ed7f..ebbe341 100644 --- a/apertium/file_morpho_stream.cc +++ b/apertium/file_morpho_stream.cc @@ -128,12 +128,12 @@ FileMorphoStream::lrlmClassify(UString const &str, int &ivwords) { if(str[j] == '\\') { - j++; + j++; } else if(str[j] == '>') { - tag = str.substr(i, j-i+1); - i = j; + tag = str.substr(i, j-i+1); + i = j; break; } } @@ -156,59 +156,57 @@ FileMorphoStream::lrlmClassify(UString const &str, int &ivwords) vwords[ivwords]->add_tag(last_type, str.substr(floor, last_pos - floor + 1), td->getPreferRules()); - if(str[last_pos+1] == '+' && last_pos+1 < limit ) - { - floor = last_pos + 1; - last_pos = floor + 1; + if(str[last_pos+1] == '+' && last_pos+1 < limit ) { + floor = last_pos + 1; + last_pos = floor + 1; vwords[ivwords]->set_plus_cut(true); if (((int)vwords.size())<=((int)(ivwords+1))) vwords.push_back(new TaggerWord(true)); ivwords++; - ms.init(me->getInitial()); - } - i = floor++; + ms.init(me->getInitial()); + } + i = floor++; } else { if (debug) { - cerr<<"Warning: There is not coarse tag for the fine tag '"<< str.substr(floor) <<"'\n"; + cerr<<"Warning: There is not coarse tag for the fine tag '"<< str.substr(floor) <<"'\n"; cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; - } + } vwords[ivwords]->add_tag(ca_tag_kundef, str.substr(floor) , td->getPreferRules()); - return; + return; } } else if(i == limit - 1) { if(ms.classifyFinals(me->getFinals()) == -1) { - if(last_pos != floor) - { - vwords[ivwords]->add_tag(last_type, + if(last_pos != floor) { + vwords[ivwords]->add_tag(last_type, str.substr(floor, last_pos - floor + 1), td->getPreferRules()); if(str[last_pos+1] == '+' && last_pos+1 < limit ) { floor = last_pos + 1; - last_pos = floor; + last_pos = floor; vwords[ivwords]->set_plus_cut(true); if (((int)vwords.size())<=((int)(ivwords+1))) vwords.push_back(new TaggerWord(true)); ivwords++; ms.init(me->getInitial()); - } - i = floor++; + } + i = floor++; } else { if (debug) { - cerr<<"Warning: There is not coarse tag for the fine tag '"<< str.substr(floor) <<"'\n"; + cerr<<"Warning: There is not coarse tag for the fine tag '"<< str.substr(floor) <<"'\n"; cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; - } + } vwords[ivwords]->add_tag(ca_tag_kundef, str.substr(floor) , td->getPreferRules()); - return; + return; } } } diff --git a/apertium/hmm.cc b/apertium/hmm.cc index e5b7960..c18e995 100644 --- a/apertium/hmm.cc +++ b/apertium/hmm.cc @@ -719,14 +719,14 @@ HMM::tagger(MorphoStream &morpho_stream, UFILE* Output) { for (itag=tags.begin(); itag!=tags.end(); itag++) { //For all tag from the current word i=*itag; for (jtag=pretags.begin(); jtag!=pretags.end(); jtag++) { //For all tags from the previous word - j=*jtag; - x = alpha[1-nwpend%2][j]*(tdhmm.getA())[j][i]*(tdhmm.getB())[i][k]; - if (alpha[nwpend%2][i]<=x) { - if (nwpend>1) - best[nwpend%2][i] = best[1-nwpend%2][j]; - best[nwpend%2][i].push_back(i); - alpha[nwpend%2][i] = x; - } + j=*jtag; + x = alpha[1-nwpend%2][j]*(tdhmm.getA())[j][i]*(tdhmm.getB())[i][k]; + if (alpha[nwpend%2][i]<=x) { + if (nwpend>1) + best[nwpend%2][i] = best[1-nwpend%2][j]; + best[nwpend%2][i].push_back(i); + alpha[nwpend%2][i] = x; + } } } @@ -736,21 +736,21 @@ HMM::tagger(MorphoStream &morpho_stream, UFILE* Output) { prob = alpha[nwpend%2][tag]; if (prob>0) - loli -= log(prob); + loli -= log(prob); else { if (TheFlags.getDebug()) - cerr<<"Problem with word '"<get_superficial_form()<<"' "<get_string_tags()<<"\n"; + cerr<<"Problem with word '"<get_superficial_form()<<"' "<get_string_tags()<<"\n"; } for (unsigned t=0; t +#include #include +#include #include #include #include @@ -210,12 +212,11 @@ void TaggerDataExe::read_compressed_hmm_lsw(FILE* in, bool is_hmm) { // open_class - open_class_count = Compression::multibyte_read(in); - open_class = new int32_t[open_class_count]; - int32_t val = 0; - for (uint64_t i = 0; i < open_class_count; i++) { + std::vector open_class; + uint64_t val = 0; + for (uint64_t i = Compression::multibyte_read(in); i > 0; i--) { val += Compression::multibyte_read(in); - open_class[i] = val; + open_class.push_back(val); } // forbid_rules @@ -242,16 +243,74 @@ TaggerDataExe::read_compressed_hmm_lsw(FILE* in, bool is_hmm) } // enforce_rules - // TODO + enforce_rules_count = Compression::multibyte_read(in); + enforce_rules_offsets = new uint64_t[enforce_rules_count+1]; + std::vector enf; + for (uint64_t i = 0; i < enforce_rules_count; i++) { + enforce_rules_offsets[i] = enf.size(); + enf.push_back(Compression::multibyte_read(in)); + for (uint64_t j = Compression::multibyte_read(in); j > 0; j--) { + enf.push_back(Compression::multibyte_read(in)); + } + } + enforce_rules_offsets[enforce_rules_count] = enf.size(); + enforce_rules = new uint64_t[enf.size()]; + for (uint64_t i = 0; i < enf.size(); i++) { + enforce_rules[i] = enf[i]; + } // prefer_rules - // TODO + prefer_rules_count = Compression::multibyte_read(in); + prefer_rules = new StringRef[prefer_rules_count]; + for (uint64_t i = 0; i < prefer_rules_count; i++) { + prefer_rules[i] = str_write.add(Compression::string_read(in)); + } // constants - // TODO + constants_count = Compression::multibyte_read(in); + constants = new str_int[constants_count]; + for (uint64_t i = 0; i < constants_count; i++) { + constants[i].s = str_write.add(Compression::string_read(in)); + constants[i].i = Compression::multibyte_read(in); + } // output - // TODO + output_count = Compression::multibyte_read(in); + // +2 in case we need to append open_class + output_offsets = new uint64_t[output_count+2]; + std::vector out; + for (uint64_t i = 0; i < output_count; i++) { + output_offsets[i] = out.size(); + for (uint64_t j = Compression::multibyte_read(in); j > 0; j--) { + out.push_back(Compression::multibyte_read(in)); + } + } + output_offsets[output_count] = out.size(); + open_class_index = output_count; + for (uint64_t i = 0; i < output_count; i++) { + if (output_offsets[i+1] - output_offsets[i] == open_class.size()) { + bool match = true; + for (uint64_t j = 0; j < open_class.size(); j++) { + if (open_class[j] != output[output_offsets[i]+j]) { + match = false; + break; + } + } + if (match) { + open_class_index = i; + break; + } + } + } + if (open_class_index == output_count) { + output_count++; + out.insert(out.end(), open_class.begin(), open_class.end()); + output_offsets[output_count] = out.size(); + } + output = new uint64_t[out.size()]; + for (uint64_t i = 0; i < out.size(); i++) { + output[i] = out[i]; + } if (is_hmm) { // dimensions @@ -284,13 +343,28 @@ TaggerDataExe::read_compressed_hmm_lsw(FILE* in, bool is_hmm) for (uint64_t count = Compression::multibyte_read(in); count > 0; count--) { uint64_t i = Compression::multibyte_read(in); uint64_t j = Compression::multibyte_read(in); - uint64_t k = Compressoin::multibyte_read(in); - d[(i*N*N)+(j*N)+k] = read_compressed_double(in); + uint64_t k = Compression::multibyte_read(in); + lsw_d[(i*N*N)+(j*N)+k] = read_compressed_double(in); } } // pattern list - // TODO + Alphabet temp; + fpos_t pos; + fgetpos(in, &pos); + alpha.read(in, false); + fsetpos(in, &pos); + temp.read(in); + int len = Compression::multibyte_read(in); + if (len == 1) { + Compression::string_read(in); + trans.read_compressed(in, temp); + finals_count = Compression::multibyte_read(in); + for (uint64_t i = 0; i < finals_count; i++) { + finals[i].i1 = Compression::multibyte_read(in); + finals[i].i2 = Compression::multibyte_read(in); + } + } // discard discard_count = Compression::multibyte_read(in); @@ -302,3 +376,22 @@ TaggerDataExe::read_compressed_hmm_lsw(FILE* in, bool is_hmm) discard[i] = str_write.add(Compression::string_read(in)); } } + +uint64_t +TaggerDataExe::get_ambiguity_class(const std::set& tags) +{ + uint64_t ret = open_class_index; + uint64_t len = output_offsets[ret+1] - output_offsets[ret]; + for (uint64_t i = 0; i < output_count; i++) { + uint64_t loc_len = output_offsets[i+1] - output_offsets[i]; + if (loc_len < tags.size() || loc_len >= len) { + continue; + } + if (std::includes(output+output_offsets[i], output+output_offsets[i+1], + tags.begin(), tags.end())) { + ret = i; + len = loc_len; + } + } + return ret; +} diff --git a/apertium/tagger_data_exe.h b/apertium/tagger_data_exe.h index 472c3fa..7550d43 100644 --- a/apertium/tagger_data_exe.h +++ b/apertium/tagger_data_exe.h @@ -19,7 +19,10 @@ #define _TAGGER_DATA_EXE_ #include +#include #include +#include +#include struct str_int { StringRef s; @@ -33,15 +36,15 @@ struct str_str_int { }; struct int_int { - int32_t i1; - int32_t i2; + uint64_t i1; + uint64_t i2; }; class TaggerDataExe { private: bool mmapping = false; -protected: +public: StringWriter str_write; /** @@ -74,9 +77,6 @@ protected: /** * HMM & LSW shared data */ - int32_t* open_class = nullptr; - uint64_t open_class_count = 0; - int_int* forbid_rules = nullptr; uint64_t forbid_rules_count = 0; @@ -86,16 +86,29 @@ protected: str_int* tag_index = nullptr; uint64_t tag_index_count = 0; - // enforce_rules + // [ tagi, taj1, tagj2, tagi, tagj1, tagj2, tagj3 ] + // [ 0, 3, 7 ] + uint64_t* enforce_rules = nullptr; + uint64_t* enforce_rules_offsets = nullptr; + uint64_t enforce_rules_count = 0; - // prefer_rules + StringRef* prefer_rules = nullptr; + uint64_t prefer_rules_count = 0; str_int* constants = nullptr; uint64_t constants_count = 0; - // output + // output = ambiguity classes + uint64_t* output = nullptr; + uint64_t* output_offsets = nullptr; + uint64_t output_count = 0; + uint64_t open_class_index = 0; // which ambiguity class is open_class - // pattern_list + // patterns + AlphabetExe alpha; + TransducerExe trans; + int_int* finals = nullptr; + uint64_t finals_count = 0; StringRef* discard = nullptr; uint64_t discard_count = 0; @@ -103,21 +116,12 @@ protected: /** * HMM and LSW weight matrices */ - uint64_t M = 0; - uint64_t N = 0; + uint64_t M = 0; // HMM number of ambiguity classes + uint64_t N = 0; // HMM number of known tags double* hmm_a = nullptr; // NxN double* hmm_b = nullptr; // NxM double* lsw_d = nullptr; // NxNxN - /* HMM -vector forbid_rules -vector enforce_rules -vector prefer_rules -Collection output -PatternList plist -vector discard - */ - /* perceptron map, double> weights int beam_width @@ -126,8 +130,6 @@ map, double> weights vector features vector global_defns FeatureDefn global_pred - Collection output - PatternList plist */ public: @@ -136,6 +138,18 @@ public: void read_compressed_unigram3(FILE* in); void read_compressed_hmm_lsw(FILE* in, bool is_hmm=true); void read_compressed_perceptron(FILE* in); + + uint64_t get_ambiguity_class(const std::set& tags); + + inline double getA(uint64_t i, uint64_t j) const { + return hmm_a[i*M + j]; + } + inline double getB(uint64_t i, uint64_t j) const { + return hmm_b[i*N + j]; + } + inline double getD(uint64_t i, uint64_t j, uint64_t k) const { + return lsw_d[i*N*N + j*N + k]; + } }; #endif diff --git a/apertium/tagger_exe.cc b/apertium/tagger_exe.cc new file mode 100644 index 0000000..e27ab3c --- /dev/null +++ b/apertium/tagger_exe.cc @@ -0,0 +1,233 @@ +#include + +#include +#include + +using namespace std; +using namespace Apertium; + +void +TaggerExe::build_match_finals() +{ + for (uint64_t i = 0; i < tde.finals_count; i++) { + match_finals[tde.finals[i].i1] = tde.finals[i].i2; + } +} + +StreamedType +TaggerExe::read_streamed_type(InputFile& input) +{ + StreamedType ret; + ret.TheString = input.readBlank(true); + if (!input.eof() && input.peek() == '^') { + input.get(); + ret.TheLexicalUnit = LexicalUnit(); + UChar32 c = input.get(); + while (c != '/' && c != '$') { + ret.TheLexicalUnit->TheSurfaceForm += c; + if (c == '\\') { + ret.TheLexicalUnit->TheSurfaceForm += input.get(); + } + c = input.get(); + } + // maybe error here if surface form is empty or we hit $ + c = input.get(); + if (c == '*') { + input.readBlock(c, '$'); + } else { + input.unget(c); + do { + ret.TheLexicalUnit->TheAnalyses.push_back(Analysis()); + ret.TheLexicalUnit->TheAnalyses.back().read(input); + c = input.get(); + } while (c == '/'); + // error if c != $ + } + } + return ret; +} + +TaggerWord* +TaggerExe::read_tagger_word(InputFile& input) +{ + if (!tagger_word_buffer.empty()) { + TaggerWord* ret = tagger_word_buffer[0]; + tagger_word_buffer.erase(tagger_word_buffer.begin()); + if (ret->isAmbiguous()) { + for (uint64_t i = 0; i < tde.discard_count; i++) { + // TODO: have TaggerWord accept UString_view + UString temp = UString{tde.str_write.get(tde.discard[i])}; + ret->discardOnAmbiguity(temp); + } + } + return ret; + } + + if (input.eof()) { + return nullptr; + } + + size_t index = 0; + tagger_word_buffer.push_back(new TaggerWord()); + tagger_word_buffer[index]->add_ignored_string(input.readBlank(true)); + UChar32 c = input.get(); + if (input.eof() || (null_flush && c == '\0')) { + end_of_file = true; + //tagger_word_buffer[index]->add_tag(ca_tag_keof, ""_u, tde.prefer_rules); + } else { // c == ^ + UString buf = input.readBlock('^', '$'); + if (buf.back() != '$' && + (input.eof() || (null_flush && input.peek() == '\0'))) { + tagger_word_buffer[index]->add_ignored_string(buf); + //tagger_word_buffer[index]->add_tag(ca_tag_keof, ""_u, tde.prefer_rules); + return read_tagger_word(input); + } + buf = buf.substr(1, buf.size()-2); + vector pieces = StringUtils::split_escape(buf, '/'); + UString surf = UString{pieces[0]}; + tagger_word_buffer[index]->set_superficial_form(surf); + if (pieces.size() > 1) { + for (auto& it : pieces) { + index = 0; + vector segments = StringUtils::split_escape(it, '+'); + MatchState2 state(&tde.trans); + size_t start = 0; + int tag = -1; + size_t last_pos = 0; + for (size_t i = 0; i < segments.size(); i++) { + if (i != 0) { + state.step('+'); + } + state.step(segments[i], tde.alpha); + int val = state.classifyFinals(match_finals); + if (val != -1) { + tag = val; + last_pos = i+1; + } + if (last_pos == start && + (state.empty() || i == segments.size() - 1)) { + UString tmp = UString{segments[i]}; + for (size_t j = i+1; j < segments.size(); j++) { + tmp += '+'; + tmp += segments[j]; + } + /* if (debug) { + cerr<<"Warning: There is not coarse tag for the fine tag '" << tmp <<"'\n"; + cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; + }*/ + //tagger_word_buffer[index]->add_tag(ca_tag_kundef, tmp, tde.prefer); + break; + } else if (state.empty() || (i + 1 == segments.size() && val == -1)) { + UString tmp; + for (size_t j = start; j < last_pos; j++) { + if (!tmp.empty()) { + tmp += '+'; + } + tmp += segments[j]; + } + // tagger_word_buffer[index]->add_tag(tag, tmp, tde.prefer); + if (last_pos < segments.size()) { + start = last_pos; + tagger_word_buffer[index]->set_plus_cut(true); + index++; + if (index >= tagger_word_buffer.size()) { + tagger_word_buffer.push_back(new TaggerWord(true)); + } + state.clear(); + } + i = start - 1; + } + if (tag == -1) { + // tag = ca_tag_kundef; + /* if (debug) { + cerr<<"Warning: There is not coarse tag for the fine tag '" << tmp <<"'\n"; + cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; + }*/ + } + UString tmp; + for (size_t j = start; j < segments.size(); j++) { + if (!tmp.empty()) { + tmp += '+'; + } + tmp += segments[j]; + } + // tagger_word_buffer[index]->add_tag(tag, tmp, tde.prefer); + } + } + } + } + return read_tagger_word(input); +} + +void +TaggerExe::tag_hmm(InputFile& input, UFILE* output) +{ + build_match_finals(); + + vector> alpha(2, vector(tde.N)); + vector>> best(2, vector>(tde.N)); + + set tags, pretags; + + //tags.insert(eos); + //alpha[0][eos] = 1; + + vector words; + TaggerWord* cur_word = read_tagger_word(input); + + while (cur_word) { + words.push_back(*cur_word); + + pretags.swap(tags); + //tags = cur_word->get_tags(); + if (tags.empty()) { + // tags = tde.getOpenClass(); + } + + uint64_t cls = tde.get_ambiguity_class(tags); + + //clear_array_double(&alpha[nwpend%2][0], N); + //clear_array_vector(&best[nwpend%2][0], N); + + if (cls < tde.output_count) { + // if it's a new ambiguity class, weights will all be 0 + for (auto& i : tags) { + for (auto& j : pretags) { + int loc = words.size() % 2; + double x = alpha[1-loc][j] * tde.getA(j, i) * tde.getB(i, cls); + if (alpha[loc][i] <= x) { + if (words.size() > 1) { + best[loc][i] = best[1-loc][j]; + } + best[loc][i].push_back(i); + alpha[loc][i] = x; + } + } + } + } + + if (tags.size() == 1) { + uint64_t tag = *tags.begin(); + double prob = alpha[words.size()%2][tag]; + /*if (prob <= 0 && debug) { + cerr << <<"Problem with word '"<get_superficial_form()<<"' "<get_string_tags()<<"\n"; + } + */ + uint64_t eof = 0; // tde.getTagIndex("TAG_kEOF"_u) + for (uint64_t t = 0; t < best[words.size()%2][tag].size(); t++) { + if (true) { // (TheFlags.getFirst()) { + //write(words[t].get_all_chosen_tag_first(best[words.size()%2][tag][t], + // eof), + // output); + } else { + words[t].set_show_sf(false); // TheFlags.getShowSuperficial() + //write(words[t].get_lexical_form(best[words.size()%2][tag][t], eof), + // output); + } + } + words.clear(); + alpha[0][tag] = 1; + } + delete cur_word; + } +} diff --git a/apertium/tagger_exe.h b/apertium/tagger_exe.h new file mode 100644 index 0000000..fcf049c --- /dev/null +++ b/apertium/tagger_exe.h @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2021 Apertium + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License as + * published by the Free Software Foundation; either version 2 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, see . + */ + +#ifndef _TAGGER_EXE_ +#define _TAGGER_EXE_ + +#include +#include +#include +#include +#include +#include +#include + +class TaggerExe { +private: + bool null_flush = true; + bool end_of_file = false; + std::vector tagger_word_buffer; + std::map match_finals; + void build_match_finals(); +public: + TaggerDataExe tde; + Apertium::StreamedType read_streamed_type(InputFile& input); + TaggerWord* read_tagger_word(InputFile& input); + void tag_hmm(InputFile& input, UFILE* output); +}; + +#endif