diff options
author | Rico Sennrich <rico.sennrich@gmx.ch> | 2014-10-16 16:59:52 +0400 |
---|---|---|
committer | Rico Sennrich <rico.sennrich@gmx.ch> | 2014-10-16 17:22:34 +0400 |
commit | 84b1e84e87626dfb1382f0bafcdbe55733f87922 (patch) | |
tree | 02ced1f1dad91dc156fd7f69328c43c393b4f5fe | |
parent | f097addf6e2c23b044d54c2ce4e0283faf19ed10 (diff) |
new parameter init_model to initialize the network parameters from an existing model.
(use cases: continue training with more epochs, or save memory by using different subsets of data for each epoch)
-rw-r--r-- | src/param.h | 1 | ||||
-rw-r--r-- | src/trainNeuralNetwork.cpp | 18 |
2 files changed, 16 insertions, 3 deletions
diff --git a/src/param.h b/src/param.h index b303514..8e42853 100644 --- a/src/param.h +++ b/src/param.h @@ -18,6 +18,7 @@ struct param std::string input_words_file; std::string output_words_file; std::string model_prefix; + std::string init_model; int ngram_size; int vocab_size; diff --git a/src/trainNeuralNetwork.cpp b/src/trainNeuralNetwork.cpp index fa27627..a088e16 100644 --- a/src/trainNeuralNetwork.cpp +++ b/src/trainNeuralNetwork.cpp @@ -94,6 +94,7 @@ int main(int argc, char** argv) ValueArg<string> output_words_file("", "output_words_file", "Vocabulary." , false, "", "string", cmd); ValueArg<string> validation_file("", "validation_file", "Validation data (one numberized example per line)." , false, "", "string", cmd); ValueArg<string> train_file("", "train_file", "Training data (one numberized example per line)." , true, "", "string", cmd); + ValueArg<string> init_model("", "init_model", "Initialize parameters from existing model (to continue interrupted training)", false, "", "string", cmd); cmd.parse(argc, argv); @@ -139,6 +140,7 @@ int main(int argc, char** argv) myParam.L2_reg = L2_reg.getValue(); myParam.init_normal= init_normal.getValue(); myParam.init_range = init_range.getValue(); + myParam.init_model = init_model.getValue(); myParam.normalization_init = normalization_init.getValue(); cerr << "Command line: " << endl; @@ -187,8 +189,13 @@ int main(int argc, char** argv) } cerr << loss_function.getDescription() << sep << loss_function.getValue() << endl; - cerr << init_normal.getDescription() << sep << init_normal.getValue() << endl; - cerr << init_range.getDescription() << sep << init_range.getValue() << endl; + if (init_model.getValue() != "") { + cerr << init_model.getDescription() << sep << init_model.getValue() << endl; + } + else { + cerr << init_normal.getDescription() << sep << init_normal.getValue() << endl; + cerr << init_range.getDescription() << sep << init_range.getValue() << endl; + } cerr << num_epochs.getDescription() << sep << num_epochs.getValue() << endl; cerr << minibatch_size.getDescription() << sep << minibatch_size.getValue() << endl; @@ -311,7 +318,12 @@ int main(int argc, char** argv) myParam.output_embedding_dimension, myParam.share_embeddings); - nn.initialize(rng, myParam.init_normal, myParam.init_range, -log(myParam.output_vocab_size)); + if (myParam.init_model != "") { + nn.read(myParam.init_model); + } + else { + nn.initialize(rng, myParam.init_normal, myParam.init_range, -log(myParam.output_vocab_size)); + } nn.set_activation_function(string_to_activation_function(myParam.activation_function)); loss_function_type loss_function = string_to_loss_function(myParam.loss_function); |