commit 225ed7e9034290e00a01c23599dc30d1713d1995 Author: Daniel Swanson Date: Thu Aug 5 15:54:01 2021 -0500 HMM nearly working diff --git a/apertium/Makefile.am b/apertium/Makefile.am index d8225d4..3b27891 100644 --- a/apertium/Makefile.am +++ b/apertium/Makefile.am @@ -196,6 +196,7 @@ bin_PROGRAMS = apertium-cleanstream \ apertium-rexpresstag \ apertium-tagger \ apertium-tagger-apply-new-rules \ + apertium-tagger-new \ apertium-tagger-readwords \ apertium-perceptron-trace \ apertium-tmxbuild \ @@ -260,6 +261,9 @@ apertium_postlatex_raw_LDADD= -lapertium$(VERSION_MAJOR) $(lib_LTLIBRARIES) apertium_tagger_SOURCES = apertium_tagger.cc apertium_tagger_LDADD = -lapertium$(VERSION_MAJOR) $(lib_LTLIBRARIES) +apertium_tagger_new_SOURCES = apertium_tagger_new.cc +apertium_tagger_new_LDADD = -lapertium$(VERSION_MAJOR) $(lib_LTLIBRARIES) + apertium_perceptron_trace_SOURCES = apertium_perceptron_trace.cc apertium_perceptron_trace_LDADD = -lapertium$(VERSION_MAJOR) $(lib_LTLIBRARIES) diff --git a/apertium/apertium_tagger_new.cc b/apertium/apertium_tagger_new.cc new file mode 100644 index 0000000..81642e6 --- /dev/null +++ b/apertium/apertium_tagger_new.cc @@ -0,0 +1,19 @@ +#include +#include + +#include + +int main(int argc, char** argv) +{ + LtLocale::tryToSetLocale(); + TaggerExe t; + InputFile input; + UFILE* output = u_finit(stdout, NULL, NULL); + FILE* tg = fopen(argv[1], "rb"); + std::cerr << "about to load\n"; + t.load(tg); + std::cerr << "loaded\n"; + t.tag_hmm(input, output); + std::cerr << "tagged\n"; + return EXIT_SUCCESS; +} diff --git a/apertium/tagger_data_exe.cc b/apertium/tagger_data_exe.cc index 3287ff7..2523e9d 100644 --- a/apertium/tagger_data_exe.cc +++ b/apertium/tagger_data_exe.cc @@ -25,6 +25,12 @@ #include #include +#include + +TaggerDataExe::TaggerDataExe() + : alpha(AlphabetExe(&str_write)) +{} + uint64_t deserialise_int(FILE* in) { uint64_t ret = 0; @@ -291,7 +297,7 @@ TaggerDataExe::read_compressed_hmm_lsw(FILE* in, bool is_hmm) if (output_offsets[i+1] - output_offsets[i] == open_class.size()) { bool match = true; for (uint64_t j = 0; j < open_class.size(); j++) { - if (open_class[j] != output[output_offsets[i]+j]) { + if (open_class[j] != out[output_offsets[i]+j]) { match = false; break; } @@ -358,8 +364,9 @@ TaggerDataExe::read_compressed_hmm_lsw(FILE* in, bool is_hmm) int len = Compression::multibyte_read(in); if (len == 1) { Compression::string_read(in); - trans.read_compressed(in, temp); + trans.read_compressed(in, temp, true); finals_count = Compression::multibyte_read(in); + finals = new int_int[finals_count]; for (uint64_t i = 0; i < finals_count; i++) { finals[i].i1 = Compression::multibyte_read(in); finals[i].i2 = Compression::multibyte_read(in); @@ -395,3 +402,63 @@ TaggerDataExe::get_ambiguity_class(const std::set& tags) } return ret; } + +bool +TaggerDataExe::search(str_int* ptr, uint64_t count, StringRef key, uint64_t& val) +{ + UString_view k = str_write.get(key); + int64_t l = 0, r = count-1, m; + while (l <= r) { + m = (l + r) / 2; + if (str_write.get(ptr[m].s) == k) { + val = ptr[m].i; + return true; + } else if (str_write.get(ptr[m].s) < k) { + l = m + 1; + } else { + r = m - 1; + } + } + return false; +} + +bool +TaggerDataExe::search(str_str_int* ptr, uint64_t count, + StringRef key1, StringRef key2, uint64_t& val) +{ + UString_view k1 = str_write.get(key1); + UString_view k2 = str_write.get(key2); + int64_t l = 0, r = count-1, m; + while (l <= r) { + m = (l + r) / 2; + if (str_write.get(ptr[m].s1) == k1 && str_write.get(ptr[m].s2) == k2) { + val = ptr[m].i; + return true; + } else if ((str_write.get(ptr[m].s1) < k1) || + (str_write.get(ptr[m].s1) == k1 && + str_write.get(ptr[m].s2) < k2)) { + l = m + 1; + } else { + r = m - 1; + } + } + return false; +} + +bool +TaggerDataExe::search(int_int* ptr, uint64_t count, uint64_t key, uint64_t& val) +{ + int64_t l = 0, r = count-1, m; + while (l <= r) { + m = (l + r) / 2; + if (ptr[m].i1 == key) { + val = ptr[m].i2; + return true; + } else if (ptr[m].i1 < key) { + l = m + 1; + } else { + r = m - 1; + } + } + return false; +} diff --git a/apertium/tagger_data_exe.h b/apertium/tagger_data_exe.h index 7550d43..db36365 100644 --- a/apertium/tagger_data_exe.h +++ b/apertium/tagger_data_exe.h @@ -133,6 +133,8 @@ map, double> weights */ public: + TaggerDataExe(); + void read_compressed_unigram1(FILE* in); void read_compressed_unigram2(FILE* in); void read_compressed_unigram3(FILE* in); @@ -150,6 +152,11 @@ public: inline double getD(uint64_t i, uint64_t j, uint64_t k) const { return lsw_d[i*N*N + j*N + k]; } + + bool search(str_int* ptr, uint64_t count, StringRef key, uint64_t& val); + bool search(str_str_int* ptr, uint64_t count, StringRef k1, StringRef k2, + uint64_t& val); + bool search(int_int* ptr, uint64_t count, uint64_t key, uint64_t& val); }; #endif diff --git a/apertium/tagger_exe.cc b/apertium/tagger_exe.cc index e27ab3c..31be67e 100644 --- a/apertium/tagger_exe.cc +++ b/apertium/tagger_exe.cc @@ -3,6 +3,8 @@ #include #include +#include + using namespace std; using namespace Apertium; @@ -14,6 +16,17 @@ TaggerExe::build_match_finals() } } +void +TaggerExe::build_prefer_rules() +{ + prefer_rules = vector(tde.prefer_rules_count); + for (uint64_t i = 0; i < tde.prefer_rules_count; i++) { + UString temp = UString{tde.str_write.get(tde.prefer_rules[i])}; + temp = StringUtils::substitute(temp, "<*>"_u, "(<[^>]+>)+"_u); + prefer_rules[i].compile(temp); + } +} + StreamedType TaggerExe::read_streamed_type(InputFile& input) { @@ -67,19 +80,26 @@ TaggerExe::read_tagger_word(InputFile& input) return nullptr; } + uint64_t ca_tag_keof = tde.tag_index_count; + uint64_t ca_tag_kundef = tde.tag_index_count; + tde.search(tde.tag_index, tde.tag_index_count, + tde.str_write.add("TAG_kEOF"_u), ca_tag_keof); + tde.search(tde.tag_index, tde.tag_index_count, + tde.str_write.add("TAG_kUNDEF"_u), ca_tag_kundef); + size_t index = 0; tagger_word_buffer.push_back(new TaggerWord()); tagger_word_buffer[index]->add_ignored_string(input.readBlank(true)); UChar32 c = input.get(); if (input.eof() || (null_flush && c == '\0')) { end_of_file = true; - //tagger_word_buffer[index]->add_tag(ca_tag_keof, ""_u, tde.prefer_rules); + tagger_word_buffer[index]->add_tag(ca_tag_keof, ""_u, prefer_rules); } else { // c == ^ UString buf = input.readBlock('^', '$'); if (buf.back() != '$' && (input.eof() || (null_flush && input.peek() == '\0'))) { tagger_word_buffer[index]->add_ignored_string(buf); - //tagger_word_buffer[index]->add_tag(ca_tag_keof, ""_u, tde.prefer_rules); + tagger_word_buffer[index]->add_tag(ca_tag_keof, ""_u, prefer_rules); return read_tagger_word(input); } buf = buf.substr(1, buf.size()-2); @@ -87,9 +107,9 @@ TaggerExe::read_tagger_word(InputFile& input) UString surf = UString{pieces[0]}; tagger_word_buffer[index]->set_superficial_form(surf); if (pieces.size() > 1) { - for (auto& it : pieces) { + for (size_t p = 1; p < pieces.size(); p++) { index = 0; - vector segments = StringUtils::split_escape(it, '+'); + vector segments = StringUtils::split_escape(pieces[p], '+'); MatchState2 state(&tde.trans); size_t start = 0; int tag = -1; @@ -111,11 +131,11 @@ TaggerExe::read_tagger_word(InputFile& input) tmp += '+'; tmp += segments[j]; } - /* if (debug) { + if (debug) { cerr<<"Warning: There is not coarse tag for the fine tag '" << tmp <<"'\n"; cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; - }*/ - //tagger_word_buffer[index]->add_tag(ca_tag_kundef, tmp, tde.prefer); + } + tagger_word_buffer[index]->add_tag(ca_tag_kundef, tmp, prefer_rules); break; } else if (state.empty() || (i + 1 == segments.size() && val == -1)) { UString tmp; @@ -125,7 +145,7 @@ TaggerExe::read_tagger_word(InputFile& input) } tmp += segments[j]; } - // tagger_word_buffer[index]->add_tag(tag, tmp, tde.prefer); + tagger_word_buffer[index]->add_tag(tag, tmp, prefer_rules); if (last_pos < segments.size()) { start = last_pos; tagger_word_buffer[index]->set_plus_cut(true); @@ -137,13 +157,6 @@ TaggerExe::read_tagger_word(InputFile& input) } i = start - 1; } - if (tag == -1) { - // tag = ca_tag_kundef; - /* if (debug) { - cerr<<"Warning: There is not coarse tag for the fine tag '" << tmp <<"'\n"; - cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; - }*/ - } UString tmp; for (size_t j = start; j < segments.size(); j++) { if (!tmp.empty()) { @@ -151,7 +164,14 @@ TaggerExe::read_tagger_word(InputFile& input) } tmp += segments[j]; } - // tagger_word_buffer[index]->add_tag(tag, tmp, tde.prefer); + if (tag == -1) { + tag = ca_tag_kundef; + if (debug) { + cerr<<"Warning: There is not coarse tag for the fine tag '" << tmp <<"'\n"; + cerr<<" This is because of an incomplete tagset definition or a dictionary error\n"; + } + } + tagger_word_buffer[index]->add_tag(tag, tmp, prefer_rules); } } } @@ -163,14 +183,24 @@ void TaggerExe::tag_hmm(InputFile& input, UFILE* output) { build_match_finals(); + build_prefer_rules(); + vector arr_tg; + for (uint64_t i = 0; i < tde.array_tags_count; i++) { + arr_tg.push_back(UString{tde.str_write.get(tde.array_tags[i])}); + } + TaggerWord::setArrayTags(arr_tg); vector> alpha(2, vector(tde.N)); vector>> best(2, vector>(tde.N)); set tags, pretags; - //tags.insert(eos); - //alpha[0][eos] = 1; + uint64_t eos; + tde.search(tde.tag_index, tde.tag_index_count, tde.str_write.add("TAG_SENT"_u), + eos); + + tags.insert(eos); + alpha[0][eos] = 1; vector words; TaggerWord* cur_word = read_tagger_word(input); @@ -179,8 +209,15 @@ TaggerExe::tag_hmm(InputFile& input, UFILE* output) words.push_back(*cur_word); pretags.swap(tags); + tags.clear(); //tags = cur_word->get_tags(); + for (auto& it : cur_word->get_tags()) { + tags.insert(it); + } if (tags.empty()) { + uint64_t s = tde.output_offsets[tde.open_class_index]; + uint64_t e = tde.output_offsets[tde.open_class_index+1]; + tags.insert(tde.output+s, tde.output+e); // tags = tde.getOpenClass(); } @@ -188,6 +225,10 @@ TaggerExe::tag_hmm(InputFile& input, UFILE* output) //clear_array_double(&alpha[nwpend%2][0], N); //clear_array_vector(&best[nwpend%2][0], N); + for (uint64_t i = 0; i < tde.N; i++) { + alpha[words.size()%2][i] = 0.0; + best[words.size()%2][i].clear(); + } if (cls < tde.output_count) { // if it's a new ambiguity class, weights will all be 0 @@ -209,25 +250,33 @@ TaggerExe::tag_hmm(InputFile& input, UFILE* output) if (tags.size() == 1) { uint64_t tag = *tags.begin(); double prob = alpha[words.size()%2][tag]; - /*if (prob <= 0 && debug) { - cerr << <<"Problem with word '"<get_superficial_form()<<"' "<get_string_tags()<<"\n"; - } - */ - uint64_t eof = 0; // tde.getTagIndex("TAG_kEOF"_u) + if (prob <= 0){// && debug) { + cerr <<"Problem with word '"<< cur_word->get_superficial_form()<<"' "<get_string_tags()<<"\n"; + } + uint64_t eof = tde.tag_index_count; + tde.search(tde.tag_index, tde.tag_index_count, + tde.str_write.add("TAG_kEOF"_u), eof); for (uint64_t t = 0; t < best[words.size()%2][tag].size(); t++) { - if (true) { // (TheFlags.getFirst()) { + if (false) { // (TheFlags.getFirst()) { //write(words[t].get_all_chosen_tag_first(best[words.size()%2][tag][t], // eof), // output); } else { - words[t].set_show_sf(false); // TheFlags.getShowSuperficial() - //write(words[t].get_lexical_form(best[words.size()%2][tag][t], eof), - // output); + words[t].set_show_sf(show_superficial); + write(words[t].get_lexical_form(best[words.size()%2][tag][t], eof), + output); } } words.clear(); alpha[0][tag] = 1; } delete cur_word; + cur_word = read_tagger_word(input); } } + +void +TaggerExe::load(FILE* in) +{ + tde.read_compressed_hmm_lsw(in, true); +} diff --git a/apertium/tagger_exe.h b/apertium/tagger_exe.h index fcf049c..833985d 100644 --- a/apertium/tagger_exe.h +++ b/apertium/tagger_exe.h @@ -19,6 +19,7 @@ #define _TAGGER_EXE_ #include +#include #include #include #include @@ -29,15 +30,21 @@ class TaggerExe { private: bool null_flush = true; + bool debug = false; + bool show_superficial = false; + bool end_of_file = false; std::vector tagger_word_buffer; std::map match_finals; void build_match_finals(); + std::vector prefer_rules; + void build_prefer_rules(); public: TaggerDataExe tde; Apertium::StreamedType read_streamed_type(InputFile& input); TaggerWord* read_tagger_word(InputFile& input); void tag_hmm(InputFile& input, UFILE* output); + void load(FILE* in); }; #endif diff --git a/apertium/tagger_word.cc b/apertium/tagger_word.cc index 16ec8a6..9d1b2ba 100644 --- a/apertium/tagger_word.cc +++ b/apertium/tagger_word.cc @@ -115,6 +115,23 @@ TaggerWord::add_tag(TTag &t, const UString &lf, vector const &prefer_ru } } +void +TaggerWord::add_tag(const uint64_t t, const UString& lf, const vector& prefer_rules) +{ + TTag tg = static_cast(t); + if (tags.find(tg) == tags.end()) { + tags.insert(tg); + lexical_forms[tg] = lf; + } else { + for (auto& it : prefer_rules) { + if (!it.match(lf).empty()) { + lexical_forms[tg] = lf; + break; + } + } + } +} + set& TaggerWord::get_tags() { return tags; @@ -142,6 +159,31 @@ TaggerWord::get_string_tags() { return st; } +UString +TaggerWord::get_lexical_form(const uint64_t t, const uint64_t TAG_kEOF) +{ + UString ret; + if (show_ignored_string) { + ret.append(ignored_string); + } + if (t == TAG_kEOF) { + return ret; + } + if (!previous_plus_cut) { + ret += '^'; + } + if (lexical_forms.empty()) { + ret += '*'; + ret.append(superficial_form); + } else { + ret.append(lexical_forms[static_cast(t)]); + } + if (ret != ignored_string) { + ret += (plus_cut ? '+' : '$'); + } + return ret; +} + UString TaggerWord::get_lexical_form(TTag &t, int const TAG_kEOF) { UString ret; diff --git a/apertium/tagger_word.h b/apertium/tagger_word.h index 560500a..b8b40e6 100644 --- a/apertium/tagger_word.h +++ b/apertium/tagger_word.h @@ -88,6 +88,7 @@ public: * @param lf the lexical form (fine tag) */ virtual void add_tag(TTag &t, const UString &lf, vector const &prefer_rules); + virtual void add_tag(const uint64_t t, const UString& lf, const vector& prefer_rules); /** Get the set of tags of this word. * @return set of tags. @@ -103,6 +104,7 @@ public: * @return the lexical form of tag t */ virtual UString get_lexical_form(TTag &t, int const TAG_kEOF); + virtual UString get_lexical_form(const uint64_t t, const uint64_t teof); UString get_all_chosen_tag_first(TTag &t, int const TAG_kEOF);