commit ccb9ab7b91c8eba77c519effc59d9e1296d546ef Author: Daniel Swanson Date: Mon Aug 2 17:36:36 2021 -0500 split read functions for TransducerExe diff --git a/lttoolbox/fst_processor.cc b/lttoolbox/fst_processor.cc index 3b675c8..cb8c7ce 100644 --- a/lttoolbox/fst_processor.cc +++ b/lttoolbox/fst_processor.cc @@ -993,12 +993,11 @@ FSTProcessor::load(FILE *input) alphabet.read(input, true); uint64_t tr_count = read_le_64(input); - Alphabet temp; for (uint64_t i = 0; i < tr_count; i++) { uint32_t s = read_le_32(input); uint32_t c = read_le_32(input); UString name = UString{str_write.get(s, c)}; - transducers[name].read(input, temp); + transducers[name].read(input); } } } else { @@ -1021,7 +1020,7 @@ FSTProcessor::load(FILE *input) while(len > 0) { UString name = Compression::string_read(input); - transducers[name].read(input, temp); + transducers[name].read_compressed(input, temp); len--; } } diff --git a/lttoolbox/transducer_exe.cc b/lttoolbox/transducer_exe.cc index 3b6fe60..259f71e 100644 --- a/lttoolbox/transducer_exe.cc +++ b/lttoolbox/transducer_exe.cc @@ -40,10 +40,9 @@ TransducerExe::~TransducerExe() } void -TransducerExe::read(FILE* input, Alphabet& alphabet) +TransducerExe::read_compressed(FILE* input, Alphabet& alphabet) { bool read_weights = false; // only matters for pre-mmap - bool mmap = false; fpos_t pos; fgetpos(input, &pos); char header[4]{}; @@ -54,94 +53,108 @@ TransducerExe::read(FILE* input, Alphabet& alphabet) throw std::runtime_error("Transducer has features that are unknown to this version of lttoolbox - upgrade!"); } read_weights = (features & TDF_WEIGHTS); - mmap = (features & TDF_MMAP); } else { // no header fsetpos(input, &pos); } - if (mmap) { - read_le_64(input); // total size - initial = read_le_64(input); - state_count = read_le_64(input); - final_count = read_le_64(input); - transition_count = read_le_64(input); - - finals = new Final[final_count]; - for (uint64_t i = 0; i < final_count; i++) { - finals[i].state = read_le_64(input); - finals[i].weight = read_le_double(input); - } + initial = Compression::multibyte_read(input); + final_count = Compression::multibyte_read(input); - offsets = new uint64_t[state_count+1]; - for (uint64_t i = 0; i < state_count+1; i++) { - offsets[i] = read_le_64(input); + uint64_t base_state = 0; + double base_weight = 0.0; + finals = new Final[final_count]; + for (uint64_t i = 0; i < final_count; i++) { + base_state += Compression::multibyte_read(input); + if (read_weights) { + base_weight += Compression::long_multibyte_read(input); } + finals[i].state = base_state; + finals[i].weight = base_weight; + } - transitions = new Transition[transition_count]; - for (uint64_t i = 0; i < transition_count; i++) { - transitions[i].isym = read_le_s32(input); - transitions[i].osym = read_le_s32(input); - transitions[i].dest = read_le_64(input); - transitions[i].weight = read_le_double(input); - } - } else { - initial = Compression::multibyte_read(input); - final_count = Compression::multibyte_read(input); - - uint64_t base_state = 0; - double base_weight = 0.0; - finals = new Final[final_count]; - for (uint64_t i = 0; i < final_count; i++) { - base_state += Compression::multibyte_read(input); + state_count = Compression::multibyte_read(input); + offsets = new uint64_t[state_count+1]; + transition_count = 0; + std::vector isyms, osyms; + std::vector dests; + std::vector weights; + for (uint64_t i = 0; i < state_count; i++) { + offsets[i] = transition_count; + std::map>>> temp; + uint64_t count = Compression::multibyte_read(input); + transition_count += count; + int32_t tag_base = 0; + for (uint64_t t = 0; t < count; t++) { + tag_base += Compression::multibyte_read(input); + uint64_t dest = (i + Compression::multibyte_read(input)) % state_count; if (read_weights) { - base_weight += Compression::long_multibyte_read(input); + base_weight = Compression::long_multibyte_read(input); } - finals[i].state = base_state; - finals[i].weight = base_weight; + auto sym = alphabet.decode(tag_base); + temp[sym.first].push_back(make_pair(sym.second, + make_pair(dest, base_weight))); } - - state_count = Compression::multibyte_read(input); - offsets = new uint64_t[state_count+1]; - transition_count = 0; - std::vector isyms, osyms; - std::vector dests; - std::vector weights; - for (uint64_t i = 0; i < state_count; i++) { - offsets[i] = transition_count; - std::map>>> temp; - uint64_t count = Compression::multibyte_read(input); - transition_count += count; - int32_t tag_base = 0; - for (uint64_t t = 0; t < count; t++) { - tag_base += Compression::multibyte_read(input); - uint64_t dest = (i + Compression::multibyte_read(input)) % state_count; - if (read_weights) { - base_weight = Compression::long_multibyte_read(input); - } - auto sym = alphabet.decode(tag_base); - temp[sym.first].push_back(make_pair(sym.second, - make_pair(dest, base_weight))); - } - for (auto& it : temp) { - for (auto& it2 : it.second) { - isyms.push_back(it.first); - osyms.push_back(it2.first); - dests.push_back(it2.second.first); - weights.push_back(it2.second.second); - } + for (auto& it : temp) { + for (auto& it2 : it.second) { + isyms.push_back(it.first); + osyms.push_back(it2.first); + dests.push_back(it2.second.first); + weights.push_back(it2.second.second); } } - offsets[state_count] = transition_count; - transitions = new Transition[transition_count]; - for (uint64_t i = 0; i < transition_count; i++) { - transitions[i].isym = isyms[i]; - transitions[i].osym = osyms[i]; - transitions[i].dest = dests[i]; - transitions[i].weight = weights[i]; + } + offsets[state_count] = transition_count; + transitions = new Transition[transition_count]; + for (uint64_t i = 0; i < transition_count; i++) { + transitions[i].isym = isyms[i]; + transitions[i].osym = osyms[i]; + transitions[i].dest = dests[i]; + transitions[i].weight = weights[i]; + } +} + +void +TransducerExe::read(FILE* input) +{ + fpos_t pos; + fgetpos(input, &pos); + char header[4]{}; + auto l = fread_unlocked(header, 1, 4, input); + if (l == 4 && strncmp(header, HEADER_TRANSDUCER, 4) == 0) { + auto features = read_le_64(input); + if (features >= TDF_UNKNOWN) { + throw std::runtime_error("Transducer has features that are unknown to this version of lttoolbox - upgrade!"); } + } else { + throw std::runtime_error("Unable to read transducer header!"); + } + + read_le_64(input); // total size + initial = read_le_64(input); + state_count = read_le_64(input); + final_count = read_le_64(input); + transition_count = read_le_64(input); + + finals = new Final[final_count]; + for (uint64_t i = 0; i < final_count; i++) { + finals[i].state = read_le_64(input); + finals[i].weight = read_le_double(input); + } + + offsets = new uint64_t[state_count+1]; + for (uint64_t i = 0; i < state_count+1; i++) { + offsets[i] = read_le_64(input); + } + + transitions = new Transition[transition_count]; + for (uint64_t i = 0; i < transition_count; i++) { + transitions[i].isym = read_le_s32(input); + transitions[i].osym = read_le_s32(input); + transitions[i].dest = read_le_64(input); + transitions[i].weight = read_le_double(input); } } diff --git a/lttoolbox/transducer_exe.h b/lttoolbox/transducer_exe.h index cef9a30..6b7d93b 100644 --- a/lttoolbox/transducer_exe.h +++ b/lttoolbox/transducer_exe.h @@ -59,7 +59,8 @@ private: public: TransducerExe(); ~TransducerExe(); - void read(FILE* input, Alphabet& alphabet); + void read_compressed(FILE* input, Alphabet& alphabet); + void read(FILE* input); void* init(void* ptr); };