diff options
Diffstat (limited to 'src/testNeuralLM.cpp')
-rw-r--r-- | src/testNeuralLM.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/src/testNeuralLM.cpp b/src/testNeuralLM.cpp index abaab34..a2aa5e3 100644 --- a/src/testNeuralLM.cpp +++ b/src/testNeuralLM.cpp @@ -20,7 +20,7 @@ using namespace Eigen; using namespace nplm; void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector<int> > &ngrams, - vector<double> &out) { + vector<user_data_t> &out) { if (ngrams.size() == 0) return; int ngram_size = ngrams[0].size(); @@ -29,7 +29,7 @@ void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector< // Score one n-gram at a time. This is how the LM would be queried from a decoder. for (int sent_id=0; sent_id<start.size()-1; sent_id++) { - double sent_log_prob = 0.0; + user_data_t sent_log_prob = 0.0; for (int j=start[sent_id]; j<start[sent_id+1]; j++) sent_log_prob += lm.lookup_ngram(ngrams[j]); out.push_back(sent_log_prob); @@ -38,7 +38,7 @@ void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector< else { // Score a whole minibatch at a time. - Matrix<double,1,Dynamic> log_probs(ngrams.size()); + Matrix<user_data_t,1,Dynamic> log_probs(ngrams.size()); Matrix<int,Dynamic,Dynamic> minibatch(ngram_size, minibatch_size); minibatch.setZero(); @@ -52,7 +52,7 @@ void score(neuralLM &lm, int minibatch_size, vector<int>& start, vector< vector< for (int sent_id=0; sent_id<start.size()-1; sent_id++) { - double sent_log_prob = 0.0; + user_data_t sent_log_prob = 0.0; for (int j=start[sent_id]; j<start[sent_id+1]; j++) sent_log_prob += log_probs[j]; out.push_back(sent_log_prob); @@ -157,7 +157,7 @@ int main (int argc, char *argv[]) start.push_back(ngrams.size()); int num_threads = 1; - vector< vector<double> > sent_log_probs(num_threads); + vector< vector<user_data_t> > sent_log_probs(num_threads); /* // Test thread safety @@ -169,7 +169,7 @@ int main (int argc, char *argv[]) */ score(lm, minibatch_size, start, ngrams, sent_log_probs[0]); - vector<double> log_likelihood(num_threads); + vector<user_data_t> log_likelihood(num_threads); std::fill(log_likelihood.begin(), log_likelihood.end(), 0.0); for (int i=0; i<sent_log_probs[0].size(); i++) { for (int t=0; t<num_threads; t++) |