Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-14 21:49:38 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-14 21:49:38 +0300
commit19e71f8061c26c517aa751cd4e7b66a1d75f0bcd (patch)
tree3c12cecb6fdd4755dc1a88c5b878b0cab62dc517
parent436f0bd52ee584dd8698dd6cc6c8e3cdb0eb728c (diff)
towards n-best-lists
-rw-r--r--CMakeLists.txt4
-rw-r--r--src/common/vocab.h39
-rw-r--r--src/decoder/decoder_main.cu43
-rw-r--r--src/decoder/search.h123
-rw-r--r--src/rescorer/nbest.cpp10
-rw-r--r--src/rescorer/nbest.h2
6 files changed, 134 insertions, 87 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5fa647eb..76e90b45 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.1.0)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(amunn CXX)
-SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -fPIC -funroll-loops -Wno-unused-result -Wno-deprecated")
+SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
#SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
-SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -fPIC -arch=sm_35 -lineinfo --use_fast_math")
+SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -arch=sm_35 -lineinfo --use_fast_math")
#SET(CUDA_VERBOSE_BUILD ON)
include_directories(${amunn_SOURCE_DIR})
diff --git a/src/common/vocab.h b/src/common/vocab.h
index f70c3069..f8997041 100644
--- a/src/common/vocab.h
+++ b/src/common/vocab.h
@@ -4,6 +4,10 @@
#include <string>
#include <vector>
#include <fstream>
+#include <sstream>
+
+#include "types.h"
+#include "utils.h"
class Vocab {
public:
@@ -15,8 +19,6 @@ class Vocab {
str2id_[line] = c++;
id2str_.push_back(line);
}
- //str2id_["</s>"] = c;
- //id2str_.push_back("</s>");
}
size_t operator[](const std::string& word) const {
@@ -27,15 +29,32 @@ class Vocab {
return 1;
}
- inline std::vector<size_t> Encode(const std::vector<std::string>& sentence, bool addEOS=false) const {
- std::vector<size_t> indexes;
- for (auto& word: sentence) {
- indexes.push_back((*this)[word]);
- }
- if (addEOS) {
- indexes.push_back((*this)["</s>"]);
+ Sentence operator()(const std::vector<std::string>& lineTokens, bool addEOS = true) const {
+ Sentence words(lineTokens.size());
+ std::transform(lineTokens.begin(), lineTokens.end(), words.begin(),
+ [&](const std::string& w) { return (*this)[w]; });
+ if(addEOS)
+ words.push_back(EOS);
+ return words;
+ }
+
+ Sentence operator()(const std::string& line, bool addEOS = true) const {
+ std::vector<std::string> lineTokens;
+ Split(line, lineTokens, " ");
+ return (*this)(lineTokens, addEOS);
+ }
+
+ std::string operator()(const Sentence& sentence, bool ignoreEOS = true) const {
+ std::stringstream line;
+ for(size_t i = 0; i < sentence.size(); ++i) {
+ if(sentence[i] != EOS || !ignoreEOS) {
+ if(i > 0) {
+ line << " ";
+ }
+ line << (*this)[sentence[i]];
+ }
}
- return indexes;
+ return line.str();
}
diff --git a/src/decoder/decoder_main.cu b/src/decoder/decoder_main.cu
index 2dfe0d2f..292021f8 100644
--- a/src/decoder/decoder_main.cu
+++ b/src/decoder/decoder_main.cu
@@ -57,6 +57,33 @@ void ProgramOptions(int argc, char *argv[],
}
}
+class BPE {
+ public:
+ BPE(const std::string& sep = "@@ ")
+ : sep_(sep) {}
+
+ std::string split(const std::string& line) {
+ return line;
+ }
+
+ std::string unsplit(const std::string& line) {
+ std::string joined = line;
+ size_t pos = joined.find(sep_);
+ while(pos != std::string::npos) {
+ joined.erase(pos, sep_.size());
+ pos = joined.find(sep_, pos);
+ }
+ return joined;
+ }
+
+ operator bool() const {
+ return true;
+ }
+
+ private:
+ std::string sep_;
+};
+
int main(int argc, char* argv[]) {
std::string modelPath, srcVocabPath, trgVocabPath;
size_t device = 0;
@@ -70,17 +97,23 @@ int main(int argc, char* argv[]) {
Vocab trgVocab(trgVocabPath);
std::cerr << "done." << std::endl;
- Search search(model, srcVocab, trgVocab);
+ Search search(model);
std::cerr << "Translating...\n";
std::ios_base::sync_with_stdio(false);
- std::string line;
+ BPE bpe;
+
boost::timer::cpu_timer timer;
- while(std::getline(std::cin, line)) {
- auto result = search.Decode(line, beamSize);
- std::cout << result << std::endl;
+ std::string in;
+ while(std::getline(std::cin, in)) {
+ Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
+ History history = search.Decode(sentence, beamSize);
+ std::string out = trgVocab(history.Top().first);
+ if(bpe)
+ out = bpe.unsplit(out);
+ std::cout << out << std::endl;
}
std::cerr << timer.format() << std::endl;
return 0;
diff --git a/src/decoder/search.h b/src/decoder/search.h
index 4b73675e..e674183d 100644
--- a/src/decoder/search.h
+++ b/src/decoder/search.h
@@ -6,6 +6,7 @@
#include <algorithm>
#include <limits>
#include <sstream>
+#include <queue>
#include <boost/timer/timer.hpp>
#include <thrust/functional.h>
@@ -16,24 +17,70 @@
#include <thrust/sort.h>
#include <thrust/sequence.h>
+#include "types.h"
#include "matrix.h"
#include "dl4mt.h"
-#include "vocab.h"
#include "hypothesis.h"
#include "utils.h"
-#define EOL "</s>"
+typedef std::vector<Hypothesis> Beam;
+typedef std::pair<Sentence, Hypothesis> Result;
+typedef std::vector<Result> NBestList;
+
+class History {
+ private:
+ struct HypothesisCoord {
+ bool operator<(const HypothesisCoord& hc) const {
+ return cost < hc.cost;
+ }
+
+ size_t i;
+ size_t j;
+ float cost;
+ };
+
+ public:
+ void Add(const Beam& beam, bool last = false) {
+ for(size_t j = 0; j < beam.size(); ++j)
+ if(beam[j].GetWord() == EOS || last)
+ topHyps_.push({ history_.size(), j, beam[j].GetCost() });
+ history_.push_back(beam);
+ }
+
+ size_t size() const {
+ return history_.size();
+ }
+
+ NBestList NBest(size_t n) {
+
+ }
+
+ Result Top() const {
+ Sentence targetWords;
+ auto bestHypCoord = topHyps_.top();
+ size_t start = bestHypCoord.i;
+ size_t j = bestHypCoord.j;
+ for(int i = start; i >= 0; i--) {
+ auto& bestHyp = history_[i][j];
+ targetWords.push_back(bestHyp.GetWord());
+ j = bestHyp.GetPrevStateIndex();
+ }
+
+ std::reverse(targetWords.begin(), targetWords.end());
+ return Result(targetWords, history_[bestHypCoord.i][bestHypCoord.j]);
+ }
+
+ private:
+ std::vector<Beam> history_;
+ std::priority_queue<HypothesisCoord> topHyps_;
+
+};
class Search {
- typedef std::vector<Hypothesis> Beam;
- typedef std::vector<Beam> History;
-
private:
const Weights& model_;
Encoder encoder_;
Decoder decoder_;
- const Vocab svcb_;
- const Vocab tvcb_;
mblas::Matrix State_, NextState_, BeamState_;
mblas::Matrix Embeddings_, NextEmbeddings_;
@@ -41,24 +88,14 @@ class Search {
mblas::Matrix SourceContext_;
public:
- Search(const Weights& model, const Vocab& svcb, const Vocab tvcb)
+ Search(const Weights& model)
: model_(model),
encoder_(model_),
- decoder_(model_),
- svcb_(svcb), tvcb_(tvcb)
+ decoder_(model_)
{}
- std::string Decode(const std::string& source, size_t beamSize = 12) {
- // this should happen somewhere else
- std::vector<std::string> sourceSplit;
- Split(source, sourceSplit, " ");
- std::vector<size_t> sourceWords(sourceSplit.size());
- std::transform(sourceSplit.begin(), sourceSplit.end(), sourceWords.begin(),
- [&](const std::string& w) { return svcb_[w]; });
- sourceWords.push_back(svcb_[EOL]);
-
+ History Decode(const Sentence sourceWords, size_t beamSize = 12) {
encoder_.GetContext(sourceWords, SourceContext_);
-
decoder_.EmptyState(State_, SourceContext_, 1);
decoder_.EmptyEmbedding(Embeddings_, 1);
@@ -72,13 +109,13 @@ class Search {
Beam hyps;
BestHyps(hyps, prevHyps, Probs_, beamSize);
- history.push_back(hyps);
+ history.Add(hyps, history.size() + 1 == sourceWords.size() * 3);
Beam survivors;
std::vector<size_t> beamWords;
std::vector<size_t> beamStateIds;
for(auto& h : hyps) {
- if(h.GetWord() != tvcb_[EOL]) {
+ if(h.GetWord() != EOS) {
survivors.push_back(h);
beamWords.push_back(h.GetWord());
beamStateIds.push_back(h.GetPrevStateIndex());
@@ -98,7 +135,7 @@ class Search {
} while(history.size() < sourceWords.size() * 3);
- return FindBest(history);
+ return history;
}
void BestHyps(Beam& bestHyps, const Beam& prevHyps, mblas::Matrix& Probs, const size_t beamSize) {
@@ -131,44 +168,4 @@ class Search {
bestHyps.emplace_back(wordIndex, hypIndex, cost);
}
}
-
- std::string FindBest(const History& history) {
- std::vector<size_t> targetWords;
-
- size_t best = 0;
- size_t beamSize = 0;
- float bestCost = std::numeric_limits<float>::lowest();
-
- for(auto b = history.rbegin(); b != history.rend(); b++) {
- if(b->size() > beamSize) {
- beamSize = b->size();
- for(size_t i = 0; i < beamSize; ++i) {
- if(b == history.rbegin() || (*b)[i].GetWord() == tvcb_[EOL]) {
- if((*b)[i].GetCost() > bestCost) {
- best = i;
- bestCost = (*b)[i].GetCost();
- targetWords.clear();
- }
- }
- }
- }
-
- auto& bestHyp = (*b)[best];
- targetWords.push_back(bestHyp.GetWord());
- best = bestHyp.GetPrevStateIndex();
- }
-
- std::reverse(targetWords.begin(), targetWords.end());
- std::stringstream translation;
- for(size_t i = 0; i < targetWords.size(); ++i) {
- if(tvcb_[targetWords[i]] != EOL) {
- if(i > 0) {
- translation << " ";
- }
- translation << tvcb_[targetWords[i]];
- }
- }
- return translation.str();
- }
-
}; \ No newline at end of file
diff --git a/src/rescorer/nbest.cpp b/src/rescorer/nbest.cpp
index e7135fa6..e3185092 100644
--- a/src/rescorer/nbest.cpp
+++ b/src/rescorer/nbest.cpp
@@ -39,9 +39,7 @@ std::vector<std::string> NBest::GetTokens(const size_t index) const {
}
std::vector<size_t> NBest::GetEncodedTokens(const size_t index) const {
- std::vector<std::string> tokens;
- Split(srcSentences_[index], tokens);
- return srcVocab_->Encode(tokens, true);
+ return (*srcVocab_)(srcSentences_[index]);
}
void NBest::Parse_(const std::string& path) {
@@ -76,10 +74,10 @@ inline std::vector< std::vector< std::string > > NBest::SplitBatch(std::vector<s
return splittedBatch;
}
-inline Batch NBest::EncodeBatch(const std::vector<std::vector<std::string>>& batch) const {
+inline Batch NBest::EncodeBatch(const std::vector<std::string>& batch) const {
Batch encodedBatch;
for (auto& sentence: batch) {
- encodedBatch.push_back(trgVocab_->Encode(sentence, true));
+ encodedBatch.push_back((*trgVocab_)(sentence));
}
return encodedBatch;
}
@@ -103,7 +101,7 @@ inline Batch NBest::MaskAndTransposeBatch(const Batch& batch) const {
Batch NBest::ProcessBatch(std::vector<std::string>& batch) const {
- return MaskAndTransposeBatch(EncodeBatch(SplitBatch(batch)));
+ return MaskAndTransposeBatch(EncodeBatch(batch));
}
std::vector<Batch> NBest::GetBatches(const size_t index) const {
diff --git a/src/rescorer/nbest.h b/src/rescorer/nbest.h
index 7c91aeec..d5ad6eac 100644
--- a/src/rescorer/nbest.h
+++ b/src/rescorer/nbest.h
@@ -41,7 +41,7 @@ class NBest {
std::vector<std::vector<std::string>> SplitBatch(std::vector<std::string>& batch) const;
void ParseInputFile(const std::string& path);
- Batch EncodeBatch(const std::vector<std::vector<std::string>>& batch) const;
+ Batch EncodeBatch(const std::vector<std::string>& batch) const;
Batch MaskAndTransposeBatch(const Batch& batch) const;