diff --git a/src/lstm/recodebeam.cpp b/src/lstm/recodebeam.cpp index 2a8a0fcdca..ccecfcd2d8 100644 --- a/src/lstm/recodebeam.cpp +++ b/src/lstm/recodebeam.cpp @@ -34,6 +34,15 @@ const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = { static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"}; +// setting to enable diplopia removal functionality +// needs to be config setting in future +static const bool kRemoveDiplopia = true; +// The maximum diplopia gap is the maximum number of timesteps +// which the peak value of possible diplopia candidates can be apart +// in order to be considered as genuine diplopia +// needs to be config setting in future +static const int kMaxDiplopiaGap = 2; + // Prints debug details of the node. void RecodeNode::Print(int null_char, const UNICHARSET &unicharset, int depth) const { @@ -88,15 +97,364 @@ void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio, if (lstm_choice_mode) { timesteps.clear(); } + + // compute and save all top N flags for all timesteps + save_topn_.clear(); for (int t = 0; t < width; ++t) { - ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]); - DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, - charset); + ComputeAndSaveTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]); + } + + // eliminate potential diplopia cases if enabled + if (kRemoveDiplopia) { + FindAndRemoveDiplopia(width); + } + + for (int t = 0; t < width; ++t) { + FinalizeTopNFlags(t, output.NumFeatures(), kBeamWidths[0]); + DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, charset); if (lstm_choice_mode) { SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t); } } } + +// determines top_n choices for given step and saves for later +void RecodeBeamSearch::ComputeAndSaveTopN(const float *outputs, int num_outputs, int top_n) { + + top_heap_.clear(); + for (int i = 0; i < num_outputs; ++i) { + if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) { + TopPair entry(outputs[i], i); + top_heap_.Push(&entry); + if (top_heap_.size() > top_n) { + top_heap_.Pop(&entry); + } + } + } + + std::vector topn_for_step; + topn_for_step.resize(top_heap_.size()); + + int i = top_heap_.size(); + while (!top_heap_.empty()) { + TopPair entry; + top_heap_.Pop(&entry); + TopPair save_entry(entry.key(), entry.data()); + --i; + topn_for_step[i] = save_entry; + } + + save_topn_.push_back(topn_for_step); +} + +// searches for potential diplopia cases and zeros scores +// for all choices other than the highest one +void RecodeBeamSearch::FindAndRemoveDiplopia(int width) { + + int latest = -1; + int timestep = LocateDiplopia(width, latest); + while (timestep >= 0) { + latest = timestep; + RemoveChosenDiplopiaChar(width); + timestep = LocateDiplopia(width, latest); + } + +} + +// locates timestep (if any) at which apparent diplopia is occurring +int RecodeBeamSearch::LocateDiplopia(int width, int latest) { + + int retval = -1; + + for (int t = latest + 1; t < width; ++t) { + std::vector topn_for_step = save_topn_[t]; + int nbr_entries_topn = topn_for_step.size(); + first_diplopia_char_ = -1; + second_diplopia_char_ = -1; + if (nbr_entries_topn >= 2) { + if (topn_for_step[0].data() != null_char_ && topn_for_step[1].data() != null_char_) { + first_diplopia_char_ = topn_for_step[0].data(); + second_diplopia_char_ = topn_for_step[1].data(); + DetermineDiplopiaCharToRemove(width, t); + if (diplopia_gap_ <= kMaxDiplopiaGap) { + retval = t; + } + } + } + if (retval >= 0) { + break; + } + } + + return retval; +} + +void RecodeBeamSearch::DetermineDiplopiaCharToRemove(int width, int timestep) { + + DetermineDiplopiaDimensions(width, timestep, first_diplopia_char_); + int start_first_diplopia = start_diplopia_; + int end_first_diplopia = end_diplopia_; + float first_diplopia_max = diplopia_max_; + int first_diplopia_max_timestep = diplopia_max_timestep_; + + DetermineDiplopiaDimensions(width, timestep, second_diplopia_char_); + int start_second_diplopia = start_diplopia_; + int end_second_diplopia = end_diplopia_; + float second_diplopia_max = diplopia_max_; + int second_diplopia_max_timestep = diplopia_max_timestep_; + + if (first_diplopia_max >= second_diplopia_max) { + chosen_diplopia_char_ = second_diplopia_char_; + start_diplopia_ = start_second_diplopia; + end_diplopia_ = end_second_diplopia; + diplopia_max_ = second_diplopia_max; + } else { + chosen_diplopia_char_ = first_diplopia_char_; + start_diplopia_ = start_first_diplopia; + end_diplopia_ = end_first_diplopia; + diplopia_max_ = first_diplopia_max; + } + + if (first_diplopia_max_timestep <= second_diplopia_max_timestep) { + diplopia_gap_ = second_diplopia_max_timestep - first_diplopia_max_timestep; + } else { + diplopia_gap_ = first_diplopia_max_timestep - second_diplopia_max_timestep; + } +} + +void RecodeBeamSearch::RemoveChosenDiplopiaChar(int width) { + + for (int t = start_diplopia_; t <= end_diplopia_; ++t) { + int nbr_entries_topn = save_topn_[t].size(); + for (int i = 0; i < nbr_entries_topn; ++i) { + if (save_topn_[t][i].data() != null_char_) { + if (save_topn_[t][i].data() == chosen_diplopia_char_){ + for (int j = i; j < nbr_entries_topn - 1; j++) { + save_topn_[t][j].data() = save_topn_[t][j + 1].data(); + save_topn_[t][j].key() = save_topn_[t][j + 1].key(); + } + save_topn_[t][nbr_entries_topn - 1].data() = chosen_diplopia_char_; + save_topn_[t][nbr_entries_topn - 1].key() = 0.0F; + break; + } + } + } + } + +} + +void RecodeBeamSearch::DetermineDiplopiaDimensions(int width, int timestep, int diplopia_char) { + + enum CharShape { + CS_AT_PEAK, // at peak at current timestep + CS_PEAK_LATER, // peak is at a later timestep + CS_PEAK_EARLIER, // peak is at an earlier timestep + CS_IN_VALLEY, // in valley at current timestep + CS_COUNT + }; + + CharShape shape; + + float key_current = GetKeyForTimestep(width, timestep, diplopia_char); + float key_prev = key_current; + + int later_ctr = timestep + 1; + float key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + while (key_later == key_current) { + ++later_ctr; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + } + int earlier_ctr = timestep - 1; + float key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + while (key_earlier == key_current) { + --earlier_ctr; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + } + + if (key_later < key_current && key_earlier < key_current) { + shape = CS_AT_PEAK; + } else if (key_later > key_current && key_earlier < key_current) { + shape = CS_PEAK_LATER; + } else if (key_later < key_current && key_earlier > key_current) { + shape = CS_PEAK_EARLIER; + } else { + shape = CS_IN_VALLEY; + } + + switch (shape) { + + case CS_AT_PEAK: + + diplopia_max_ = key_current; + diplopia_max_timestep_ = timestep; + + start_diplopia_ = timestep; + key_prev = key_current; + earlier_ctr = timestep - 1; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + while (key_earlier > 0.0F && key_earlier <= key_prev) { + start_diplopia_ = earlier_ctr; + key_prev = key_earlier; + --earlier_ctr; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + } + + end_diplopia_ = timestep; + key_prev = key_current; + later_ctr = timestep + 1; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + while (key_later > 0.0F && key_later <= key_prev) { + end_diplopia_ = later_ctr; + key_prev = key_later; + ++later_ctr; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + } + break; + + case CS_PEAK_LATER: + + start_diplopia_ = timestep; + key_prev = key_current; + earlier_ctr = timestep - 1; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + while (key_earlier > 0.0F && key_earlier <= key_prev) { + start_diplopia_ = earlier_ctr; + key_prev = key_earlier; + --earlier_ctr; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + } + + key_prev = key_current; + later_ctr = timestep + 1; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + while (key_later > 0.0F && key_later >= key_prev) { + key_prev = key_later; + ++later_ctr; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + } + + diplopia_max_ = key_prev; + diplopia_max_timestep_ = later_ctr - 1; + + end_diplopia_ = later_ctr; + while (key_later > 0.0F && key_later <= key_prev) { + end_diplopia_ = later_ctr; + key_prev = key_later; + ++later_ctr; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + } + break; + + case CS_PEAK_EARLIER: + + end_diplopia_ = timestep; + key_prev = key_current; + later_ctr = timestep + 1; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + while (key_later > 0.0F && key_later <= key_prev) { + end_diplopia_ = later_ctr; + key_prev = key_later; + ++later_ctr; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + } + + key_prev = key_current; + earlier_ctr = timestep - 1; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + while (key_earlier > 0.0F && key_earlier >= key_prev) { + key_prev = key_earlier; + --earlier_ctr; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + } + + diplopia_max_ = key_prev; + diplopia_max_timestep_ = earlier_ctr + 1; + + start_diplopia_ = earlier_ctr; + while (key_earlier > 0.0F && key_earlier <= key_prev) { + start_diplopia_ = earlier_ctr; + key_prev = key_earlier; + --earlier_ctr; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + } + break; + + case CS_IN_VALLEY: + + diplopia_max_ = key_current; + diplopia_max_timestep_ = timestep; + + start_diplopia_ = timestep; + key_prev = key_current; + earlier_ctr = timestep - 1; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + while (key_earlier > 0.0F && key_earlier == key_prev) { + start_diplopia_ = earlier_ctr; + key_prev = key_earlier; + --earlier_ctr; + key_earlier = GetKeyForTimestep(width, earlier_ctr, diplopia_char); + } + + end_diplopia_ = timestep; + key_prev = key_current; + later_ctr = timestep + 1; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + while (key_later > 0.0F && key_later == key_prev) { + end_diplopia_ = later_ctr; + key_prev = key_later; + ++later_ctr; + key_later = GetKeyForTimestep(width, later_ctr, diplopia_char); + } + break; + } + +} + +float RecodeBeamSearch::GetKeyForTimestep(int width, int timestep, int diplopia_char) { + + float retval = 0.0F; + if (timestep >= width || timestep < 0 ) { + return retval; + } + std::vector topn_for_step = save_topn_[timestep]; + int nbr_entries_topn = topn_for_step.size(); + for (int i = 0; i < nbr_entries_topn; ++i) { + if (topn_for_step[i].data() != null_char_) { + if (topn_for_step[i].data() == diplopia_char){ + retval = topn_for_step[i].key(); + break; + } + } + } + return retval; + +} + +// Fills top_n_flags_ with enum values for the status of each character +void RecodeBeamSearch::FinalizeTopNFlags(int t, int num_outputs, int top_n) { + + top_n_flags_.clear(); + top_n_flags_.resize(num_outputs, TN_ALSO_RAN); + + std::vector topn_for_step = save_topn_[t]; + int nbr_entries = topn_for_step.size(); + for (int i = 0; i < nbr_entries; ++i) { + TopPair entry = topn_for_step[i]; + if (i > 1) { + top_n_flags_[entry.data()] = TN_TOPN; + } else { + top_n_flags_[entry.data()] = TN_TOP2; + if (i == 0) { + top_code_ = entry.data(); + } else { + second_code_ = entry.data(); + } + } + } + + top_n_flags_[null_char_] = TN_TOP2; +} + void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY &output, double dict_ratio, double cert_offset, double worst_dict_cert, @@ -188,7 +546,7 @@ void RecodeBeamSearch::calculateCharBoundaries(std::vector *starts, std::vector *ends, std::vector *char_bounds_, int maxWidth) { - char_bounds_->push_back(0); + char_bounds_->push_back((*starts)[0]); for (unsigned i = 0; i < ends->size(); ++i) { int middle = ((*starts)[i + 1] - (*ends)[i]) / 2; char_bounds_->push_back((*ends)[i] + middle); @@ -243,6 +601,7 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box, int lstm_choice_mode) { words->truncate(0); std::vector unichar_ids; + std::vector unichar_codes; std::vector certs; std::vector ratings; std::vector xcoords; @@ -253,7 +612,7 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box, if (debug) { DebugPath(unicharset, best_nodes); ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings, - &xcoords); + &xcoords, &unichar_codes); tprintf("\nSecond choice path:\n"); DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings, xcoords); @@ -262,7 +621,7 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box, // Coordinates of every chosen character, to match the alternative choices to // it. ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords, - &character_boundaries_); + &character_boundaries_, &unichar_codes); int num_ids = unichar_ids.size(); if (debug) { DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings, @@ -568,8 +927,11 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( const std::vector &best_nodes, std::vector *unichar_ids, std::vector *certs, std::vector *ratings, std::vector *xcoords, - std::vector *character_boundaries) { + std::vector *character_boundaries, + std::vector *codes) { + unichar_ids->clear(); + codes->clear(); certs->clear(); ratings->clear(); xcoords->clear(); @@ -588,8 +950,8 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( } rating -= cert; } - starts.push_back(t); if (t < width) { + starts.push_back(t); int unichar_id = best_nodes[t]->unichar_id; if (unichar_id == UNICHAR_SPACE && !certs->empty() && best_nodes[t]->permuter != NO_PERM) { @@ -603,9 +965,11 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( rating = 0.0; } unichar_ids->push_back(unichar_id); + codes->push_back(best_nodes[t]->code); xcoords->push_back(t); - do { - double cert = best_nodes[t++]->certainty; + t++; + while (t < width && best_nodes[t]->duplicate) { + double cert = best_nodes[t]->certainty; // Special-case NO-PERM space to forget the certainty of the previous // nulls. See long comment in ContinueContext. if (cert < certainty || (unichar_id == UNICHAR_SPACE && @@ -613,7 +977,8 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( certainty = cert; } rating -= cert; - } while (t < width && best_nodes[t]->duplicate); + t++; + } ends.push_back(t); certs->push_back(certainty); ratings->push_back(rating); @@ -744,6 +1109,7 @@ void RecodeBeamSearch::DecodeStep(const float *outputs, int t, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, bool debug) { + current_timestep_ = t; if (t == static_cast(beam_.size())) { beam_.push_back(new RecodeBeam); } @@ -1157,6 +1523,7 @@ void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id, RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert, score, prev, initial_dawgs, ComputeCodeHash(code, false, prev)); + node.timestep = current_timestep_; *best_initial_dawg = node; } } @@ -1203,6 +1570,7 @@ void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id, uint64_t hash = ComputeCodeHash(code, dup, prev); RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end, dup, cert, score, prev, d, hash); + node.timestep = current_timestep_; if (UpdateHeapIfMatched(&node, heap)) { return; } diff --git a/src/lstm/recodebeam.h b/src/lstm/recodebeam.h index 316fb16efa..bc9258f7c9 100644 --- a/src/lstm/recodebeam.h +++ b/src/lstm/recodebeam.h @@ -102,7 +102,8 @@ struct RecodeNode { , score(0.0f) , prev(nullptr) , dawgs(nullptr) - , code_hash(0) {} + , code_hash(0) + , timestep(-1) {} RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start, bool word_start, bool end, bool dup, float cert, float s, const RecodeNode *p, DawgPositionVector *d, uint64_t hash) @@ -117,7 +118,8 @@ struct RecodeNode { , score(s) , prev(p) , dawgs(d) - , code_hash(hash) {} + , code_hash(hash) + , timestep(-1) {} // NOTE: If we could use C++11, then this would be a move constructor. // Instead we have copy constructor that does a move!! This is because we // don't want to copy the whole DawgPositionVector each time, and true @@ -172,6 +174,8 @@ struct RecodeNode { // A hash of all codes in the prefix and this->code as well. Used for // duplicate path removal. uint64_t code_hash; + // keep track of the timestep that this node belongs to + int timestep; }; using RecodePair = KDPairInc; @@ -304,7 +308,8 @@ class TESS_API RecodeBeamSearch { static void ExtractPathAsUnicharIds(const std::vector &best_nodes, std::vector *unichar_ids, std::vector *certs, std::vector *ratings, std::vector *xcoords, - std::vector *character_boundaries = nullptr); + std::vector *character_boundaries = nullptr, + std::vector *codes = nullptr); // Sets up a word with the ratings matrix and fake blobs with boxes in the // right places. @@ -319,6 +324,23 @@ class TESS_API RecodeBeamSearch { void ComputeSecTopN(std::unordered_set *exList, const float *outputs, int num_outputs, int top_n); + // determines top_n choices for given step and saves for later + void ComputeAndSaveTopN(const float *outputs, int num_outputs, int top_n); + // searches for potential diplopia cases and eliminates lower scored ones + void FindAndRemoveDiplopia(int width); + // locates timestep at which apparent diplopia is occurring + int LocateDiplopia(int width, int latest); + // determines which character to remove for diplopia + void DetermineDiplopiaCharToRemove(int width, int timestep); + // removes selected diplopia character + void RemoveChosenDiplopiaChar(int width); + // determines dimensions for a specific character + void DetermineDiplopiaDimensions(int width, int timestep, int diplopia_char); + // returns key value for given character and timestep + float GetKeyForTimestep(int width, int timestep, int diplopia_char); + // Fills top_n_flags_ with enum values for status of each character + void FinalizeTopNFlags(int t, int num_outputs, int top_n); + // Adds the computation for the current time-step to the beam. Call at each // time-step in sequence from left to right. outputs is the activation vector // for the current timestep. @@ -422,6 +444,22 @@ class TESS_API RecodeBeamSearch { bool is_simple_text_; // The encoded (class label) of the null/reject character. int null_char_; + // save topn characters and ratings for all timesteps + // for later use in diplopia detection + std::vector> save_topn_; + // used to save timestep# in heap nodes + int current_timestep_; + + // used to identify potential diplopia cases + int first_diplopia_char_; + int second_diplopia_char_; + int chosen_diplopia_char_; + int start_diplopia_; + int end_diplopia_; + int diplopia_gap_; + float diplopia_max_; + int diplopia_max_timestep_; + }; } // namespace tesseract. diff --git a/test b/test index 2761899921..ebaee164bb 160000 --- a/test +++ b/test @@ -1 +1 @@ -Subproject commit 2761899921c08014cf9dbf3b63592237fb9e6ecb +Subproject commit ebaee164bb39fe55b601b95b92db686d3c7da265