Welcome to mirror list, hosted at ThFree Co, Russian Federation.

tune_weights.cc « interpolate « lm - github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0d1667ef3bbc99c65e31c2f70c0e5460557f4e9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include "tune_weights.hh"

#include "tune_derivatives.hh"
#include "tune_instances.hh"

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" // Older gcc doesn't have "-Wunused-local-typedefs" and complains.
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#include <Eigen/Dense>
#pragma GCC diagnostic pop
#include <boost/program_options.hpp>

#include <iostream>

namespace lm { namespace interpolate {
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, std::vector<float> &weights_out) {
  Instances instances(tune_file, model_names, config);
  Vector weights = Vector::Constant(model_names.size(), 1.0 / model_names.size());
  Vector gradient;
  Matrix hessian;
  for (std::size_t iteration = 0; iteration < 10 /*TODO fancy stopping criteria */; ++iteration) {
    std::cerr << "Iteration " << iteration << ": weights =";
    for (Vector::Index i = 0; i < weights.rows(); ++i) {
      std::cerr << ' ' << weights(i);
    }
    std::cerr << std::endl;
    std::cerr << "Perplexity = " << Derivatives(instances, weights, gradient, hessian) << std::endl;
    // TODO: 1.0 step size was too big and it kept getting unstable.  More math.
    weights -= 0.7 * hessian.inverse() * gradient;
  }
  weights_out.assign(weights.data(), weights.data() + weights.size());
}
}} // namespaces