diff options
author | Rico Sennrich <rico.sennrich@gmx.ch> | 2015-04-08 12:08:47 +0300 |
---|---|---|
committer | Rico Sennrich <rico.sennrich@gmx.ch> | 2015-04-08 12:08:47 +0300 |
commit | 28bdadf328c63ee086e8aa5de23cfe0c11728c5b (patch) | |
tree | 376bab92734d4d9e3a37deb32e45cb8324a5620b | |
parent | 3dc380d71ab1355ff45de1dad63c3ed00cbf9f0b (diff) |
refactor validation perplexity test; check perplexity every 500000 minibatches.
-rw-r--r-- | src/trainNeuralNetwork.cpp | 105 |
1 files changed, 58 insertions, 47 deletions
diff --git a/src/trainNeuralNetwork.cpp b/src/trainNeuralNetwork.cpp index 5e871ec..97af03b 100644 --- a/src/trainNeuralNetwork.cpp +++ b/src/trainNeuralNetwork.cpp @@ -51,6 +51,55 @@ typedef ip::allocator<vec, ip::managed_mapped_file::segment_manager> vecAllocato typedef long long int data_size_t; // training data can easily exceed 2G instances + +void compute_validation_perplexity(int ngram_size, int output_vocab_size, int validation_minibatch_size, int validation_data_size, int num_validation_batches, param & myParam, propagator & prop_validation, Map< Matrix<int,Dynamic,Dynamic> > & validation_data, double & current_learning_rate, double & current_validation_ll) +{ + double log_likelihood = 0.0; + + Matrix<double,Dynamic,Dynamic> scores(output_vocab_size, validation_minibatch_size); + Matrix<double,Dynamic,Dynamic> output_probs(output_vocab_size, validation_minibatch_size); + Matrix<int,Dynamic,Dynamic> minibatch(ngram_size, validation_minibatch_size); + + for (int validation_batch =0;validation_batch < num_validation_batches;validation_batch++) + { + int validation_minibatch_start_index = validation_minibatch_size * validation_batch; + int current_minibatch_size = min(validation_minibatch_size, + validation_data_size - validation_minibatch_start_index); + minibatch.leftCols(current_minibatch_size) = validation_data.middleCols(validation_minibatch_start_index, + current_minibatch_size); + prop_validation.fProp(minibatch.topRows(ngram_size-1)); + + // Do full forward prop through output word embedding layer + start_timer(4); + if (prop_validation.skip_hidden) + prop_validation.output_layer_node.param->fProp(prop_validation.first_hidden_activation_node.fProp_matrix, scores); + else + prop_validation.output_layer_node.param->fProp(prop_validation.second_hidden_activation_node.fProp_matrix, scores); + stop_timer(4); + + // And softmax and loss. Be careful of short minibatch + double minibatch_log_likelihood; + start_timer(5); + SoftmaxLogLoss().fProp(scores.leftCols(current_minibatch_size), + minibatch.row(ngram_size-1), + output_probs, + minibatch_log_likelihood); + stop_timer(5); + log_likelihood += minibatch_log_likelihood; + } + + cerr << "Validation log-likelihood: "<< log_likelihood << endl; + cerr << " perplexity: "<< exp(-log_likelihood/validation_data_size) << endl; + + // If the validation perplexity decreases, halve the learning rate. + if (current_validation_ll != 0.0 && log_likelihood < current_validation_ll && myParam.parameter_update != "ADA") + { + current_learning_rate /= 2; + } + current_validation_ll = log_likelihood; +} + + int main(int argc, char** argv) { ios::sync_with_stdio(false); @@ -550,6 +599,13 @@ int main(int argc, char** argv) cerr << batch <<"..."; } + if (batch > 0 && batch % 500000 == 0) + { + cerr << endl; + compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll); + cerr << "Current learning rate: " << current_learning_rate << endl; + } + data_size_t minibatch_start_index = minibatch_size * batch; int current_minibatch_size = min(static_cast<data_size_t>(minibatch_size), training_data_size - minibatch_start_index); @@ -731,53 +787,8 @@ int main(int argc, char** argv) if (epoch % 1 == 0 && validation_data_size > 0) { - //////COMPUTING VALIDATION SET PERPLEXITY/////////////////////// - //////////////////////////////////////////////////////////////// - - double log_likelihood = 0.0; - - Matrix<double,Dynamic,Dynamic> scores(output_vocab_size, validation_minibatch_size); - Matrix<double,Dynamic,Dynamic> output_probs(output_vocab_size, validation_minibatch_size); - Matrix<int,Dynamic,Dynamic> minibatch(ngram_size, validation_minibatch_size); - - for (int validation_batch =0;validation_batch < num_validation_batches;validation_batch++) - { - int validation_minibatch_start_index = validation_minibatch_size * validation_batch; - int current_minibatch_size = min(validation_minibatch_size, - validation_data_size - validation_minibatch_start_index); - minibatch.leftCols(current_minibatch_size) = validation_data.middleCols(validation_minibatch_start_index, - current_minibatch_size); - prop_validation.fProp(minibatch.topRows(ngram_size-1)); - - // Do full forward prop through output word embedding layer - start_timer(4); - if (prop_validation.skip_hidden) - prop_validation.output_layer_node.param->fProp(prop_validation.first_hidden_activation_node.fProp_matrix, scores); - else - prop_validation.output_layer_node.param->fProp(prop_validation.second_hidden_activation_node.fProp_matrix, scores); - stop_timer(4); - - // And softmax and loss. Be careful of short minibatch - double minibatch_log_likelihood; - start_timer(5); - SoftmaxLogLoss().fProp(scores.leftCols(current_minibatch_size), - minibatch.row(ngram_size-1), - output_probs, - minibatch_log_likelihood); - stop_timer(5); - log_likelihood += minibatch_log_likelihood; - } - - cerr << "Validation log-likelihood: "<< log_likelihood << endl; - cerr << " perplexity: "<< exp(-log_likelihood/validation_data_size) << endl; - - // If the validation perplexity decreases, halve the learning rate. - if (epoch > 0 && log_likelihood < current_validation_ll && myParam.parameter_update != "ADA") - { - current_learning_rate /= 2; - } - current_validation_ll = log_likelihood; - } + compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll); + } } return 0; |