Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-20 01:34:59 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-20 01:34:59 +0300
commit40dc3ed6762740ab6686ed8fcbc52d40c5c9b31e (patch)
treefbf47341ab0ad9731d3d6336f9f8a7e1466f828b /src
parente57894e5773f5891c6499ae53fce55ddfcfa4e79 (diff)
fixed concurrency bug
Diffstat (limited to 'src')
-rw-r--r--src/decoder/decoder_main.cu3
-rw-r--r--src/decoder/god.cu8
-rw-r--r--src/decoder/god.h2
3 files changed, 12 insertions, 1 deletions
diff --git a/src/decoder/decoder_main.cu b/src/decoder/decoder_main.cu
index 2e2e0093..72fcee3f 100644
--- a/src/decoder/decoder_main.cu
+++ b/src/decoder/decoder_main.cu
@@ -23,7 +23,7 @@ int main(int argc, char* argv[]) {
while(std::getline(std::cin, in)) {
Sentence sentence = God::GetSourceVocab()(in);
- auto translationTask = [&, taskCounter]{
+ auto translationTask = [sentence, taskCounter] {
thread_local std::unique_ptr<Search> search;
if(!search)
search.reset(new Search(taskCounter));
@@ -40,5 +40,6 @@ int main(int argc, char* argv[]) {
Printer(result.get(), lineCounter++, std::cout);
std::cerr << timer.format() << std::endl;
+ God::ClearModels();
return 0;
}
diff --git a/src/decoder/god.cu b/src/decoder/god.cu
index 50ea1769..a91e44d2 100644
--- a/src/decoder/god.cu
+++ b/src/decoder/god.cu
@@ -152,3 +152,11 @@ std::vector<float>& God::GetScorerWeights() {
return Summon().weights_;
}
+// clean up cuda vectors before cuda context goes out of scope
+void God::ClearModels() {
+ for(auto& models : Summon().modelsPerDevice_)
+ for(auto& m : models)
+ m.reset(nullptr);
+}
+
+
diff --git a/src/decoder/god.h b/src/decoder/god.h
index 8dcbd5d3..62435969 100644
--- a/src/decoder/god.h
+++ b/src/decoder/god.h
@@ -33,6 +33,8 @@ class God {
static std::vector<ScorerPtr> GetScorers(size_t);
static std::vector<float>& GetScorerWeights();
+ static void ClearModels();
+
private:
God& NonStaticInit(int argc, char** argv);