Welcome to mirror list, hosted at ThFree Co, Russian Federation.

normalize_test.cc « interpolate « lm - github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: fe220f36274ae3a297992cb385537b6c90a04f7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "lm/interpolate/normalize.hh"

#include "lm/interpolate/interpolate_info.hh"
#include "lm/interpolate/merge_probabilities.hh"
#include "lm/common/ngram_stream.hh"
#include "util/stream/chain.hh"
#include "util/stream/multi_stream.hh"

#define BOOST_TEST_MODULE NormalizeTest
#include <boost/test/unit_test.hpp>

namespace lm { namespace interpolate { namespace {

// log without backoff
const float kInputs[] = {-0.3, 1.2, -9.8, 4.0, -7.0, 0.0};

class WriteInput {
  public:
    WriteInput() {}
    void Run(const util::stream::ChainPosition &to) {
      util::stream::Stream out(to);
      for (WordIndex i = 0; i < sizeof(kInputs) / sizeof(float); ++i, ++out) {
        memcpy(out.Get(), &i, sizeof(WordIndex));
        memcpy((uint8_t*)out.Get() + sizeof(WordIndex), &kInputs[i], sizeof(float));
      }
      out.Poison();
    }
};

void CheckOutput(const util::stream::ChainPosition &from) {
  NGramStream<float> in(from);
  float sum = 0.0;
  for (WordIndex i = 0; i < sizeof(kInputs) / sizeof(float) - 1 /* <s> at the end */; ++i) {
    sum += pow(10.0, kInputs[i]);
  }
  sum = log10(sum);
  BOOST_REQUIRE(in);
  BOOST_CHECK_CLOSE(kInputs[0] - sum, in->Value(), 0.0001);
  BOOST_REQUIRE(++in);
  BOOST_CHECK_CLOSE(kInputs[1] - sum, in->Value(), 0.0001);
  BOOST_REQUIRE(++in);
  BOOST_CHECK_CLOSE(kInputs[2] - sum, in->Value(), 0.0001);
  BOOST_REQUIRE(++in);
  BOOST_CHECK_CLOSE(kInputs[3] - sum, in->Value(), 0.0001);
  BOOST_REQUIRE(++in);
  BOOST_CHECK_CLOSE(kInputs[4] - sum, in->Value(), 0.0001);
  BOOST_REQUIRE(++in);
  BOOST_CHECK_CLOSE(kInputs[5] - sum, in->Value(), 0.0001);
  BOOST_CHECK(!++in);
}

BOOST_AUTO_TEST_CASE(Unigrams) {
  InterpolateInfo info;
  info.lambdas.push_back(2.0);
  info.lambdas.push_back(-0.1);
  info.orders.push_back(1);
  info.orders.push_back(1);

  BOOST_CHECK_EQUAL(0, MakeEncoder(info, 1).EncodedLength());

  // No backoffs.
  util::stream::Chains blank(0);
  util::FixedArray<util::stream::ChainPositions> models_by_order(2);
  models_by_order.push_back(blank);
  models_by_order.push_back(blank);

  util::stream::Chains merged_probabilities(1);
  util::stream::Chains probabilities_out(1);
  util::stream::Chains backoffs_out(0);

  merged_probabilities.push_back(util::stream::ChainConfig(sizeof(WordIndex) + sizeof(float) + sizeof(float), 2, 24));
  probabilities_out.push_back(util::stream::ChainConfig(sizeof(WordIndex) + sizeof(float), 2, 100));

  merged_probabilities[0] >> WriteInput();
  Normalize(info, models_by_order, merged_probabilities, probabilities_out, backoffs_out);

  util::stream::ChainPosition checker(probabilities_out[0].Add());

  merged_probabilities >> util::stream::kRecycle;
  probabilities_out >> util::stream::kRecycle;

  CheckOutput(checker);
  probabilities_out.Wait();
}

}}} // namespaces