Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/nplm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRico Sennrich <rico.sennrich@gmx.ch>2015-08-27 16:12:52 +0300
committerRico Sennrich <rico.sennrich@gmx.ch>2015-08-27 16:12:52 +0300
commit12fd985180a7c9a7aee402c2f05a1f9651566de3 (patch)
treeb9b17a1a0dfe03b06ab4c3befe98a755bc4a3db5
parent9dea3fe1329ce392b9e420c89663f7bddabb068a (diff)
(optionally) only save best model (based on validation perplexity)
-rw-r--r--src/param.h2
-rw-r--r--src/trainNeuralNetwork.cpp45
2 files changed, 36 insertions, 11 deletions
diff --git a/src/param.h b/src/param.h
index fe1b6d6..e0f8437 100644
--- a/src/param.h
+++ b/src/param.h
@@ -59,6 +59,8 @@ struct param
double input_dropout;
int null_index;
+ bool save_best;
+
bool normalization;
user_data_t normalization_init;
diff --git a/src/trainNeuralNetwork.cpp b/src/trainNeuralNetwork.cpp
index 632391a..7233ff3 100644
--- a/src/trainNeuralNetwork.cpp
+++ b/src/trainNeuralNetwork.cpp
@@ -130,6 +130,7 @@ int main(int argc, char** argv)
ValueArg<double> input_dropout("", "input_dropout", "Probability of retaining input word. Values between 0 (all input is ignored) to 1 (no dropout). Default: 1.", false, 1, "user_data_t", cmd);
ValueArg<int> null_index("", "null_index", "Index of null word. Used as special (dropped out) token for input dropout.", false, 0, "int", cmd);
+ ValueArg<bool> save_best("", "save_best", "Save only best model (based on devset perplexity). 1 = yes, 0 = no. Default: 0.", false, 0, "bool", cmd);
ValueArg<user_data_t> learning_rate("", "learning_rate", "Learning rate for stochastic gradient ascent. Default: 1.", false, 1., "user_data_t", cmd);
@@ -224,6 +225,7 @@ int main(int argc, char** argv)
myParam.L2_reg = L2_reg.getValue();
myParam.input_dropout = input_dropout.getValue();
myParam.null_index = null_index.getValue();
+ myParam.save_best = save_best.getValue();
myParam.init_normal= init_normal.getValue();
myParam.init_range = init_range.getValue();
myParam.normalization_init = normalization_init.getValue();
@@ -557,6 +559,7 @@ int main(int argc, char** argv)
user_data_t momentum_delta = (myParam.final_momentum - myParam.initial_momentum)/(myParam.num_epochs-1);
user_data_t current_learning_rate = myParam.learning_rate;
user_data_t current_validation_ll = 0.0;
+ user_data_t best_validation_ll = 0.0;
int ngram_size = myParam.ngram_size;
int input_vocab_size = myParam.input_vocab_size;
@@ -609,11 +612,21 @@ int main(int argc, char** argv)
cerr << batch <<"...";
}
- if (batch > 0 && batch % 500000 == 0)
+ if (batch > 0 && batch % 500000 == 0 && validation_data_size > 0)
{
cerr << endl;
compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll);
cerr << "Current learning rate: " << current_learning_rate << endl;
+ if ((best_validation_ll == 0.0 || best_validation_ll < current_validation_ll) && myParam.model_prefix != "" && myParam.save_best)
+ {
+ best_validation_ll = current_validation_ll;
+ cerr << "New best perplexity; Writing model" << endl;
+ if (myParam.input_words_file != "")
+ nn.write(myParam.model_prefix + ".best", input_words, output_words);
+ else
+ nn.write(myParam.model_prefix + ".best");
+ }
+
}
data_size_t minibatch_start_index = minibatch_size * batch;
@@ -799,18 +812,28 @@ int main(int argc, char** argv)
nn.input_layer.zero(myParam.null_index);
}
- cerr << "Writing model" << endl;
- if (myParam.input_words_file != "")
- nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1), input_words, output_words);
- else
- nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1));
+ if (!myParam.save_best || validation_data_size == 0) {
+ cerr << "Writing model" << endl;
+ if (myParam.input_words_file != "")
+ nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1), input_words, output_words);
+ else
+ nn.write(myParam.model_prefix + "." + lexical_cast<string>(epoch+1));
+ }
}
- if (epoch % 1 == 0 && validation_data_size > 0)
- {
- compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll);
- }
-
+ if (epoch % 1 == 0 && validation_data_size > 0)
+ {
+ compute_validation_perplexity(ngram_size, output_vocab_size, validation_minibatch_size, validation_data_size, num_validation_batches, myParam, prop_validation, validation_data, current_learning_rate, current_validation_ll);
+ if ((best_validation_ll == 0.0 || best_validation_ll < current_validation_ll) && myParam.save_best)
+ {
+ best_validation_ll = current_validation_ll;
+ cerr << "New best perplexity; Writing model" << endl;
+ if (myParam.input_words_file != "")
+ nn.write(myParam.model_prefix + ".best", input_words, output_words);
+ else
+ nn.write(myParam.model_prefix + ".best");
+ }
+ }
}
return 0;
}