Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/god.cpp')
-rw-r--r--src/common/god.cpp47
1 files changed, 22 insertions, 25 deletions
diff --git a/src/common/god.cpp b/src/common/god.cpp
index 61df15ff..de54bbc3 100644
--- a/src/common/god.cpp
+++ b/src/common/god.cpp
@@ -22,20 +22,22 @@ using namespace std;
namespace amunmt {
God::God()
-:threadIncr_(0)
+ : threadIncr_(0)
{
-
}
-God::~God() {}
+God::~God()
+{
+}
God& God::Init(const std::string& options) {
std::vector<std::string> args = boost::program_options::split_unix(options);
int argc = args.size() + 1;
char* argv[argc];
argv[0] = const_cast<char*>("bogus");
- for(int i = 1; i < argc; i++)
+ for (int i = 1; i < argc; ++i) {
argv[i] = const_cast<char*>(args[i-1].c_str());
+ }
return Init(argc, argv);
}
@@ -49,35 +51,35 @@ God& God::Init(int argc, char** argv) {
config_.AddOptions(argc, argv);
config_.LogOptions();
- if(Get("source-vocab").IsSequence()) {
- for(auto sourceVocabPath : Get<std::vector<std::string>>("source-vocab"))
- sourceVocabs_.emplace_back(new Vocab(sourceVocabPath));
- }
- else {
- sourceVocabs_.emplace_back(new Vocab(Get<std::string>("source-vocab")));
+ if (Get("source-vocab").IsSequence()) {
+ for (auto sourceVocabPath : Get<std::vector<std::string>>("source-vocab")) {
+ sourceVocabs_.emplace_back(new Vocab(sourceVocabPath));
+ }
+ } else {
+ sourceVocabs_.emplace_back(new Vocab(Get<std::string>("source-vocab")));
}
targetVocab_.reset(new Vocab(Get<std::string>("target-vocab")));
weights_ = Get<std::map<std::string, float>>("weights");
if(Get<bool>("show-weights")) {
- LOG(info) << "Outputting weights and exiting";
- for(auto && pair : weights_) {
- std::cout << pair.first << "= " << pair.second << std::endl;
- }
- exit(0);
+ LOG(info) << "Outputting weights and exiting";
+ for(auto && pair : weights_) {
+ std::cout << pair.first << "= " << pair.second << std::endl;
+ }
+ exit(0);
}
LoadScorers();
LoadFiltering();
if (Has("input-file")) {
- LOG(info) << "Reading from " << Get<std::string>("input-file");
- inputStream_.reset(new InputFileStream(Get<std::string>("input-file")));
+ LOG(info) << "Reading from " << Get<std::string>("input-file");
+ inputStream_.reset(new InputFileStream(Get<std::string>("input-file")));
}
else {
- LOG(info) << "Reading from stdin";
- inputStream_.reset(new InputFileStream(std::cin));
+ LOG(info) << "Reading from stdin";
+ inputStream_.reset(new InputFileStream(std::cin));
}
LoadPrePostProcessing();
@@ -184,11 +186,9 @@ std::vector<ScorerPtr> God::GetScorers(const DeviceInfo &deviceInfo) const {
std::vector<ScorerPtr> scorers;
if (deviceInfo.deviceType == CPUDevice) {
- //cerr << "CPU GetScorers" << endl;
for (auto&& loader : cpuLoaders_ | boost::adaptors::map_values)
scorers.emplace_back(loader->NewScorer(*this, deviceInfo));
} else {
- //cerr << "GPU GetScorers" << endl;
for (auto&& loader : gpuLoaders_ | boost::adaptors::map_values)
scorers.emplace_back(loader->NewScorer(*this, deviceInfo));
}
@@ -233,14 +233,12 @@ std::vector<std::string> God::Postprocess(const std::vector<std::string>& input)
}
return processed;
}
-// clean up cuda vectors before cuda context goes out of scope
+
void God::CleanUp() {
for (Loaders::value_type& loader : cpuLoaders_) {
- //cerr << "cpu loader=" << loader.first << endl;
loader.second.reset(nullptr);
}
for (Loaders::value_type& loader : gpuLoaders_) {
- //cerr << "gpu loader=" << loader.first << endl;
loader.second.reset(nullptr);
}
}
@@ -274,7 +272,6 @@ DeviceInfo God::GetNextDevice() const
++threadIncr_;
- //cerr << "GetNextDevice=" << ret << endl;
return ret;
}