Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/nplm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRico Sennrich <rico.sennrich@gmx.ch>2015-04-08 12:08:47 +0300
committerRico Sennrich <rico.sennrich@gmx.ch>2015-04-08 12:08:47 +0300
commit28bdadf328c63ee086e8aa5de23cfe0c11728c5b (patch)
tree376bab92734d4d9e3a37deb32e45cb8324a5620b
parent3dc380d71ab1355ff45de1dad63c3ed00cbf9f0b (diff)
refactor validation perplexity test; check perplexity every 500000 minibatches.
-rw-r--r--src/trainNeuralNetwork.cpp105
1 files changed, 58 insertions, 47 deletions
diff --git a/src/trainNeuralNetwork.cpp b/src/trainNeuralNetwork.cpp
index 5e871ec..97af03b 100644
--- a/src/trainNeuralNetwork.cpp
+++ b/src/trainNeuralNetwork.cpp
@@ -51,6 +51,55 @@ typedef ip::allocator<vec, ip::managed_mapped_file::segment_manager> vecAllocato
typedef long long int data_size_t; // training data can easily exceed 2G instances
+
+void compute_validation_perplexity(int ngram_size, int output_vocab_size, int validation_minibatch_size, int validation_data_size, int num_validation_batches, param & myParam, propagator & prop_validation, Map< Matrix<int,Dynamic,Dynamic> > & validation_data, double & current_learning_rate, double & current_validation_ll)
+{
+ double log_likelihood = 0.0;
+
+ Matrix<double,Dynamic,Dynamic> scores(output_vocab_size, validation_minibatch_size);
+ Matrix<double,Dynamic,Dynamic> output_probs(output_vocab_size, validation_minibatch_size);
+ Matrix<int,Dynamic,Dynamic> minibatch(ngram_size, validation_minibatch_size);
+
+ for (int validation_batch =0;validation_batch < num_validation_batches;validation_batch++)
+ {
+ int validation_minibatch_start_index = validation_minibatch_size * validation_batch;
+ int current_minibatch_size = min(validation_minibatch_size,
+ validation_data_size - validation_minibatch_start_index);
+ minibatch.leftCols(current_minibatch_size) = validation_data.middleCols(validation_minibatch_start_index,
+ current_minibatch_size);
+ prop_validation.fProp(minibatch.topRows(ngram_size-1));
+
+ // Do full forward prop through output word embedding layer
+ start_timer(4);
+ if (prop_validation.skip_hidden)
+ prop_validation.output_layer_node.param->fProp(prop_validation.first_hidden_activation_node.fProp_matrix, scores);
+ else
+ prop_validation.output_layer_node.param->fProp(prop_validation.second_hidden_activation_node.fProp_matrix, scores);
+ stop_timer(4);
+
+ // And softmax and loss. Be careful of short minibatch
+ double minibatch_log_likelihood;
+ start_timer(5);
+ SoftmaxLogLoss().fProp(scores.leftCols(current_minibatch_size),
+ minibatch.row(ngram_size-1),
+ output_probs,
+ minibatch_log_likelihood);
+ stop_timer(5);
+ log_likelihood += minibatch_log_likelihood;
+ }
+
+ cerr << "Validation log-likelihood: "<< log_likelihood << endl;
+ cerr << " perplexity: "<< exp(-log_likelihood/validation_data_size) << endl;
+
+ // If the validation perplexity decreases, halve the learning rate.
+ if (current_validation_ll != 0.0 && log_likelihood < current_validation_ll && myParam.parameter_update != "ADA")
+ {
+ current_learning_rate /= 2;
+ }
+ current_validation_ll = log_likelihood;
+}
+
+
int main(int argc, char** argv)
{
ios::sync_with_stdio(false);
@@ -550,6 +599,13 @@ int main(int argc, char** argv)
cerr << batch <<"...";
}
+ if (batch > 0 && batch % 500000 == 0)
+ {
+ cerr << endl;
+ compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll);
+ cerr << "Current learning rate: " << current_learning_rate << endl;
+ }
+
data_size_t minibatch_start_index = minibatch_size * batch;
int current_minibatch_size = min(static_cast<data_size_t>(minibatch_size), training_data_size - minibatch_start_index);
@@ -731,53 +787,8 @@ int main(int argc, char** argv)
if (epoch % 1 == 0 && validation_data_size > 0)
{
- //////COMPUTING VALIDATION SET PERPLEXITY///////////////////////
- ////////////////////////////////////////////////////////////////
-
- double log_likelihood = 0.0;
-
- Matrix<double,Dynamic,Dynamic> scores(output_vocab_size, validation_minibatch_size);
- Matrix<double,Dynamic,Dynamic> output_probs(output_vocab_size, validation_minibatch_size);
- Matrix<int,Dynamic,Dynamic> minibatch(ngram_size, validation_minibatch_size);
-
- for (int validation_batch =0;validation_batch < num_validation_batches;validation_batch++)
- {
- int validation_minibatch_start_index = validation_minibatch_size * validation_batch;
- int current_minibatch_size = min(validation_minibatch_size,
- validation_data_size - validation_minibatch_start_index);
- minibatch.leftCols(current_minibatch_size) = validation_data.middleCols(validation_minibatch_start_index,
- current_minibatch_size);
- prop_validation.fProp(minibatch.topRows(ngram_size-1));
-
- // Do full forward prop through output word embedding layer
- start_timer(4);
- if (prop_validation.skip_hidden)
- prop_validation.output_layer_node.param->fProp(prop_validation.first_hidden_activation_node.fProp_matrix, scores);
- else
- prop_validation.output_layer_node.param->fProp(prop_validation.second_hidden_activation_node.fProp_matrix, scores);
- stop_timer(4);
-
- // And softmax and loss. Be careful of short minibatch
- double minibatch_log_likelihood;
- start_timer(5);
- SoftmaxLogLoss().fProp(scores.leftCols(current_minibatch_size),
- minibatch.row(ngram_size-1),
- output_probs,
- minibatch_log_likelihood);
- stop_timer(5);
- log_likelihood += minibatch_log_likelihood;
- }
-
- cerr << "Validation log-likelihood: "<< log_likelihood << endl;
- cerr << " perplexity: "<< exp(-log_likelihood/validation_data_size) << endl;
-
- // If the validation perplexity decreases, halve the learning rate.
- if (epoch > 0 && log_likelihood < current_validation_ll && myParam.parameter_update != "ADA")
- {
- current_learning_rate /= 2;
- }
- current_validation_ll = log_likelihood;
- }
+ compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll);
+ }
}
return 0;