Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/nplm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRico Sennrich <rico.sennrich@gmx.ch>2014-10-16 16:59:52 +0400
committerRico Sennrich <rico.sennrich@gmx.ch>2014-10-16 17:22:34 +0400
commit84b1e84e87626dfb1382f0bafcdbe55733f87922 (patch)
tree02ced1f1dad91dc156fd7f69328c43c393b4f5fe
parentf097addf6e2c23b044d54c2ce4e0283faf19ed10 (diff)
new parameter init_model to initialize the network parameters from an existing model.
(use cases: continue training with more epochs, or save memory by using different subsets of data for each epoch)
-rw-r--r--src/param.h1
-rw-r--r--src/trainNeuralNetwork.cpp18
2 files changed, 16 insertions, 3 deletions
diff --git a/src/param.h b/src/param.h
index b303514..8e42853 100644
--- a/src/param.h
+++ b/src/param.h
@@ -18,6 +18,7 @@ struct param
std::string input_words_file;
std::string output_words_file;
std::string model_prefix;
+ std::string init_model;
int ngram_size;
int vocab_size;
diff --git a/src/trainNeuralNetwork.cpp b/src/trainNeuralNetwork.cpp
index fa27627..a088e16 100644
--- a/src/trainNeuralNetwork.cpp
+++ b/src/trainNeuralNetwork.cpp
@@ -94,6 +94,7 @@ int main(int argc, char** argv)
ValueArg<string> output_words_file("", "output_words_file", "Vocabulary." , false, "", "string", cmd);
ValueArg<string> validation_file("", "validation_file", "Validation data (one numberized example per line)." , false, "", "string", cmd);
ValueArg<string> train_file("", "train_file", "Training data (one numberized example per line)." , true, "", "string", cmd);
+ ValueArg<string> init_model("", "init_model", "Initialize parameters from existing model (to continue interrupted training)", false, "", "string", cmd);
cmd.parse(argc, argv);
@@ -139,6 +140,7 @@ int main(int argc, char** argv)
myParam.L2_reg = L2_reg.getValue();
myParam.init_normal= init_normal.getValue();
myParam.init_range = init_range.getValue();
+ myParam.init_model = init_model.getValue();
myParam.normalization_init = normalization_init.getValue();
cerr << "Command line: " << endl;
@@ -187,8 +189,13 @@ int main(int argc, char** argv)
}
cerr << loss_function.getDescription() << sep << loss_function.getValue() << endl;
- cerr << init_normal.getDescription() << sep << init_normal.getValue() << endl;
- cerr << init_range.getDescription() << sep << init_range.getValue() << endl;
+ if (init_model.getValue() != "") {
+ cerr << init_model.getDescription() << sep << init_model.getValue() << endl;
+ }
+ else {
+ cerr << init_normal.getDescription() << sep << init_normal.getValue() << endl;
+ cerr << init_range.getDescription() << sep << init_range.getValue() << endl;
+ }
cerr << num_epochs.getDescription() << sep << num_epochs.getValue() << endl;
cerr << minibatch_size.getDescription() << sep << minibatch_size.getValue() << endl;
@@ -311,7 +318,12 @@ int main(int argc, char** argv)
myParam.output_embedding_dimension,
myParam.share_embeddings);
- nn.initialize(rng, myParam.init_normal, myParam.init_range, -log(myParam.output_vocab_size));
+ if (myParam.init_model != "") {
+ nn.read(myParam.init_model);
+ }
+ else {
+ nn.initialize(rng, myParam.init_normal, myParam.init_range, -log(myParam.output_vocab_size));
+ }
nn.set_activation_function(string_to_activation_function(myParam.activation_function));
loss_function_type loss_function = string_to_loss_function(myParam.loss_function);