Welcome to mirror list, hosted at ThFree Co, Russian Federation.

tune_instances_test.cc « interpolate « lm - github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 73a4620fa1c73b6a47319ab9fce8f8f76289a581 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include "lm/interpolate/tune_instances.hh"

#include "util/file.hh"
#include "util/file_stream.hh"
#include "util/stream/chain.hh"
#include "util/stream/config.hh"
#include "util/stream/typed_stream.hh"
#include "util/string_piece.hh"

#define BOOST_TEST_MODULE InstanceTest
#include <boost/test/unit_test.hpp>

#include <vector>

#include <math.h>

namespace lm { namespace interpolate { namespace {

BOOST_AUTO_TEST_CASE(Toy) {
  util::scoped_fd test_input(util::MakeTemp("temporary"));
  util::FileStream(test_input.get()) << "c\n";

  std::string dir("../common/test_data");
  if (boost::unit_test::framework::master_test_suite().argc == 2) {
    dir = boost::unit_test::framework::master_test_suite().argv[1];
  }

#if BYTE_ORDER == LITTLE_ENDIAN
  std::string endian = "little";
#elif BYTE_ORDER == BIG_ENDIAN
  std::string endian = "big";
#else
#error "Unsupported byte order."
#endif
  dir += "/" + endian + "endian/";

  std::vector<StringPiece> model_names;
  std::string full0 = dir + "toy0";
  std::string full1 = dir + "toy1";
  model_names.push_back(full0);
  model_names.push_back(full1);

  // Tiny buffer sizes.
  InstancesConfig config;
  config.model_read_chain_mem = 100;
  config.extension_write_chain_mem = 100;
  config.lazy_memory = 100;
  config.sort.temp_prefix = "temporary";
  config.sort.buffer_size = 100;
  config.sort.total_memory = 1024;

  util::SeekOrThrow(test_input.get(), 0);

  Instances inst(test_input.release(), model_names, config);

  BOOST_CHECK_EQUAL(1, inst.BOS());
  const Matrix &ln_unigrams = inst.LNUnigrams();

  // <unk>=0
  BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(0, 0), 0.001);
  BOOST_CHECK_CLOSE(-1 * M_LN10, ln_unigrams(0, 1), 0.001);
  // <s>=1 doesn't matter as long as it doesn't cause NaNs.
  BOOST_CHECK(!isnan(ln_unigrams(1, 0)));
  BOOST_CHECK(!isnan(ln_unigrams(1, 1)));
  // a = 2
  BOOST_CHECK_CLOSE(-0.46943438 * M_LN10, ln_unigrams(2, 0), 0.001);
  BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(2, 1), 0.001);
  // </s> = 3
  BOOST_CHECK_CLOSE(-0.5720968 * M_LN10, ln_unigrams(3, 0), 0.001);
  BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(3, 1), 0.001);
  // c = 4
  BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(4, 0), 0.001); // <unk>
  BOOST_CHECK_CLOSE(-0.7659168 * M_LN10, ln_unigrams(4, 1), 0.001);
  // too lazy to do b = 5.

  // Two instances:
  // <s> predicts c
  // <s> c predicts </s>
  BOOST_REQUIRE_EQUAL(2, inst.NumInstances());
  BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(0), 0.001);
  BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(1), 0.001);


  // Backoffs of <s> c
  BOOST_CHECK_CLOSE(0.0, inst.LNBackoffs(1)(0), 0.001);
  BOOST_CHECK_CLOSE((-0.30103 - 0.30103) * M_LN10, inst.LNBackoffs(1)(1), 0.001);

  util::stream::Chain extensions(util::stream::ChainConfig(inst.ReadExtensionsEntrySize(), 2, 300));
  inst.ReadExtensions(extensions);
  util::stream::TypedStream<Extension> stream(extensions.Add());
  extensions >> util::stream::kRecycle;

  // The extensions are (in order of instance, vocab id, and model as they should be sorted):
  // <s> a from both models 0 and 1 (so two instances)
  // <s> c from model 1
  // <s> b from model 0
  // c </s> from model 1
  // Magic probabilities come from querying the models directly.

  // <s> a from model 0
  BOOST_REQUIRE(stream);
  BOOST_CHECK_EQUAL(0, stream->instance);
  BOOST_CHECK_EQUAL(2 /* a */, stream->word);
  BOOST_CHECK_EQUAL(0, stream->model);
  BOOST_CHECK_CLOSE(-0.37712017 * M_LN10, stream->ln_prob, 0.001);

  // <s> a from model 1
  BOOST_REQUIRE(++stream);
  BOOST_CHECK_EQUAL(0, stream->instance);
  BOOST_CHECK_EQUAL(2 /* a */, stream->word);
  BOOST_CHECK_EQUAL(1, stream->model);
  BOOST_CHECK_CLOSE(-0.4301247 * M_LN10, stream->ln_prob, 0.001);

  // <s> c from model 1
  BOOST_REQUIRE(++stream);
  BOOST_CHECK_EQUAL(0, stream->instance);
  BOOST_CHECK_EQUAL(4 /* c */, stream->word);
  BOOST_CHECK_EQUAL(1, stream->model);
  BOOST_CHECK_CLOSE(-0.4740302 * M_LN10, stream->ln_prob, 0.001);

  // <s> b from model 0
  BOOST_REQUIRE(++stream);
  BOOST_CHECK_EQUAL(0, stream->instance);
  BOOST_CHECK_EQUAL(5 /* b */, stream->word);
  BOOST_CHECK_EQUAL(0, stream->model);
  BOOST_CHECK_CLOSE(-0.41574955 * M_LN10, stream->ln_prob, 0.001);

  // c </s> from model 1
  BOOST_REQUIRE(++stream);
  BOOST_CHECK_EQUAL(1, stream->instance);
  BOOST_CHECK_EQUAL(3 /* </s> */, stream->word);
  BOOST_CHECK_EQUAL(1, stream->model);
  BOOST_CHECK_CLOSE(-0.09113217 * M_LN10, stream->ln_prob, 0.001);

  BOOST_CHECK(!++stream);
}

}}} // namespaces