File: | perceptron_tagger.cc |
Warning: | line 261, column 7 Branch condition evaluates to a garbage value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | #include <apertium/perceptron_tagger.h> | |||
2 | ||||
3 | #include <apertium/mtx_reader.h> | |||
4 | #include <apertium/exception.h> | |||
5 | #include <algorithm> | |||
6 | #include <map> | |||
7 | #include <set> | |||
8 | ||||
9 | namespace Apertium { | |||
10 | ||||
11 | PerceptronTagger::PerceptronTagger(TaggerFlags flags) : StreamTagger(flags) {}; | |||
12 | ||||
13 | PerceptronTagger::~PerceptronTagger() {}; | |||
14 | ||||
15 | void PerceptronTagger::tag(Stream &in, std::ostream &out) { | |||
16 | SentenceStream::SentenceTagger::tag(in, out, TheFlags.getSentSeg()); | |||
17 | } | |||
18 | ||||
19 | void PerceptronTagger::read_spec(const std::string &filename) { | |||
20 | MTXReader(spec).read(filename); | |||
21 | } | |||
22 | ||||
23 | std::ostream & | |||
24 | operator<<(std::ostream &out, PerceptronTagger const &pt) { | |||
25 | out << "== Spec ==\n"; | |||
26 | out << pt.spec; | |||
27 | out << "== Weights " << pt.weights.size() << " ==\n"; | |||
28 | out << pt.weights; | |||
29 | return out; | |||
30 | } | |||
31 | ||||
32 | TaggedSentence | |||
33 | PerceptronTagger::tagSentence(const Sentence &untagged_sent) { | |||
34 | const size_t sent_len = untagged_sent.size(); | |||
35 | ||||
36 | std::vector<AgendaItem> agenda; | |||
37 | agenda.reserve(spec.beam_width); | |||
38 | agenda.push_back(AgendaItem()); | |||
39 | agenda.back().tagged.reserve(sent_len); | |||
40 | ||||
41 | UnaryFeatureVec feat_vec_delta; | |||
42 | std::vector<Analysis>::const_iterator analys_it; | |||
43 | std::vector<AgendaItem>::const_iterator agenda_it; | |||
44 | std::vector<Morpheme>::const_iterator wordoid_it; | |||
45 | ||||
46 | for (size_t token_idx = 0; token_idx < sent_len; token_idx++) { | |||
47 | const std::vector<Analysis> &analyses = | |||
48 | untagged_sent[token_idx].TheLexicalUnit->TheAnalyses; | |||
49 | ||||
50 | std::vector<AgendaItem> new_agenda; | |||
51 | new_agenda.reserve(spec.beam_width * analyses.size()); | |||
52 | ||||
53 | if (analyses.size() == 1) { | |||
54 | extendAgendaAll(agenda, analyses.front()); | |||
55 | continue; | |||
56 | } else if (analyses.size() == 0) { | |||
57 | extendAgendaAll(agenda, Optional<Analysis>()); | |||
58 | continue; | |||
59 | } | |||
60 | ||||
61 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
62 | for (analys_it = analyses.begin(); analys_it != analyses.end(); analys_it++) { | |||
63 | const std::vector<Morpheme> &wordoids = analys_it->TheMorphemes; | |||
64 | ||||
65 | new_agenda.push_back(*agenda_it); | |||
66 | AgendaItem &new_agenda_item = new_agenda.back(); | |||
67 | new_agenda_item.tagged.push_back(*analys_it); | |||
68 | ||||
69 | for (wordoid_it = wordoids.begin(); wordoid_it != wordoids.end(); wordoid_it++) { | |||
70 | int wordoid_idx = wordoid_it - wordoids.begin(); | |||
71 | feat_vec_delta.clear(); | |||
72 | spec.get_features(new_agenda_item.tagged, untagged_sent, | |||
73 | token_idx, wordoid_idx, feat_vec_delta); | |||
74 | if (TheFlags.getDebug()) { | |||
75 | FeatureVec fv(feat_vec_delta); | |||
76 | std::cerr << "Token " << token_idx << "\t\tWordoid " << wordoid_idx << "\n"; | |||
77 | std::cerr << fv; | |||
78 | std::cerr << "Score: " << weights * feat_vec_delta << "\n"; | |||
79 | } | |||
80 | new_agenda_item.score += weights * feat_vec_delta; | |||
81 | } | |||
82 | } | |||
83 | } | |||
84 | // Apply the beam | |||
85 | if (TheFlags.getDebug()) { | |||
86 | std::cerr << "-- Before beam: --\n" << new_agenda; | |||
87 | } | |||
88 | size_t new_agenda_size = std::min((size_t)spec.beam_width, new_agenda.size()); | |||
89 | agenda.resize(new_agenda_size); | |||
90 | std::partial_sort_copy(new_agenda.begin(), new_agenda.end(), | |||
91 | agenda.begin(), agenda.end()); | |||
92 | if (TheFlags.getDebug()) { | |||
93 | std::cerr << "-- After beam: --\n" << agenda; | |||
94 | } | |||
95 | } | |||
96 | ||||
97 | spec.clearCache(); | |||
98 | return agenda.front().tagged; | |||
99 | } | |||
100 | ||||
101 | void PerceptronTagger::outputLexicalUnit( | |||
102 | const LexicalUnit &lexical_unit, const Optional<Analysis> analysis, | |||
103 | std::ostream &output) { | |||
104 | StreamTagger::outputLexicalUnit(lexical_unit, analysis, output); | |||
105 | } | |||
106 | ||||
107 | bool PerceptronTagger::trainSentence( | |||
108 | const TrainingSentence &sentence, | |||
109 | FeatureVecAverager &avg_weights) | |||
110 | { | |||
111 | const TaggedSentence &tagged_sent = sentence.first; | |||
112 | const Sentence &untagged_sent = sentence.second; | |||
113 | assert(tagged_sent.size() == untagged_sent.size())(static_cast <bool> (tagged_sent.size() == untagged_sent .size()) ? void (0) : __assert_fail ("tagged_sent.size() == untagged_sent.size()" , __builtin_FILE (), __builtin_LINE (), __extension__ __PRETTY_FUNCTION__ )); | |||
114 | const size_t sent_len = tagged_sent.size(); | |||
115 | ||||
116 | std::vector<TrainingAgendaItem> agenda; | |||
117 | agenda.reserve(spec.beam_width); | |||
118 | agenda.push_back(TrainingAgendaItem()); | |||
119 | agenda.back().tagged.reserve(sent_len); | |||
120 | std::vector<TrainingAgendaItem>::const_iterator correct_agenda_it | |||
121 | = agenda.begin(); | |||
122 | ||||
123 | TrainingAgendaItem correct_sentence; | |||
124 | correct_sentence.tagged.reserve(sent_len); | |||
125 | ||||
126 | UnaryFeatureVec feat_vec_delta; | |||
127 | std::vector<Analysis>::const_iterator analys_it; | |||
128 | std::vector<TrainingAgendaItem>::const_iterator agenda_it; | |||
129 | std::vector<Morpheme>::const_iterator wordoid_it; | |||
130 | ||||
131 | for (size_t token_idx = 0; token_idx < sent_len; token_idx++) { | |||
132 | //std::cerr << "Token idx: " << token_idx << "\n"; | |||
133 | const TaggedToken &tagged_tok(tagged_sent[token_idx]); | |||
134 | const StreamedType &untagged_tok(untagged_sent[token_idx]); | |||
135 | correct_sentence.tagged.push_back(tagged_tok); | |||
136 | ||||
137 | const std::vector<Analysis> &analyses = | |||
138 | untagged_tok.TheLexicalUnit->TheAnalyses; | |||
139 | ||||
140 | std::vector<TrainingAgendaItem> new_agenda; | |||
141 | new_agenda.reserve(spec.beam_width * analyses.size()); | |||
142 | ||||
143 | if (analyses.size() <= 1 || !tagged_tok) { | |||
144 | // Case |analyses| = 0, nothing we can do | |||
145 | // Case !tagged_tok, |analyses| > 0, no point penalising a guess which | |||
146 | // can only be incorrect when there's no correct answer | |||
147 | // Case |analyses| = 1, everything will cancel out anyway | |||
148 | if (analyses.size() == 1) { | |||
149 | extendAgendaAll(agenda, analyses.front()); | |||
150 | continue; | |||
151 | } else { | |||
152 | extendAgendaAll(agenda, Optional<Analysis>()); | |||
153 | continue; | |||
154 | } | |||
155 | } | |||
156 | ||||
157 | bool correct_available = false; | |||
158 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
159 | //std::cerr << *agenda_it; | |||
160 | for (analys_it = analyses.begin(); analys_it != analyses.end(); analys_it++) { | |||
161 | const std::vector<Morpheme> &wordoids = analys_it->TheMorphemes; | |||
162 | ||||
163 | new_agenda.push_back(*agenda_it); | |||
164 | TrainingAgendaItem &new_agenda_item = new_agenda.back(); | |||
165 | new_agenda_item.tagged.push_back(*analys_it); | |||
166 | ||||
167 | for (wordoid_it = wordoids.begin(); wordoid_it != wordoids.end(); wordoid_it++) { | |||
168 | int wordoid_idx = wordoid_it - wordoids.begin(); | |||
169 | feat_vec_delta.clear(); | |||
170 | spec.get_features(new_agenda_item.tagged, untagged_sent, | |||
171 | token_idx, wordoid_idx, feat_vec_delta); | |||
172 | new_agenda_item.vec += feat_vec_delta; | |||
173 | new_agenda_item.score += weights * feat_vec_delta; | |||
174 | if (agenda_it == correct_agenda_it && *analys_it == *tagged_tok) { | |||
175 | correct_sentence = new_agenda_item; | |||
176 | correct_available = true; | |||
177 | } | |||
178 | } | |||
179 | } | |||
180 | } | |||
181 | if (!correct_available) { | |||
182 | if (TheFlags.getSkipErrors()) { | |||
183 | return true; | |||
184 | } else { | |||
185 | std::stringstream what_; | |||
186 | what_ << "Tagged analysis unavailable in untagged/ambigous input.\n"; | |||
187 | what_ << "Available:\n"; | |||
188 | for (analys_it = analyses.begin(); analys_it != analyses.end(); analys_it++) { | |||
189 | what_ << *analys_it << "\n"; | |||
190 | } | |||
191 | what_ << "Required: " << *tagged_tok << "\n"; | |||
192 | what_ << "Rerun with --skip-on-error to skip this sentence."; | |||
193 | throw Apertium::Exception::PerceptronTagger::CorrectAnalysisUnavailable(what_); | |||
194 | } | |||
195 | } | |||
196 | // Apply the beam | |||
197 | //std::cerr << "-- Before beam: --\n" << new_agenda; | |||
198 | size_t new_agenda_size = std::min((size_t)spec.beam_width, new_agenda.size()); | |||
199 | agenda.resize(new_agenda_size); | |||
200 | std::partial_sort_copy(new_agenda.begin(), new_agenda.end(), | |||
201 | agenda.begin(), agenda.end()); | |||
202 | //std::cerr << "-- After beam: --\n" << agenda; | |||
203 | ||||
204 | // Early update "fallen off the beam" | |||
205 | bool any_match = false; | |||
206 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
207 | if (agenda_it->tagged == correct_sentence.tagged) { | |||
208 | correct_agenda_it = agenda_it; | |||
209 | any_match = true; | |||
210 | break; | |||
211 | } | |||
212 | } | |||
213 | if (!any_match) { | |||
214 | /*std::cerr << "Early update time!\n"; | |||
215 | std::cerr << "Before:\n" << weights << "\n"; | |||
216 | std::cerr << "Incorrect:\n" << agenda.front().vec << "\n"; | |||
217 | std::cerr << "Correct:\n" << correct_sentence.vec << "\n";*/ | |||
218 | avg_weights -= agenda.front().vec; | |||
219 | avg_weights += correct_sentence.vec; | |||
220 | avg_weights.incIteration(); | |||
221 | //std::cerr << "After:\n" << weights << "\n"; | |||
222 | return false; | |||
223 | } | |||
224 | } | |||
225 | // Normal update | |||
226 | /*std::cerr << "Best match:\n" << agenda.front().tagged << "\n\n"; | |||
227 | std::cerr << "Correct:\n" << correct_sentence.tagged << "\n\n";*/ | |||
228 | if (agenda.front().tagged != correct_sentence.tagged) { | |||
229 | /*std::cerr << "Normal update time!\n"; | |||
230 | std::cerr << "Before:\n" << weights << "\n"; | |||
231 | std::cerr << "Incorrect:\n" << agenda.front().vec << "\n"; | |||
232 | std::cerr << "Correct:\n" << correct_sentence.vec << "\n";*/ | |||
233 | avg_weights -= agenda.front().vec; | |||
234 | avg_weights += correct_sentence.vec; | |||
235 | avg_weights.incIteration(); | |||
236 | //std::cerr << "After:\n" << weights << "\n"; | |||
237 | } | |||
238 | return false; | |||
239 | } | |||
240 | ||||
241 | void PerceptronTagger::train(Stream&) {} // dummy | |||
242 | ||||
243 | void PerceptronTagger::train( | |||
244 | Stream &tagged, | |||
245 | Stream &untagged, | |||
246 | int iterations) { | |||
247 | FeatureVecAverager avg_weights(weights); | |||
248 | TrainingCorpus tc(tagged, untagged, TheFlags.getSkipErrors(), TheFlags.getSentSeg()); | |||
249 | size_t avail_skipped; | |||
| ||||
250 | for (int i = 0; i < iterations; i++) { | |||
251 | std::cerr << "Iteration " << i + 1 << " of " << iterations << "\n"; | |||
252 | avail_skipped = 0; | |||
253 | tc.shuffle(); | |||
254 | std::vector<TrainingSentence>::const_iterator si; | |||
255 | for (si = tc.sentences.begin(); si != tc.sentences.end(); si++) { | |||
256 | avail_skipped += trainSentence(*si, avg_weights); | |||
257 | spec.clearCache(); | |||
258 | } | |||
259 | } | |||
260 | avg_weights.average(); | |||
261 | if (avail_skipped) { | |||
| ||||
262 | std::cerr << "Skipped " << tc.skipped << " sentences due to token " | |||
263 | << "misalignment and " << avail_skipped << " sentences due to " | |||
264 | << "tagged token being unavailable in untagged file out of " | |||
265 | << tc.sentences.size() << " total sentences.\n"; | |||
266 | } | |||
267 | //std::cerr << *this; | |||
268 | } | |||
269 | ||||
270 | void PerceptronTagger::serialise(std::ostream &serialised) const | |||
271 | { | |||
272 | spec.serialise(serialised); | |||
273 | weights.serialise(serialised); | |||
274 | }; | |||
275 | ||||
276 | void PerceptronTagger::deserialise(std::istream &serialised) | |||
277 | { | |||
278 | spec.deserialise(serialised); | |||
279 | weights.deserialise(serialised); | |||
280 | }; | |||
281 | ||||
282 | template <typename T> void | |||
283 | PerceptronTagger::extendAgendaAll( | |||
284 | std::vector<T> &agenda, | |||
285 | Optional<Analysis> analy) { | |||
286 | typename std::vector<T>::iterator agenda_it; | |||
287 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
288 | agenda_it->tagged.push_back(analy); | |||
289 | } | |||
290 | } | |||
291 | ||||
292 | std::ostream& | |||
293 | operator<<(std::ostream &out, const TaggedSentence &tagged) { | |||
294 | TaggedSentence::const_iterator tsi; | |||
295 | for (tsi = tagged.begin(); tsi != tagged.end(); tsi++) { | |||
296 | if (*tsi) { | |||
297 | out << **tsi; | |||
298 | } else { | |||
299 | out << "*"; | |||
300 | } | |||
301 | out << " "; | |||
302 | } | |||
303 | return out; | |||
304 | } | |||
305 | ||||
306 | std::ostream& | |||
307 | operator<<(std::ostream &out, const PerceptronTagger::TrainingAgendaItem &tai) { | |||
308 | out << "Score: " << tai.score << "\n"; | |||
309 | out << "Sentence: " << tai.tagged << "\n"; | |||
310 | out << "\n"; | |||
311 | out << "Vector:\n" << tai.vec; | |||
312 | return out; | |||
313 | } | |||
314 | ||||
315 | std::ostream& | |||
316 | operator<<(std::ostream &out, const std::vector<PerceptronTagger::TrainingAgendaItem> &agenda) { | |||
317 | std::vector<PerceptronTagger::TrainingAgendaItem>::const_iterator agenda_it; | |||
318 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
319 | out << *agenda_it; | |||
320 | } | |||
321 | out << "\n\n"; | |||
322 | return out; | |||
323 | } | |||
324 | ||||
325 | std::ostream& | |||
326 | operator<<(std::ostream &out, const PerceptronTagger::AgendaItem &ai) { | |||
327 | out << "Score: " << ai.score << "\n"; | |||
328 | out << "Sentence: " << ai.tagged << "\n"; | |||
329 | return out; | |||
330 | } | |||
331 | ||||
332 | std::ostream& | |||
333 | operator<<(std::ostream &out, const std::vector<PerceptronTagger::AgendaItem> &agenda) { | |||
334 | std::vector<PerceptronTagger::AgendaItem>::const_iterator agenda_it; | |||
335 | for (agenda_it = agenda.begin(); agenda_it != agenda.end(); agenda_it++) { | |||
336 | out << *agenda_it; | |||
337 | } | |||
338 | out << "\n\n"; | |||
339 | return out; | |||
340 | } | |||
341 | ||||
342 | bool operator<(const PerceptronTagger::AgendaItem &a, | |||
343 | const PerceptronTagger::AgendaItem &b) { | |||
344 | return a.score > b.score; | |||
345 | }; | |||
346 | } |