commit 7e4470877c189287dc5326d1426f9d58774c0843 Author: Daniel Swanson Date: Sat Aug 7 10:12:26 2021 -0500 most of unigram 1 and unigram 2 diff --git a/apertium/tagger_data_exe.cc b/apertium/tagger_data_exe.cc index 3e1dc3e..09ff2fc 100644 --- a/apertium/tagger_data_exe.cc +++ b/apertium/tagger_data_exe.cc @@ -406,16 +406,16 @@ TaggerDataExe::get_ambiguity_class(const std::set& tags) } bool -TaggerDataExe::search(str_int* ptr, uint64_t count, StringRef key, uint64_t& val) +TaggerDataExe::search(str_int* ptr, uint64_t count, UString_view key, + uint64_t& val) { - UString_view k = str_write.get(key); int64_t l = 0, r = count-1, m; while (l <= r) { m = (l + r) / 2; - if (str_write.get(ptr[m].s) == k) { + if (str_write.get(ptr[m].s) == key) { val = ptr[m].i; return true; - } else if (str_write.get(ptr[m].s) < k) { + } else if (str_write.get(ptr[m].s) < key) { l = m + 1; } else { r = m - 1; @@ -426,19 +426,17 @@ TaggerDataExe::search(str_int* ptr, uint64_t count, StringRef key, uint64_t& val bool TaggerDataExe::search(str_str_int* ptr, uint64_t count, - StringRef key1, StringRef key2, uint64_t& val) + UString_view key1, UString_view key2, uint64_t& val) { - UString_view k1 = str_write.get(key1); - UString_view k2 = str_write.get(key2); int64_t l = 0, r = count-1, m; while (l <= r) { m = (l + r) / 2; - if (str_write.get(ptr[m].s1) == k1 && str_write.get(ptr[m].s2) == k2) { + if (str_write.get(ptr[m].s1) == key1 && str_write.get(ptr[m].s2) == key2) { val = ptr[m].i; return true; - } else if ((str_write.get(ptr[m].s1) < k1) || - (str_write.get(ptr[m].s1) == k1 && - str_write.get(ptr[m].s2) < k2)) { + } else if ((str_write.get(ptr[m].s1) < key1) || + (str_write.get(ptr[m].s1) == key1 && + str_write.get(ptr[m].s2) < key2)) { l = m + 1; } else { r = m - 1; diff --git a/apertium/tagger_data_exe.h b/apertium/tagger_data_exe.h index 5c42d8a..8506736 100644 --- a/apertium/tagger_data_exe.h +++ b/apertium/tagger_data_exe.h @@ -153,9 +153,9 @@ public: return lsw_d[i*N*N + j*N + k]; } - bool search(str_int* ptr, uint64_t count, StringRef key, uint64_t& val); - bool search(str_str_int* ptr, uint64_t count, StringRef k1, StringRef k2, - uint64_t& val); + bool search(str_int* ptr, uint64_t count, UString_view key, uint64_t& val); + bool search(str_str_int* ptr, uint64_t count, UString_view k1, + UString_view k2, uint64_t& val); bool search(int_int* ptr, uint64_t count, uint64_t key, uint64_t& val); }; diff --git a/apertium/tagger_exe.cc b/apertium/tagger_exe.cc index 31be67e..035cdc4 100644 --- a/apertium/tagger_exe.cc +++ b/apertium/tagger_exe.cc @@ -82,10 +82,8 @@ TaggerExe::read_tagger_word(InputFile& input) uint64_t ca_tag_keof = tde.tag_index_count; uint64_t ca_tag_kundef = tde.tag_index_count; - tde.search(tde.tag_index, tde.tag_index_count, - tde.str_write.add("TAG_kEOF"_u), ca_tag_keof); - tde.search(tde.tag_index, tde.tag_index_count, - tde.str_write.add("TAG_kUNDEF"_u), ca_tag_kundef); + tde.search(tde.tag_index, tde.tag_index_count, "TAG_kEOF"_u, ca_tag_keof); + tde.search(tde.tag_index, tde.tag_index_count, "TAG_kUNDEF"_u, ca_tag_kundef); size_t index = 0; tagger_word_buffer.push_back(new TaggerWord()); @@ -179,6 +177,103 @@ TaggerExe::read_tagger_word(InputFile& input) return read_tagger_word(input); } +long double +TaggerExe::score_unigram1(UString_view lu) +{ + long double s = 1; + uint64_t c = 0; + if (tde.search(tde.uni1, tde.uni1_count, lu, c)) { + s += c; + } + return s; +} + +void +TaggerExe::build_uni2_counts() +{ + for (uint64_t i = 0; i < tde.uni2_count; i++) { + UString_view key = tde.str_write.get(tde.uni2[i].s1); + uni2_counts[key].first++; + uni2_counts[key].second += tde.uni2[i].i; + } +} + +long double +TaggerExe::score_unigram2(UString_view lu) +{ + auto loc = lu.find_first_of('<'); + if (loc == UString_view::npos) { + return 0.5; + } + UString_view lemma = lu.substr(0, loc); + UString_view analysis = lu.substr(loc); + long double tokenCount_r_a = 1; + long double tokenCount_a = 1; + long double typeCount_a = 1; + if (uni2_counts.find(analysis) != uni2_counts.end()) { + uint64_t n; + if (tde.search(tde.uni2, tde.uni2_count, analysis, lemma, n)) { + tokenCount_r_a += n; + typeCount_a = 0; + } + typeCount_a += uni2_counts[analysis].first; + tokenCount_a += uni2_counts[analysis].second; + } + return (tokenCount_r_a * tokenCount_a) / (tokenCount_a + typeCount_a); +} + +void +TaggerExe::tag_unigram(InputFile& input, UFILE* output, int model) +{ + if (model == 2) { + build_uni2_counts(); + } + while (!input.eof()) { + write(input.readBlank(true), output); + UChar32 c = input.get(); + if (c == '\0') { + u_fputc(c, output); + if (null_flush) { + u_fflush(output); + } + continue; + } else if (c == U_EOF) { + break; + } + // readBlank() guarantees the next char is thus ^ + UString lu = input.readBlock('^', '$'); + if (lu[lu.size()-1] != '$' && (input.peek() == '\0' || input.eof())) { + write(lu, output); + continue; + } + lu = lu.substr(1, lu.size()-2); + vector pieces = StringUtils::split_escape(lu, '/'); + // TODO: superficial and reordering options + size_t selected = 1; + long double score = 0; + if (pieces.size() == 1) { + u_fprintf(output, "^*%S$", lu.c_str()); + continue; + } + for (size_t i = 1; i < pieces.size(); i++) { + long double s = 0; + switch (model) { + case 1: + s = score_unigram1(pieces[i]); break; + case 2: + s = score_unigram2(pieces[i]); break; + default: + break; + } + if (s > score) { + score = s; + selected = i; + } + } + // write(pieces[selected], output); + } +} + void TaggerExe::tag_hmm(InputFile& input, UFILE* output) { @@ -196,8 +291,7 @@ TaggerExe::tag_hmm(InputFile& input, UFILE* output) set tags, pretags; uint64_t eos; - tde.search(tde.tag_index, tde.tag_index_count, tde.str_write.add("TAG_SENT"_u), - eos); + tde.search(tde.tag_index, tde.tag_index_count, "TAG_SENT"_u, eos); tags.insert(eos); alpha[0][eos] = 1; @@ -254,8 +348,7 @@ TaggerExe::tag_hmm(InputFile& input, UFILE* output) cerr <<"Problem with word '"<< cur_word->get_superficial_form()<<"' "<get_string_tags()<<"\n"; } uint64_t eof = tde.tag_index_count; - tde.search(tde.tag_index, tde.tag_index_count, - tde.str_write.add("TAG_kEOF"_u), eof); + tde.search(tde.tag_index, tde.tag_index_count, "TAG_kEOF"_u, eof); for (uint64_t t = 0; t < best[words.size()%2][tag].size(); t++) { if (false) { // (TheFlags.getFirst()) { //write(words[t].get_all_chosen_tag_first(best[words.size()%2][tag][t], diff --git a/apertium/tagger_exe.h b/apertium/tagger_exe.h index 833985d..28e315d 100644 --- a/apertium/tagger_exe.h +++ b/apertium/tagger_exe.h @@ -39,10 +39,16 @@ private: void build_match_finals(); std::vector prefer_rules; void build_prefer_rules(); + + std::map> uni2_counts; + void build_uni2_counts(); + long double score_unigram1(UString_view lu); + long double score_unigram2(UString_view lu); public: TaggerDataExe tde; Apertium::StreamedType read_streamed_type(InputFile& input); TaggerWord* read_tagger_word(InputFile& input); + void tag_unigram(InputFile& input, UFILE* output, int model); void tag_hmm(InputFile& input, UFILE* output); void load(FILE* in); };