Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/nplm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/testNeuralLM.cpp')
-rw-r--r--src/testNeuralLM.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/testNeuralLM.cpp b/src/testNeuralLM.cpp
index abaab34..a2aa5e3 100644
--- a/src/testNeuralLM.cpp
+++ b/src/testNeuralLM.cpp
@@ -20,7 +20,7 @@ using namespace Eigen;
using namespace nplm;
void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector<int> > &ngrams,
- vector<double> &out) {
+ vector<user_data_t> &out) {
if (ngrams.size() == 0) return;
int ngram_size = ngrams[0].size();
@@ -29,7 +29,7 @@ void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector<
// Score one n-gram at a time. This is how the LM would be queried from a decoder.
for (int sent_id=0; sent_id<start.size()-1; sent_id++)
{
- double sent_log_prob = 0.0;
+ user_data_t sent_log_prob = 0.0;
for (int j=start[sent_id]; j<start[sent_id+1]; j++)
sent_log_prob += lm.lookup_ngram(ngrams[j]);
out.push_back(sent_log_prob);
@@ -38,7 +38,7 @@ void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector<
else
{
// Score a whole minibatch at a time.
- Matrix<double,1,Dynamic> log_probs(ngrams.size());
+ Matrix<user_data_t,1,Dynamic> log_probs(ngrams.size());
Matrix<int,Dynamic,Dynamic> minibatch(ngram_size, minibatch_size);
minibatch.setZero();
@@ -52,7 +52,7 @@ void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector<
for (int sent_id=0; sent_id<start.size()-1; sent_id++)
{
- double sent_log_prob = 0.0;
+ user_data_t sent_log_prob = 0.0;
for (int j=start[sent_id]; j<start[sent_id+1]; j++)
sent_log_prob += log_probs[j];
out.push_back(sent_log_prob);
@@ -157,7 +157,7 @@ int main (int argc, char *argv[])
start.push_back(ngrams.size());
int num_threads = 1;
- vector< vector<double> > sent_log_probs(num_threads);
+ vector< vector<user_data_t> > sent_log_probs(num_threads);
/*
// Test thread safety
@@ -169,7 +169,7 @@ int main (int argc, char *argv[])
*/
score(lm, minibatch_size, start, ngrams, sent_log_probs[0]);
- vector<double> log_likelihood(num_threads);
+ vector<user_data_t> log_likelihood(num_threads);
std::fill(log_likelihood.begin(), log_likelihood.end(), 0.0);
for (int i=0; i<sent_log_probs[0].size(); i++) {
for (int t=0; t<num_threads; t++)