| File: | perceptron_tagger.cc |
| Warning: | line 261, column 7 Branch condition evaluates to a garbage value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | #include <apertium/perceptron_tagger.h> | |||
| 2 | ||||
| 3 | #include <apertium/mtx_reader.h> | |||
| 4 | #include <apertium/exception.h> | |||
| 5 | #include <algorithm> | |||
| 6 | #include <map> | |||
| 7 | #include <set> | |||
| 8 | ||||
| 9 | namespace Apertium { | |||
| 10 | ||||
| 11 | PerceptronTagger::PerceptronTagger(TaggerFlags flags) : StreamTagger(flags) {}; | |||
| 12 | ||||
| 13 | PerceptronTagger::~PerceptronTagger() {}; | |||
| 14 | ||||
| 15 | void PerceptronTagger::tag(Stream &in, std::ostream &out) { | |||
| 16 | SentenceStream::SentenceTagger::tag(in, out, TheFlags.getSentSeg()); | |||
| 17 | } | |||
| 18 | ||||
| 19 | void PerceptronTagger::read_spec(const std::string &filename) { | |||
| 20 | MTXReader(spec).read(filename); | |||
| 21 | } | |||
| 22 | ||||
| 23 | std::ostream & | |||
| 24 | operator<<(std::ostream &out, PerceptronTagger const &pt) { | |||
| 25 | out << "== Spec ==\n"; | |||
| 26 | out << pt.spec; | |||
| 27 | out << "== Weights " << pt.weights.size() << " ==\n"; | |||
| 28 | out << pt.weights; | |||
| 29 | return out; | |||
| 30 | } | |||
| 31 | ||||
| 32 | TaggedSentence | |||
| 33 | PerceptronTagger::tagSentence(const Sentence &untagged_sent) { | |||
| 34 | const size_t sent_len = untagged_sent.size(); | |||
| 35 | ||||
| 36 | std::vector<AgendaItem> agenda; | |||
| 37 | agenda.reserve(spec.beam_width); | |||
| 38 | agenda.push_back(AgendaItem()); | |||
| 39 | agenda.back().tagged.reserve(sent_len); | |||
| 40 | ||||
| 41 | UnaryFeatureVec feat_vec_delta; | |||
| 42 | std::vector<Analysis>::const_iterator analys_it; | |||
| 43 | std::vector<AgendaItem>::const_iterator agenda_it; | |||
| 44 | std::vector<Morpheme>::const_iterator wordoid_it; | |||
| 45 | ||||
| 46 | for (size_t token_idx = 0; token_idx < sent_len; token_idx++) { | |||
| 47 | const std::vector<Analysis> &analyses = | |||
| 48 | untagged_sent[token_idx].TheLexicalUnit->TheAnalyses; | |||
| 49 | ||||
| 50 | std::vector<AgendaItem> new_agenda; | |||
| 51 | new_agenda.reserve(spec.beam_width * analyses.size()); | |||
| 52 | ||||
| 53 | if (analyses.size() == 1) { | |||
| 54 | extendAgendaAll(agenda, analyses.front()); | |||
| 55 | continue; | |||
| 56 | } else if (analyses.size() == 0) { | |||
| 57 | extendAgendaAll(agenda, Optional<Analysis>()); | |||
| 58 | continue; | |||
| 59 | } | |||
| 60 | ||||
| 61 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
| 62 | for (analys_it = analyses.begin(); analys_it != analyses.end(); analys_it++) { | |||
| 63 | const std::vector<Morpheme> &wordoids = analys_it->TheMorphemes; | |||
| 64 | ||||
| 65 | new_agenda.push_back(*agenda_it); | |||
| 66 | AgendaItem &new_agenda_item = new_agenda.back(); | |||
| 67 | new_agenda_item.tagged.push_back(*analys_it); | |||
| 68 | ||||
| 69 | for (wordoid_it = wordoids.begin(); wordoid_it != wordoids.end(); wordoid_it++) { | |||
| 70 | int wordoid_idx = wordoid_it - wordoids.begin(); | |||
| 71 | feat_vec_delta.clear(); | |||
| 72 | spec.get_features(new_agenda_item.tagged, untagged_sent, | |||
| 73 | token_idx, wordoid_idx, feat_vec_delta); | |||
| 74 | if (TheFlags.getDebug()) { | |||
| 75 | FeatureVec fv(feat_vec_delta); | |||
| 76 | std::cerr << "Token " << token_idx << "\t\tWordoid " << wordoid_idx << "\n"; | |||
| 77 | std::cerr << fv; | |||
| 78 | std::cerr << "Score: " << weights * feat_vec_delta << "\n"; | |||
| 79 | } | |||
| 80 | new_agenda_item.score += weights * feat_vec_delta; | |||
| 81 | } | |||
| 82 | } | |||
| 83 | } | |||
| 84 | // Apply the beam | |||
| 85 | if (TheFlags.getDebug()) { | |||
| 86 | std::cerr << "-- Before beam: --\n" << new_agenda; | |||
| 87 | } | |||
| 88 | size_t new_agenda_size = std::min((size_t)spec.beam_width, new_agenda.size()); | |||
| 89 | agenda.resize(new_agenda_size); | |||
| 90 | std::partial_sort_copy(new_agenda.begin(), new_agenda.end(), | |||
| 91 | agenda.begin(), agenda.end()); | |||
| 92 | if (TheFlags.getDebug()) { | |||
| 93 | std::cerr << "-- After beam: --\n" << agenda; | |||
| 94 | } | |||
| 95 | } | |||
| 96 | ||||
| 97 | spec.clearCache(); | |||
| 98 | return agenda.front().tagged; | |||
| 99 | } | |||
| 100 | ||||
| 101 | void PerceptronTagger::outputLexicalUnit( | |||
| 102 | const LexicalUnit &lexical_unit, const Optional<Analysis> analysis, | |||
| 103 | std::ostream &output) { | |||
| 104 | StreamTagger::outputLexicalUnit(lexical_unit, analysis, output); | |||
| 105 | } | |||
| 106 | ||||
| 107 | bool PerceptronTagger::trainSentence( | |||
| 108 | const TrainingSentence &sentence, | |||
| 109 | FeatureVecAverager &avg_weights) | |||
| 110 | { | |||
| 111 | const TaggedSentence &tagged_sent = sentence.first; | |||
| 112 | const Sentence &untagged_sent = sentence.second; | |||
| 113 | assert(tagged_sent.size() == untagged_sent.size())(static_cast <bool> (tagged_sent.size() == untagged_sent .size()) ? void (0) : __assert_fail ("tagged_sent.size() == untagged_sent.size()" , __builtin_FILE (), __builtin_LINE (), __extension__ __PRETTY_FUNCTION__ )); | |||
| 114 | const size_t sent_len = tagged_sent.size(); | |||
| 115 | ||||
| 116 | std::vector<TrainingAgendaItem> agenda; | |||
| 117 | agenda.reserve(spec.beam_width); | |||
| 118 | agenda.push_back(TrainingAgendaItem()); | |||
| 119 | agenda.back().tagged.reserve(sent_len); | |||
| 120 | std::vector<TrainingAgendaItem>::const_iterator correct_agenda_it | |||
| 121 | = agenda.begin(); | |||
| 122 | ||||
| 123 | TrainingAgendaItem correct_sentence; | |||
| 124 | correct_sentence.tagged.reserve(sent_len); | |||
| 125 | ||||
| 126 | UnaryFeatureVec feat_vec_delta; | |||
| 127 | std::vector<Analysis>::const_iterator analys_it; | |||
| 128 | std::vector<TrainingAgendaItem>::const_iterator agenda_it; | |||
| 129 | std::vector<Morpheme>::const_iterator wordoid_it; | |||
| 130 | ||||
| 131 | for (size_t token_idx = 0; token_idx < sent_len; token_idx++) { | |||
| 132 | //std::cerr << "Token idx: " << token_idx << "\n"; | |||
| 133 | const TaggedToken &tagged_tok(tagged_sent[token_idx]); | |||
| 134 | const StreamedType &untagged_tok(untagged_sent[token_idx]); | |||
| 135 | correct_sentence.tagged.push_back(tagged_tok); | |||
| 136 | ||||
| 137 | const std::vector<Analysis> &analyses = | |||
| 138 | untagged_tok.TheLexicalUnit->TheAnalyses; | |||
| 139 | ||||
| 140 | std::vector<TrainingAgendaItem> new_agenda; | |||
| 141 | new_agenda.reserve(spec.beam_width * analyses.size()); | |||
| 142 | ||||
| 143 | if (analyses.size() <= 1 || !tagged_tok) { | |||
| 144 | // Case |analyses| = 0, nothing we can do | |||
| 145 | // Case !tagged_tok, |analyses| > 0, no point penalising a guess which | |||
| 146 | // can only be incorrect when there's no correct answer | |||
| 147 | // Case |analyses| = 1, everything will cancel out anyway | |||
| 148 | if (analyses.size() == 1) { | |||
| 149 | extendAgendaAll(agenda, analyses.front()); | |||
| 150 | continue; | |||
| 151 | } else { | |||
| 152 | extendAgendaAll(agenda, Optional<Analysis>()); | |||
| 153 | continue; | |||
| 154 | } | |||
| 155 | } | |||
| 156 | ||||
| 157 | bool correct_available = false; | |||
| 158 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
| 159 | //std::cerr << *agenda_it; | |||
| 160 | for (analys_it = analyses.begin(); analys_it != analyses.end(); analys_it++) { | |||
| 161 | const std::vector<Morpheme> &wordoids = analys_it->TheMorphemes; | |||
| 162 | ||||
| 163 | new_agenda.push_back(*agenda_it); | |||
| 164 | TrainingAgendaItem &new_agenda_item = new_agenda.back(); | |||
| 165 | new_agenda_item.tagged.push_back(*analys_it); | |||
| 166 | ||||
| 167 | for (wordoid_it = wordoids.begin(); wordoid_it != wordoids.end(); wordoid_it++) { | |||
| 168 | int wordoid_idx = wordoid_it - wordoids.begin(); | |||
| 169 | feat_vec_delta.clear(); | |||
| 170 | spec.get_features(new_agenda_item.tagged, untagged_sent, | |||
| 171 | token_idx, wordoid_idx, feat_vec_delta); | |||
| 172 | new_agenda_item.vec += feat_vec_delta; | |||
| 173 | new_agenda_item.score += weights * feat_vec_delta; | |||
| 174 | if (agenda_it == correct_agenda_it && *analys_it == *tagged_tok) { | |||
| 175 | correct_sentence = new_agenda_item; | |||
| 176 | correct_available = true; | |||
| 177 | } | |||
| 178 | } | |||
| 179 | } | |||
| 180 | } | |||
| 181 | if (!correct_available) { | |||
| 182 | if (TheFlags.getSkipErrors()) { | |||
| 183 | return true; | |||
| 184 | } else { | |||
| 185 | std::stringstream what_; | |||
| 186 | what_ << "Tagged analysis unavailable in untagged/ambigous input.\n"; | |||
| 187 | what_ << "Available:\n"; | |||
| 188 | for (analys_it = analyses.begin(); analys_it != analyses.end(); analys_it++) { | |||
| 189 | what_ << *analys_it << "\n"; | |||
| 190 | } | |||
| 191 | what_ << "Required: " << *tagged_tok << "\n"; | |||
| 192 | what_ << "Rerun with --skip-on-error to skip this sentence."; | |||
| 193 | throw Apertium::Exception::PerceptronTagger::CorrectAnalysisUnavailable(what_); | |||
| 194 | } | |||
| 195 | } | |||
| 196 | // Apply the beam | |||
| 197 | //std::cerr << "-- Before beam: --\n" << new_agenda; | |||
| 198 | size_t new_agenda_size = std::min((size_t)spec.beam_width, new_agenda.size()); | |||
| 199 | agenda.resize(new_agenda_size); | |||
| 200 | std::partial_sort_copy(new_agenda.begin(), new_agenda.end(), | |||
| 201 | agenda.begin(), agenda.end()); | |||
| 202 | //std::cerr << "-- After beam: --\n" << agenda; | |||
| 203 | ||||
| 204 | // Early update "fallen off the beam" | |||
| 205 | bool any_match = false; | |||
| 206 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
| 207 | if (agenda_it->tagged == correct_sentence.tagged) { | |||
| 208 | correct_agenda_it = agenda_it; | |||
| 209 | any_match = true; | |||
| 210 | break; | |||
| 211 | } | |||
| 212 | } | |||
| 213 | if (!any_match) { | |||
| 214 | /*std::cerr << "Early update time!\n"; | |||
| 215 | std::cerr << "Before:\n" << weights << "\n"; | |||
| 216 | std::cerr << "Incorrect:\n" << agenda.front().vec << "\n"; | |||
| 217 | std::cerr << "Correct:\n" << correct_sentence.vec << "\n";*/ | |||
| 218 | avg_weights -= agenda.front().vec; | |||
| 219 | avg_weights += correct_sentence.vec; | |||
| 220 | avg_weights.incIteration(); | |||
| 221 | //std::cerr << "After:\n" << weights << "\n"; | |||
| 222 | return false; | |||
| 223 | } | |||
| 224 | } | |||
| 225 | // Normal update | |||
| 226 | /*std::cerr << "Best match:\n" << agenda.front().tagged << "\n\n"; | |||
| 227 | std::cerr << "Correct:\n" << correct_sentence.tagged << "\n\n";*/ | |||
| 228 | if (agenda.front().tagged != correct_sentence.tagged) { | |||
| 229 | /*std::cerr << "Normal update time!\n"; | |||
| 230 | std::cerr << "Before:\n" << weights << "\n"; | |||
| 231 | std::cerr << "Incorrect:\n" << agenda.front().vec << "\n"; | |||
| 232 | std::cerr << "Correct:\n" << correct_sentence.vec << "\n";*/ | |||
| 233 | avg_weights -= agenda.front().vec; | |||
| 234 | avg_weights += correct_sentence.vec; | |||
| 235 | avg_weights.incIteration(); | |||
| 236 | //std::cerr << "After:\n" << weights << "\n"; | |||
| 237 | } | |||
| 238 | return false; | |||
| 239 | } | |||
| 240 | ||||
| 241 | void PerceptronTagger::train(Stream&) {} // dummy | |||
| 242 | ||||
| 243 | void PerceptronTagger::train( | |||
| 244 | Stream &tagged, | |||
| 245 | Stream &untagged, | |||
| 246 | int iterations) { | |||
| 247 | FeatureVecAverager avg_weights(weights); | |||
| 248 | TrainingCorpus tc(tagged, untagged, TheFlags.getSkipErrors(), TheFlags.getSentSeg()); | |||
| 249 | size_t avail_skipped; | |||
| ||||
| 250 | for (int i = 0; i < iterations; i++) { | |||
| 251 | std::cerr << "Iteration " << i + 1 << " of " << iterations << "\n"; | |||
| 252 | avail_skipped = 0; | |||
| 253 | tc.shuffle(); | |||
| 254 | std::vector<TrainingSentence>::const_iterator si; | |||
| 255 | for (si = tc.sentences.begin(); si != tc.sentences.end(); si++) { | |||
| 256 | avail_skipped += trainSentence(*si, avg_weights); | |||
| 257 | spec.clearCache(); | |||
| 258 | } | |||
| 259 | } | |||
| 260 | avg_weights.average(); | |||
| 261 | if (avail_skipped) { | |||
| ||||
| 262 | std::cerr << "Skipped " << tc.skipped << " sentences due to token " | |||
| 263 | << "misalignment and " << avail_skipped << " sentences due to " | |||
| 264 | << "tagged token being unavailable in untagged file out of " | |||
| 265 | << tc.sentences.size() << " total sentences.\n"; | |||
| 266 | } | |||
| 267 | //std::cerr << *this; | |||
| 268 | } | |||
| 269 | ||||
| 270 | void PerceptronTagger::serialise(std::ostream &serialised) const | |||
| 271 | { | |||
| 272 | spec.serialise(serialised); | |||
| 273 | weights.serialise(serialised); | |||
| 274 | }; | |||
| 275 | ||||
| 276 | void PerceptronTagger::deserialise(std::istream &serialised) | |||
| 277 | { | |||
| 278 | spec.deserialise(serialised); | |||
| 279 | weights.deserialise(serialised); | |||
| 280 | }; | |||
| 281 | ||||
| 282 | template <typename T> void | |||
| 283 | PerceptronTagger::extendAgendaAll( | |||
| 284 | std::vector<T> &agenda, | |||
| 285 | Optional<Analysis> analy) { | |||
| 286 | typename std::vector<T>::iterator agenda_it; | |||
| 287 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
| 288 | agenda_it->tagged.push_back(analy); | |||
| 289 | } | |||
| 290 | } | |||
| 291 | ||||
| 292 | std::ostream& | |||
| 293 | operator<<(std::ostream &out, const TaggedSentence &tagged) { | |||
| 294 | TaggedSentence::const_iterator tsi; | |||
| 295 | for (tsi = tagged.begin(); tsi != tagged.end(); tsi++) { | |||
| 296 | if (*tsi) { | |||
| 297 | out << **tsi; | |||
| 298 | } else { | |||
| 299 | out << "*"; | |||
| 300 | } | |||
| 301 | out << " "; | |||
| 302 | } | |||
| 303 | return out; | |||
| 304 | } | |||
| 305 | ||||
| 306 | std::ostream& | |||
| 307 | operator<<(std::ostream &out, const PerceptronTagger::TrainingAgendaItem &tai) { | |||
| 308 | out << "Score: " << tai.score << "\n"; | |||
| 309 | out << "Sentence: " << tai.tagged << "\n"; | |||
| 310 | out << "\n"; | |||
| 311 | out << "Vector:\n" << tai.vec; | |||
| 312 | return out; | |||
| 313 | } | |||
| 314 | ||||
| 315 | std::ostream& | |||
| 316 | operator<<(std::ostream &out, const std::vector<PerceptronTagger::TrainingAgendaItem> &agenda) { | |||
| 317 | std::vector<PerceptronTagger::TrainingAgendaItem>::const_iterator agenda_it; | |||
| 318 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
| 319 | out << *agenda_it; | |||
| 320 | } | |||
| 321 | out << "\n\n"; | |||
| 322 | return out; | |||
| 323 | } | |||
| 324 | ||||
| 325 | std::ostream& | |||
| 326 | operator<<(std::ostream &out, const PerceptronTagger::AgendaItem &ai) { | |||
| 327 | out << "Score: " << ai.score << "\n"; | |||
| 328 | out << "Sentence: " << ai.tagged << "\n"; | |||
| 329 | return out; | |||
| 330 | } | |||
| 331 | ||||
| 332 | std::ostream& | |||
| 333 | operator<<(std::ostream &out, const std::vector<PerceptronTagger::AgendaItem> &agenda) { | |||
| 334 | std::vector<PerceptronTagger::AgendaItem>::const_iterator agenda_it; | |||
| 335 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
| 336 | out << *agenda_it; | |||
| 337 | } | |||
| 338 | out << "\n\n"; | |||
| 339 | return out; | |||
| 340 | } | |||
| 341 | ||||
| 342 | bool operator<(const PerceptronTagger::AgendaItem &a, | |||
| 343 | const PerceptronTagger::AgendaItem &b) { | |||
| 344 | return a.score > b.score; | |||
| 345 | }; | |||
| 346 | } |