Welcome to mirror list, hosted at ThFree Co, Russian Federation.

backoff_reunification.cc « interpolate « lm - github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 7885f9aa8cb5bf81903e51c95e4779fdd77faf18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include "backoff_reunification.hh"
#include "../common/model_buffer.hh"
#include "../common/ngram_stream.hh"
#include "../common/ngram.hh"
#include "../common/compare.hh"

#include <algorithm>
#include <cassert>

namespace lm {
namespace interpolate {

namespace {
class MergeWorker {
public:
  MergeWorker(std::size_t order, const util::stream::ChainPosition &prob_pos,
              const util::stream::ChainPosition &boff_pos)
      : order_(order), prob_pos_(prob_pos), boff_pos_(boff_pos) {
    // nothing
  }

  void Run(const util::stream::ChainPosition &position) {
    lm::NGramStream<ProbBackoff> stream(position);

    lm::NGramStream<float> prob_input(prob_pos_);
    util::stream::Stream boff_input(boff_pos_);
    for (; prob_input && boff_input; ++prob_input, ++boff_input, ++stream) {
      std::copy(prob_input->begin(), prob_input->end(), stream->begin());
      stream->Value().prob = std::min(0.0f, prob_input->Value());
      stream->Value().backoff = *reinterpret_cast<float *>(boff_input.Get());
    }
    UTIL_THROW_IF2(prob_input || boff_input,
                   "Streams were not the same size during merging");
    stream.Poison();
  }

private:
  std::size_t order_;
  util::stream::ChainPosition prob_pos_;
  util::stream::ChainPosition boff_pos_;
};
}

// Since we are *adding* something to the output chain here, we pass in the
// chain itself so that we can safely add a new step to the chain without
// creating a deadlock situation (since creating a new ChainPosition will
// make a new input/output pair---we want that position to be created
// *here*, not before).
void ReunifyBackoff(util::stream::ChainPositions &prob_pos,
                    util::stream::ChainPositions &boff_pos,
                    util::stream::Chains &output_chains) {
  assert(prob_pos.size() == boff_pos.size());

  for (size_t i = 0; i < prob_pos.size(); ++i)
    output_chains[i] >> MergeWorker(i + 1, prob_pos[i], boff_pos[i]);
}
}
}