diff options
author | Rico Sennrich <rico.sennrich@gmx.ch> | 2015-08-27 16:12:52 +0300 |
---|---|---|
committer | Rico Sennrich <rico.sennrich@gmx.ch> | 2015-08-27 16:12:52 +0300 |
commit | 12fd985180a7c9a7aee402c2f05a1f9651566de3 (patch) | |
tree | b9b17a1a0dfe03b06ab4c3befe98a755bc4a3db5 | |
parent | 9dea3fe1329ce392b9e420c89663f7bddabb068a (diff) |
(optionally) only save best model (based on validation perplexity)
-rw-r--r-- | src/param.h | 2 | ||||
-rw-r--r-- | src/trainNeuralNetwork.cpp | 45 |
2 files changed, 36 insertions, 11 deletions
diff --git a/src/param.h b/src/param.h index fe1b6d6..e0f8437 100644 --- a/src/param.h +++ b/src/param.h @@ -59,6 +59,8 @@ struct param double input_dropout; int null_index; + bool save_best; + bool normalization; user_data_t normalization_init; diff --git a/src/trainNeuralNetwork.cpp b/src/trainNeuralNetwork.cpp index 632391a..7233ff3 100644 --- a/src/trainNeuralNetwork.cpp +++ b/src/trainNeuralNetwork.cpp @@ -130,6 +130,7 @@ int main(int argc, char** argv) ValueArg<double> input_dropout("", "input_dropout", "Probability of retaining input word. Values between 0 (all input is ignored) to 1 (no dropout). Default: 1.", false, 1, "user_data_t", cmd); ValueArg<int> null_index("", "null_index", "Index of null word. Used as special (dropped out) token for input dropout.", false, 0, "int", cmd); + ValueArg<bool> save_best("", "save_best", "Save only best model (based on devset perplexity). 1 = yes, 0 = no. Default: 0.", false, 0, "bool", cmd); ValueArg<user_data_t> learning_rate("", "learning_rate", "Learning rate for stochastic gradient ascent. Default: 1.", false, 1., "user_data_t", cmd); @@ -224,6 +225,7 @@ int main(int argc, char** argv) myParam.L2_reg = L2_reg.getValue(); myParam.input_dropout = input_dropout.getValue(); myParam.null_index = null_index.getValue(); + myParam.save_best = save_best.getValue(); myParam.init_normal= init_normal.getValue(); myParam.init_range = init_range.getValue(); myParam.normalization_init = normalization_init.getValue(); @@ -557,6 +559,7 @@ int main(int argc, char** argv) user_data_t momentum_delta = (myParam.final_momentum - myParam.initial_momentum)/(myParam.num_epochs-1); user_data_t current_learning_rate = myParam.learning_rate; user_data_t current_validation_ll = 0.0; + user_data_t best_validation_ll = 0.0; int ngram_size = myParam.ngram_size; int input_vocab_size = myParam.input_vocab_size; @@ -609,11 +612,21 @@ int main(int argc, char** argv) cerr << batch <<"..."; } - if (batch > 0 && batch % 500000 == 0) + if (batch > 0 && batch % 500000 == 0 && validation_data_size > 0) { cerr << endl; compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll); cerr << "Current learning rate: " << current_learning_rate << endl; + if ((best_validation_ll == 0.0 || best_validation_ll < current_validation_ll) && myParam.model_prefix != "" && myParam.save_best) + { + best_validation_ll = current_validation_ll; + cerr << "New best perplexity; Writing model" << endl; + if (myParam.input_words_file != "") + nn.write(myParam.model_prefix + ".best", input_words, output_words); + else + nn.write(myParam.model_prefix + ".best"); + } + } data_size_t minibatch_start_index = minibatch_size * batch; @@ -799,18 +812,28 @@ int main(int argc, char** argv) nn.input_layer.zero(myParam.null_index); } - cerr << "Writing model" << endl; - if (myParam.input_words_file != "") - nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1), input_words, output_words); - else - nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1)); + if (!myParam.save_best || validation_data_size == 0) { + cerr << "Writing model" << endl; + if (myParam.input_words_file != "") + nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1), input_words, output_words); + else + nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1)); + } } - if (epoch % 1 == 0 && validation_data_size > 0) - { - compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll); - } - + if (epoch % 1 == 0 && validation_data_size > 0) + { + compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll); + if ((best_validation_ll == 0.0 || best_validation_ll < current_validation_ll) && myParam.save_best) + { + best_validation_ll = current_validation_ll; + cerr << "New best perplexity; Writing model" << endl; + if (myParam.input_words_file != "") + nn.write(myParam.model_prefix + ".best", input_words, output_words); + else + nn.write(myParam.model_prefix + ".best"); + } + } } return 0; } |