Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt34
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/common/threadpool.h121
-rw-r--r--src/common/types.h7
-rw-r--r--src/common/vocab.h2
-rw-r--r--src/decoder/decoder_main.cu74
-rw-r--r--src/decoder/history.h24
-rw-r--r--src/decoder/hypothesis.h31
-rw-r--r--src/decoder/kenlm.cpp104
-rw-r--r--src/decoder/kenlm.h64
-rw-r--r--src/decoder/search.h122
-rw-r--r--src/dl4mt/decoder.h42
-rw-r--r--src/dl4mt/gru.h10
13 files changed, 552 insertions, 84 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f799ba6a..0ca5a6c5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3,7 +3,8 @@ set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(amunn CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
-LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math)
+LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
+add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
include_directories(${amunn_SOURCE_DIR})
@@ -17,5 +18,36 @@ else(Boost_FOUND)
message(SEND_ERROR "Cannot find Boost libraries. Terminating." )
endif(Boost_FOUND)
+### KenLM stuff - BEGIN ###
+set(EXT_LIBS ${EXT_LIBS} /home/marcinj/Badania/kenlm/build/lib/libkenlm.a)
+set(EXT_LIBS ${EXT_LIBS} /home/marcinj/Badania/kenlm/build/lib/libkenlm_util.a)
+include_directories(/home/marcinj/Badania/kenlm)
+add_definitions(-DKENLM_MAX_ORDER=9)
+
+#find_package (OpenMP)
+#if (OPENMP_FOUND)
+# LIST(APPEND CUDA_NVCC_FLAGS -Xcompiler -fopenmp)
+# SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
+#endif (OPENMP_FOUND)
+
+find_package (BZip2)
+if (BZIP2_FOUND)
+ include_directories(${BZIP2_INCLUDE_DIRS})
+ set(EXT_LIBS ${EXT_LIBS} ${BZIP2_LIBRARIES})
+endif (BZIP2_FOUND)
+
+find_package (ZLIB)
+if (ZLIB_FOUND)
+ include_directories(${ZLIB_INCLUDE_DIRS})
+ set(EXT_LIBS ${EXT_LIBS} ${ZLIB_LIBRARIES})
+endif (ZLIB_FOUND)
+
+find_package (LibLZMA)
+if (LIBLZMA_FOUND)
+ include_directories(${LIBLZMA_INCLUDE_DIRS})
+ set(EXT_LIBS ${EXT_LIBS} ${LIBLZMA_LIBRARIES})
+endif (LIBLZMA_FOUND)
+### KenLM stuff - END ###
+
include_directories($amunn_SOURCE_DIR}/src)
add_subdirectory(src)
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 5c1fd4cc..0a294ded 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -9,6 +9,7 @@ add_library(libamunn OBJECT
cnpy/cnpy.cpp
rescorer/nbest.cpp
common/utils.cpp
+ decoder/kenlm.cpp
)
cuda_add_executable(
diff --git a/src/common/threadpool.h b/src/common/threadpool.h
new file mode 100644
index 00000000..33c28e56
--- /dev/null
+++ b/src/common/threadpool.h
@@ -0,0 +1,121 @@
+/*
+
+Copyright (c) 2012 Jakob Progsch, Václav Zeman
+
+This software is provided 'as-is', without any express or implied
+warranty. In no event will the authors be held liable for any damages
+arising from the use of this software.
+
+Permission is granted to anyone to use this software for any purpose,
+including commercial applications, and to alter it and redistribute it
+freely, subject to the following restrictions:
+
+ 1. The origin of this software must not be misrepresented; you must not
+ claim that you wrote the original software. If you use this software
+ in a product, an acknowledgment in the product documentation would be
+ appreciated but is not required.
+
+ 2. Altered source versions must be plainly marked as such, and must not be
+ misrepresented as being the original software.
+
+ 3. This notice may not be removed or altered from any source
+ distribution.
+
+*/
+
+#pragma once
+
+#include <vector>
+#include <queue>
+#include <memory>
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <future>
+#include <functional>
+#include <stdexcept>
+
+class ThreadPool {
+public:
+ ThreadPool(size_t);
+ template<class F, class... Args>
+ auto enqueue(F&& f, Args&&... args)
+ -> std::future<typename std::result_of<F(Args...)>::type>;
+ ~ThreadPool();
+
+private:
+ // need to keep track of threads so we can join them
+ std::vector< std::thread > workers;
+ // the task queue
+ std::queue< std::function<void()> > tasks;
+
+ // synchronization
+ std::mutex queue_mutex;
+ std::condition_variable condition;
+ bool stop;
+};
+
+// the constructor just launches some amount of workers
+inline ThreadPool::ThreadPool(size_t threads)
+ : stop(false)
+{
+ for(size_t i = 0;i<threads;++i)
+ workers.emplace_back(
+ [this]
+ {
+ for(;;)
+ {
+ std::function<void()> task;
+
+ {
+ std::unique_lock<std::mutex> lock(this->queue_mutex);
+ this->condition.wait(lock,
+ [this]{ return this->stop || !this->tasks.empty(); });
+ if(this->stop && this->tasks.empty())
+ return;
+ task = std::move(this->tasks.front());
+ this->tasks.pop();
+ }
+
+ task();
+ }
+ }
+ );
+}
+
+// add new work item to the pool
+template<class F, class... Args>
+auto ThreadPool::enqueue(F&& f, Args&&... args)
+ -> std::future<typename std::result_of<F(Args...)>::type>
+{
+ using return_type = typename std::result_of<F(Args...)>::type;
+
+ auto task = std::make_shared< std::packaged_task<return_type()> >(
+ std::bind(std::forward<F>(f), std::forward<Args>(args)...)
+ );
+
+ std::future<return_type> res = task->get_future();
+ {
+ std::unique_lock<std::mutex> lock(queue_mutex);
+
+ // don't allow enqueueing after stopping the pool
+ if(stop)
+ throw std::runtime_error("enqueue on stopped ThreadPool");
+
+ tasks.emplace([task](){ (*task)(); });
+ }
+ condition.notify_one();
+ return res;
+}
+
+// the destructor joins all threads
+inline ThreadPool::~ThreadPool()
+{
+ {
+ std::unique_lock<std::mutex> lock(queue_mutex);
+ stop = true;
+ }
+ condition.notify_all();
+ for(std::thread &worker: workers)
+ worker.join();
+}
diff --git a/src/common/types.h b/src/common/types.h
index 15086edd..3cb0c2a6 100644
--- a/src/common/types.h
+++ b/src/common/types.h
@@ -2,8 +2,9 @@
#include <vector>
-#define EOS 0
-#define UNK 1
-
typedef size_t Word;
+
+const Word EOS = 0;
+const Word UNK = 1;
+
typedef std::vector<Word> Sentence;
diff --git a/src/common/vocab.h b/src/common/vocab.h
index f8997041..cfa05545 100644
--- a/src/common/vocab.h
+++ b/src/common/vocab.h
@@ -62,7 +62,7 @@ class Vocab {
return id2str_[id];
}
- size_t size() {
+ size_t size() const {
return id2str_.size();
}
diff --git a/src/decoder/decoder_main.cu b/src/decoder/decoder_main.cu
index bded9099..ca9ea8fa 100644
--- a/src/decoder/decoder_main.cu
+++ b/src/decoder/decoder_main.cu
@@ -12,6 +12,7 @@
#include "dl4mt.h"
#include "vocab.h"
#include "search.h"
+#include "threadpool.h"
class BPE {
public:
@@ -43,9 +44,12 @@ class BPE {
int main(int argc, char* argv[]) {
std::string srcVocabPath, trgVocabPath;
std::vector<std::string> modelPaths;
- size_t device = 0;
+ std::vector<std::string> lmPaths;
+ std::vector<float> lmWeights;
+ std::vector<size_t> devices;
size_t nbest = 0;
size_t beamSize = 12;
+ size_t threads = 1;
bool help = false;
namespace po = boost::program_options;
@@ -53,12 +57,18 @@ int main(int argc, char* argv[]) {
cmdline_options.add_options()
("beamsize,b", po::value(&beamSize)->default_value(12),
"Beam size")
+ ("threads", po::value(&threads)->default_value(1),
+ "Number of threads")
("n-best-list", po::value(&nbest)->default_value(0),
"N-best list")
- ("device,d", po::value(&device)->default_value(0),
+ ("device(s),d", po::value(&devices)->multitoken(),
"CUDA Device")
("model(s),m", po::value(&modelPaths)->multitoken()->required(),
"Path to a model")
+ ("lms(s),l", po::value(&lmPaths)->multitoken(),
+ "Path to a kenlm language model")
+ ("lw(s)", po::value(&lmWeights)->multitoken(),
+ "Language Model weights")
("source,s", po::value(&srcVocabPath)->required(),
"Path to a source vocab file.")
("target,t", po::value(&trgVocabPath)->required(),
@@ -85,46 +95,82 @@ int main(int argc, char* argv[]) {
exit(0);
}
- std::cerr << "Using device GPU" << device << std::endl;;
- cudaSetDevice(device);
+ if(devices.empty())
+ devices.push_back(0);
+
Vocab srcVocab(srcVocabPath);
Vocab trgVocab(trgVocabPath);
std::vector<std::unique_ptr<Weights>> models;
for(auto& modelPath : modelPaths) {
std::cerr << "Loading model " << modelPath << std::endl;
- models.emplace_back(new Weights(modelPath));
+ models.emplace_back(new Weights(modelPath, devices[0]));
+ }
+
+ std::vector<LM> lms;
+ if(lmWeights.size() < lmPaths.size())
+ lmWeights.resize(lmPaths.size(), 0.2);
+ for(auto& lmPath : lmPaths) {
+ std::cerr << "Loading lm " << lmPath << std::endl;
+ size_t index = lms.size();
+ float weight = lmWeights[index];
+ lms.emplace_back(lmPath, trgVocab, index, weight);
}
std::cerr << "done." << std::endl;
- Search search(models, nbest > 0);
-
std::cerr << "Translating...\n";
std::ios_base::sync_with_stdio(false);
BPE bpe;
+
boost::timer::cpu_timer timer;
+
+ //ThreadPool pool(threads);
+ std::vector<std::future<History>> results;
+
std::string in;
+
size_t lineCounter = 0;
while(std::getline(std::cin, in)) {
- Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
- History history = search.Decode(sentence, beamSize);
+ // auto call = [in, beamSize, nbest, &models, &lms, &bpe, &srcVocab] {
+ // thread_local Search search(models, lms, nbest > 0);
+ //
+ // Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
+ // return search.Decode(sentence, beamSize);
+ // }
+
+
+ // results.emplace_back(
+ // pool.enqueue(call)
+ // );
+ //}
+ //
+ //for(auto&& result : results) {
+ auto call = [in, beamSize, nbest, &models, &lms, &bpe, &srcVocab] {
+ thread_local Search search(models, lms, nbest > 0);
+
+ Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
+ return search.Decode(sentence, beamSize);
+ };
+ History history = call();
+
+ //History history = result.get();
std::string out = trgVocab(history.Top().first);
if(bpe)
out = bpe.unsplit(out);
std::cout << out << std::endl;
if(nbest > 0) {
- NBestList nbl = history.NBest(beamSize);
+ NBestList nbl = history.NBest(nbest);
for(size_t i = 0; i < nbl.size(); ++i) {
auto& r = nbl[i];
- std::cout << lineCounter << " ||| " << bpe.unsplit(trgVocab(r.first)) << " |||";
- for(size_t j = 0; j < r.second.GetCostBreakdown().size(); ++j) {
- std::cout << " F" << j << "=" << r.second.GetCostBreakdown()[j];
+ std::cout << lineCounter << " ||| " << (bpe ? bpe.unsplit(trgVocab(r.first)) : trgVocab(r.first, false)) << " |||";
+ for(size_t j = 0; j < r.second->GetCostBreakdown().size(); ++j) {
+ std::cout << " F" << j << "=" << r.second->GetCostBreakdown()[j];
}
- std::cout << " ||| " << r.second.GetCost() << std::endl;
+ std::cout << " ||| " << r.second->GetCost() << std::endl;
}
}
lineCounter++;
diff --git a/src/decoder/history.h b/src/decoder/history.h
index 390f62a8..c883aa39 100644
--- a/src/decoder/history.h
+++ b/src/decoder/history.h
@@ -15,13 +15,21 @@ class History {
};
public:
+ ~History() {
+ for(auto& b : history_)
+ for(auto h : b)
+ delete h;
+ }
+
void Add(const Beam& beam, bool last = false) {
- for(size_t j = 0; j < beam.size(); ++j)
- if(beam[j].GetWord() == EOS || last)
- topHyps_.push({ history_.size(), j, beam[j].GetCost() });
+ if(beam.back()->GetPrevHyp() != nullptr) {
+ for(size_t j = 0; j < beam.size(); ++j)
+ if(beam[j]->GetWord() == EOS || last)
+ topHyps_.push({ history_.size(), j, beam[j]->GetCost() });
+ }
history_.push_back(beam);
}
-
+
size_t size() const {
return history_.size();
}
@@ -37,10 +45,10 @@ class History {
size_t j = bestHypCoord.j;
Sentence targetWords;
- for(int i = start; i >= 0; i--) {
- auto& bestHyp = history_[i][j];
- targetWords.push_back(bestHyp.GetWord());
- j = bestHyp.GetPrevStateIndex();
+ const Hypothesis* bestHyp = history_[start][j];
+ while(bestHyp->GetPrevHyp() != nullptr) {
+ targetWords.push_back(bestHyp->GetWord());
+ bestHyp = bestHyp->GetPrevHyp();
}
std::reverse(targetWords.begin(), targetWords.end());
diff --git a/src/decoder/hypothesis.h b/src/decoder/hypothesis.h
index ce357414..bcec9772 100644
--- a/src/decoder/hypothesis.h
+++ b/src/decoder/hypothesis.h
@@ -1,21 +1,28 @@
#pragma once
#include "types.h"
+#include "kenlm.h"
class Hypothesis {
public:
- Hypothesis(size_t word, size_t prev, float cost)
- : prev_(prev),
+ Hypothesis(const Hypothesis* prevHyp, size_t word, size_t prevIndex, float cost)
+ : prevHyp_(prevHyp),
+ prevIndex_(prevIndex),
word_(word),
- cost_(cost) {
+ cost_(cost)
+ {}
+
+ const Hypothesis* GetPrevHyp() const {
+ return prevHyp_;
}
+
size_t GetWord() const {
return word_;
}
size_t GetPrevStateIndex() const {
- return prev_;
+ return prevIndex_;
}
float GetCost() const {
@@ -25,15 +32,25 @@ class Hypothesis {
std::vector<float>& GetCostBreakdown() {
return costBreakdown_;
}
+
+ void AddLMState(const KenlmState& state) {
+ lmStates_.push_back(state);
+ }
+
+ const std::vector<KenlmState>& GetLMStates() const {
+ return lmStates_;
+ }
private:
- const size_t prev_;
+ const Hypothesis* prevHyp_;
+ const size_t prevIndex_;
const size_t word_;
const float cost_;
+ std::vector<KenlmState> lmStates_;
std::vector<float> costBreakdown_;
};
-typedef std::vector<Hypothesis> Beam;
-typedef std::pair<Sentence, Hypothesis> Result;
+typedef std::vector<Hypothesis*> Beam;
+typedef std::pair<Sentence, Hypothesis*> Result;
typedef std::vector<Result> NBestList;
diff --git a/src/decoder/kenlm.cpp b/src/decoder/kenlm.cpp
new file mode 100644
index 00000000..203ece1f
--- /dev/null
+++ b/src/decoder/kenlm.cpp
@@ -0,0 +1,104 @@
+#include "kenlm.h"
+#include "lm/model.hh"
+
+class VocabGetter : public lm::EnumerateVocab {
+ public:
+ VocabGetter(WordPairs& vm, const Vocab& vocab)
+ : vm_(vm), vocab_(vocab)
+ {
+ vm_.emplace_back(2, EOS); // is there a constant for "</s>" = 2?
+ vm_.emplace_back(1, UNK); // is there a constant for "<s>" = 1?
+ vm_.emplace_back(lm::kUNK, UNK);
+ }
+
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ size_t word = vocab_[str.as_string()];
+ if(word > 2)
+ vm_.emplace_back(index, word);
+ }
+
+ private:
+ WordPairs& vm_;
+ const Vocab& vocab_;
+};
+
+KenlmState::KenlmState()
+: state_(new lm::ngram::State())
+{}
+
+KenlmState::KenlmState(const KenlmState& s)
+: state_(new lm::ngram::State())
+{
+ *state_ = *s.state_;
+}
+
+KenlmState& KenlmState::operator=(const KenlmState &s) {
+ *state_ = *s.state_;
+ return *this;
+
+}
+
+KenlmState::~KenlmState() {
+ delete state_;
+}
+
+lm::ngram::State& KenlmState::operator*() {
+ return *state_;
+}
+
+lm::ngram::State& KenlmState::operator*() const {
+ return *state_;
+}
+
+bool KenlmState::operator==(const KenlmState& o) {
+ return *state_ == *o.state_;
+}
+
+uint64_t hash_value(const KenlmState& s) {
+ //for(size_t i = 0; i < s.state_->length; i++)
+ // std::cerr << s.state_->words[i] << " ";
+ return lm::ngram::hash_value(*s.state_);
+}
+
+LM::LM(const std::string& path, const Vocab& vocab, size_t index, float weight)
+ : index_(index), weight_(weight) {
+ lm::ngram::Config config;
+ VocabGetter* vg = new VocabGetter(vm_, vocab);
+ config.enumerate_vocab = vg;
+ lm_.reset(new KenlmModel(path.c_str(), config));
+ delete vg;
+}
+
+LM::~LM() {}
+
+LM::LM(LM&& lm)
+ : lm_(std::move(lm.lm_)), vm_(std::move(lm.vm_)), index_(lm.index_), weight_(lm.weight_)
+{}
+
+float LM::Score(const KenlmState& in, lm::WordIndex index, KenlmState& out) const {
+ lm::ngram::State lout;
+ float cost = lm_->FullScore(*in, index, lout).prob;
+ *out = lout;
+ return cost;
+}
+
+void LM::BeginSentenceState(KenlmState &b) const {
+ *b = lm_->BeginSentenceState();
+}
+
+WordPairs::const_iterator LM::begin() const {
+ return vm_.begin();
+}
+
+WordPairs::const_iterator LM::end() const {
+ return vm_.end();
+}
+
+size_t LM::GetIndex () const {
+ return index_;
+}
+
+float LM::GetWeight() const {
+ return weight_;
+}
+
diff --git a/src/decoder/kenlm.h b/src/decoder/kenlm.h
new file mode 100644
index 00000000..ae352c88
--- /dev/null
+++ b/src/decoder/kenlm.h
@@ -0,0 +1,64 @@
+#pragma once
+
+#include <string>
+#include <vector>
+#include <memory>
+
+#include "vocab.h"
+
+namespace lm {
+ namespace ngram {
+ class ProbingModel;
+ class State;
+ }
+
+ typedef unsigned int WordIndex;
+}
+
+class KenlmState {
+ private:
+ lm::ngram::State* state_;
+
+ public:
+ KenlmState();
+ KenlmState(const KenlmState&);
+ KenlmState& operator=(const KenlmState&);
+ ~KenlmState();
+
+ KenlmState(KenlmState&&) = delete;
+
+
+ lm::ngram::State& operator*();
+ lm::ngram::State& operator*() const;
+
+ bool operator==(const KenlmState&);
+
+ friend uint64_t hash_value(const KenlmState&);
+
+};
+
+typedef std::pair<lm::WordIndex, Word> WordPair;
+typedef std::vector<WordPair> WordPairs;
+
+class LM {
+ private:
+ typedef lm::ngram::ProbingModel KenlmModel;
+
+ public:
+ LM(const std::string& path, const Vocab& vocab, size_t index, float weight);
+ LM(LM&& lm);
+ ~LM();
+
+ float Score(const KenlmState& in, lm::WordIndex index, KenlmState& out) const;
+ void BeginSentenceState(KenlmState&) const;
+ WordPairs::const_iterator begin() const;
+ WordPairs::const_iterator end() const;
+ size_t GetIndex() const;
+ float GetWeight() const;
+
+ private:
+ std::unique_ptr<KenlmModel> lm_;
+ WordPairs vm_;
+ size_t index_;
+ float weight_;
+};
diff --git a/src/decoder/search.h b/src/decoder/search.h
index 80dfd114..26e00bfa 100644
--- a/src/decoder/search.h
+++ b/src/decoder/search.h
@@ -7,7 +7,9 @@
#include <limits>
#include <sstream>
#include <queue>
+#include <set>
#include <boost/timer/timer.hpp>
+#include <thread>
#include <thrust/functional.h>
#include <thrust/device_vector.h>
@@ -22,7 +24,9 @@
#include "dl4mt.h"
#include "hypothesis.h"
#include "history.h"
-
+#include "kenlm.h"
+#include "threadpool.h"
+
class Search {
private:
@@ -40,23 +44,39 @@ class Search {
};
typedef std::unique_ptr<EncoderDecoder> EncoderDecoderPtr;
+
std::vector<EncoderDecoderPtr> encDecs_;
+ const std::vector<LM>& lms_;
bool doBreakdown_;
-
+ std::vector<mblas::Matrix> LmProbs_;
+
public:
- Search(const std::vector<std::unique_ptr<Weights>>& models, bool doBreakdown = false)
- : doBreakdown_(doBreakdown) {
+ Search(const std::vector<std::unique_ptr<Weights>>& models,
+ const std::vector<LM>& lms,
+ bool doBreakdown = false)
+ : lms_(lms),
+ doBreakdown_(doBreakdown)
+ {
+ cudaSetDevice(models[0]->GetDevice());
for(auto& m : models)
encDecs_.emplace_back(new EncoderDecoder(*m));
+ LmProbs_.resize(lms.size());
}
History Decode(const Sentence sourceWords, size_t beamSize = 12) {
using namespace mblas;
History history;
- Beam prevHyps;
- prevHyps.emplace_back(0, 0, 0.0);
- prevHyps.back().GetCostBreakdown().resize(encDecs_.size(), 0.0);
+
+ Hypothesis* bos = new Hypothesis(nullptr, 0, 0, 0.0);
+ bos->GetCostBreakdown().resize(encDecs_.size() + lms_.size(), 0.0);
+ for(auto& lm : lms_) {
+ KenlmState state;
+ lm.BeginSentenceState(state);
+ bos->AddLMState(state);
+ }
+ Beam prevHyps = { bos };
+ history.Add(prevHyps);
for(auto& encDec : encDecs_) {
encDec->encoder_.GetContext(sourceWords, encDec->SourceContext_);
@@ -64,6 +84,7 @@ class Search {
encDec->decoder_.EmptyEmbedding(encDec->Embeddings_, 1);
}
+ const size_t maxLength = sourceWords.size() * 3;
do {
std::vector<Matrix*> Probs;
for(auto& encDec : encDecs_) {
@@ -75,16 +96,16 @@ class Search {
Beam hyps;
BestHyps(hyps, prevHyps, Probs, beamSize);
- history.Add(hyps, history.size() + 1 == sourceWords.size() * 3);
+ history.Add(hyps, history.size() == maxLength);
Beam survivors;
std::vector<size_t> beamWords;
std::vector<size_t> beamStateIds;
- for(auto& h : hyps) {
- if(h.GetWord() != EOS) {
+ for(auto h : hyps) {
+ if(h->GetWord() != EOS) {
survivors.push_back(h);
- beamWords.push_back(h.GetWord());
- beamStateIds.push_back(h.GetPrevStateIndex());
+ beamWords.push_back(h->GetWord());
+ beamStateIds.push_back(h->GetPrevStateIndex());
}
}
beamSize = survivors.size();
@@ -101,11 +122,39 @@ class Search {
prevHyps.swap(survivors);
- } while(history.size() < sourceWords.size() * 3);
+ } while(history.size() <= maxLength);
return history;
}
+ void CalcLMProbs(mblas::Matrix& LmProbs, std::vector<KenlmState>& states,
+ const Beam& prevHyps, const LM& lm) {
+
+ size_t rows = LmProbs.Rows();
+ size_t cols = LmProbs.Cols();
+
+ std::vector<float> costs(rows * cols);
+ states.resize(rows * cols);
+
+ {
+ ThreadPool pool(4);
+ for(size_t i = 0; i < prevHyps.size(); i++) {
+ auto call = [i, cols, &prevHyps, &lm, &costs, &states] {
+ const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()];
+ KenlmState stateUnk;
+ float costUnk = lm.Score(state, 0, stateUnk);
+ std::fill(costs.begin() + i * cols, costs.begin() + i * cols + cols, costUnk);
+ std::fill(states.begin() + i * cols, states.begin() + i * cols + cols, stateUnk);
+ for(auto& wp : lm) {
+ costs[i * cols + wp.second] = lm.Score(state, wp.first, states[i * cols + wp.second]);
+ }
+ };
+ pool.enqueue(call);
+ }
+ }
+ thrust::copy(costs.begin(), costs.end(), LmProbs.begin());
+ }
+
void BestHyps(Beam& bestHyps, const Beam& prevHyps,
std::vector<mblas::Matrix*>& ProbsEnsemble,
const size_t beamSize) {
@@ -115,19 +164,29 @@ class Search {
Matrix Costs(Probs.Rows(), 1);
thrust::host_vector<float> vCosts;
- for(const Hypothesis& h : prevHyps)
- vCosts.push_back(h.GetCost());
+ for(const Hypothesis* h : prevHyps)
+ vCosts.push_back(h->GetCost());
thrust::copy(vCosts.begin(), vCosts.end(), Costs.begin());
BroadcastVecColumn(Log(_1) + _2, Probs, Costs);
for(size_t i = 1; i < ProbsEnsemble.size(); ++i)
Element(_1 + Log(_2), Probs, *ProbsEnsemble[i]);
-
+
+ std::vector<std::vector<KenlmState>> states(lms_.size());
+ if(!lms_.empty()) {
+ for(auto& lm : lms_) {
+ size_t index = lm.GetIndex();
+ LmProbs_[index].Resize(Probs.Rows(), Probs.Cols());
+ CalcLMProbs(LmProbs_[index], states[lm.GetIndex()], prevHyps, lm);
+ Element(_1 + lm.GetWeight() * _2, Probs, LmProbs_[index]);
+ }
+ }
+
thrust::device_vector<unsigned> keys(Probs.size());
thrust::host_vector<unsigned> bestKeys(beamSize);
thrust::host_vector<float> bestCosts(beamSize);
- // @TODO: Here it we need to have a partial sort
+ // @TODO: Here we need to have a partial sort
if(beamSize < 10) {
for(size_t i = 0; i < beamSize; ++i) {
thrust::device_vector<float>::iterator iter =
@@ -156,26 +215,41 @@ class Search {
thrust::copy(it, it + beamSize, modelCosts.begin());
breakDowns.push_back(modelCosts);
}
+ for(size_t i = 0; i < lms_.size(); ++i) {
+ thrust::host_vector<float> modelCosts(beamSize);
+ auto it = thrust::make_permutation_iterator(LmProbs_[i].begin(), keys.begin());
+ thrust::copy(it, it + beamSize, modelCosts.begin());
+ breakDowns.push_back(modelCosts);
+ }
}
for(size_t i = 0; i < beamSize; i++) {
size_t wordIndex = bestKeys[i] % Probs.Cols();
size_t hypIndex = bestKeys[i] / Probs.Cols();
float cost = bestCosts[i];
+ Hypothesis* hyp = new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost);
+ for(auto& lm : lms_)
+ hyp->AddLMState(states[lm.GetIndex()][bestKeys[i]]);
- Hypothesis hyp(wordIndex, hypIndex, cost);
if(doBreakdown_) {
float sum = 0;
- for(size_t j = 0; j < ProbsEnsemble.size(); ++j) {
+ for(size_t j = 0; j < ProbsEnsemble.size() + lms_.size(); ++j) {
if(j == 0)
- hyp.GetCostBreakdown().push_back(breakDowns[j][i]);
+ hyp->GetCostBreakdown().push_back(breakDowns[j][i]);
else {
- float cost = log(breakDowns[j][i]) + const_cast<Hypothesis&>(prevHyps[hypIndex]).GetCostBreakdown()[j];
- sum += cost;
- hyp.GetCostBreakdown().push_back(cost);
+ float cost = 0;
+ if(j < ProbsEnsemble.size()) {
+ cost = log(breakDowns[j][i]) + const_cast<Hypothesis*>(prevHyps[hypIndex])->GetCostBreakdown()[j];
+ sum += cost;
+ }
+ else {
+ cost = breakDowns[j][i] + const_cast<Hypothesis*>(prevHyps[hypIndex])->GetCostBreakdown()[j];
+ sum += lms_[j - ProbsEnsemble.size()].GetWeight() * cost;
+ }
+ hyp->GetCostBreakdown().push_back(cost);
}
}
- hyp.GetCostBreakdown()[0] -= sum;
+ hyp->GetCostBreakdown()[0] -= sum;
}
bestHyps.push_back(hyp);
}
diff --git a/src/dl4mt/decoder.h b/src/dl4mt/decoder.h
index 8f2f972e..3b9021fb 100644
--- a/src/dl4mt/decoder.h
+++ b/src/dl4mt/decoder.h
@@ -81,11 +81,11 @@ class Decoder {
Alignment(const Weights& model)
: w_(model)
{
- for(int i = 0; i < 2; ++i) {
- cudaStreamCreate(&s_[i]);
- cublasCreate(&h_[i]);
- cublasSetStream(h_[i], s_[i]);
- }
+ //for(int i = 0; i < 2; ++i) {
+ // cudaStreamCreate(&s_[i]);
+ // cublasCreate(&h_[i]);
+ // cublasSetStream(h_[i], s_[i]);
+ //}
}
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
@@ -93,11 +93,11 @@ class Decoder {
const mblas::Matrix& SourceContext) {
using namespace mblas;
- Prod(h_[0], Temp1_, SourceContext, w_.U_);
- Prod(h_[1], Temp2_, HiddenState, w_.W_);
- BroadcastVec(_1 + _2, Temp2_, w_.B_, s_[1]);
+ Prod(/*h_[0],*/ Temp1_, SourceContext, w_.U_);
+ Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_);
+ BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
- cudaDeviceSynchronize();
+ //cudaDeviceSynchronize();
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_);
@@ -132,11 +132,11 @@ class Decoder {
Softmax(const Weights& model)
: w_(model), filtered_(false)
{
- for(int i = 0; i < 3; ++i) {
- cudaStreamCreate(&s_[i]);
- cublasCreate(&h_[i]);
- cublasSetStream(h_[i], s_[i]);
- }
+ //for(int i = 0; i < 3; ++i) {
+ // cudaStreamCreate(&s_[i]);
+ // cublasCreate(&h_[i]);
+ // cublasSetStream(h_[i], s_[i]);
+ //}
}
void GetProbs(mblas::Matrix& Probs,
@@ -145,15 +145,15 @@ class Decoder {
const mblas::Matrix& AlignedSourceContext) {
using namespace mblas;
- Prod(h_[0], T1_, State, w_.W1_);
- Prod(h_[1], T2_, Embedding, w_.W2_);
- Prod(h_[2], T3_, AlignedSourceContext, w_.W3_);
+ Prod(/*h_[0],*/ T1_, State, w_.W1_);
+ Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
+ Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
- BroadcastVec(_1 + _2, T1_, w_.B1_, s_[0]);
- BroadcastVec(_1 + _2, T2_, w_.B2_, s_[1]);
- BroadcastVec(_1 + _2, T3_, w_.B3_, s_[2]);
+ BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
+ BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
+ BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
- cudaDeviceSynchronize();
+ //cudaDeviceSynchronize();
Element(Tanh(_1 + _2 + _3), T1_, T2_, T3_);
diff --git a/src/dl4mt/gru.h b/src/dl4mt/gru.h
index e09d9327..948f4520 100644
--- a/src/dl4mt/gru.h
+++ b/src/dl4mt/gru.h
@@ -121,15 +121,15 @@ class FastGRU {
// @TODO: Optimization
// @TODO: Launch streams to perform GEMMs in parallel
// @TODO: Join matrices and perform single GEMM --------
- Prod(h_[0], RU_, Context, w_.W_);
- Prod(h_[1], H_, Context, w_.Wx_);
+ Prod(/*h_[0],*/ RU_, Context, w_.W_);
+ Prod(/*h_[1],*/ H_, Context, w_.Wx_);
// -----------------------------------------------------
// @TODO: Join matrices and perform single GEMM --------
- Prod(h_[2], Temp1_, State, w_.U_);
- Prod(h_[3], Temp2_, State, w_.Ux_);
+ Prod(/*h_[2],*/ Temp1_, State, w_.U_);
+ Prod(/*h_[3],*/ Temp2_, State, w_.Ux_);
// -----------------------------------------------------
- cudaDeviceSynchronize();
+ //cudaDeviceSynchronize();
ElementwiseOps(NextState, State, RU_, H_, Temp1_, Temp2_);
}