Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-19 17:47:41 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-19 17:47:41 +0300
commit3c62826dcd87ee0ae2de0c7f4f16bc27196713d4 (patch)
tree85ecef954046562925e39c759ea0dee6198a328d
parentfb067264cb662a07717963efddc8948c32dd5405 (diff)
Move EncoderDecoderState to its own header
-rw-r--r--src/amun/cpu/decoder/encoder_decoder.cpp12
-rw-r--r--src/amun/cpu/decoder/encoder_decoder.h59
-rw-r--r--src/amun/cpu/dl4mt/encoder_decoder_state.h30
3 files changed, 69 insertions, 32 deletions
diff --git a/src/amun/cpu/decoder/encoder_decoder.cpp b/src/amun/cpu/decoder/encoder_decoder.cpp
index d1714f59..df808dc5 100644
--- a/src/amun/cpu/decoder/encoder_decoder.cpp
+++ b/src/amun/cpu/decoder/encoder_decoder.cpp
@@ -128,16 +128,6 @@ void EncoderDecoder::Filter(const std::vector<size_t>& filterIds) {
}
-dl4mt::Encoder& EncoderDecoder::GetEncoder() {
- return *encoder_;
-}
-
-
-dl4mt::Decoder& EncoderDecoder::GetDecoder() {
- return *decoder_;
-}
-
-
BaseMatrix& EncoderDecoder::GetProbs() {
return decoder_->GetProbs();
}
@@ -154,7 +144,7 @@ void EncoderDecoderLoader::Load(const God&) {
weights_.emplace_back(new dl4mt::Weights(path, 0));
}
-ScorerPtr EncoderDecoderLoader::NewScorer(const God&, const DeviceInfo& deviceInfo) const {
+ScorerPtr EncoderDecoderLoader::NewScorer(const God&, const DeviceInfo&) const {
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
return ScorerPtr(new EncoderDecoder(name_, config_,
tab, *weights_[0]));
diff --git a/src/amun/cpu/decoder/encoder_decoder.h b/src/amun/cpu/decoder/encoder_decoder.h
index 99062e10..0e722dd3 100644
--- a/src/amun/cpu/decoder/encoder_decoder.h
+++ b/src/amun/cpu/decoder/encoder_decoder.h
@@ -9,8 +9,11 @@
#include "common/scorer.h"
#include "cpu/dl4mt/dl4mt.h"
+#include "cpu/nematus/encoder.h"
+#include "cpu/nematus/decoder.h"
#include "cpu/mblas/matrix.h"
+#include "cpu/decoder/encoder_decoder_state.h"
namespace amunmt {
@@ -23,36 +26,54 @@ class Encoder;
class Decoder;
}
-class EncoderDecoderState : public State {
+class EncoderDecoder : public Scorer {
+ private:
+ using EDState = EncoderDecoderState;
+
public:
- EncoderDecoderState();
- EncoderDecoderState(const EncoderDecoderState&) = delete;
+ EncoderDecoder(const std::string& name,
+ const YAML::Node& config,
+ size_t tab,
+ const dl4mt::Weights& model);
+
+ virtual void Decode(const State& in, State& out, const std::vector<size_t>& beamSizes);
- virtual std::string Debug() const;
+ virtual State* NewState() const;
+
+ virtual void BeginSentenceState(State& state, size_t batchSize);
+
+ virtual void SetSource(const Sentences& sources);
+
+ virtual void AssembleBeamState(const State& in,
+ const Beam& beam,
+ State& out);
- CPU::mblas::Matrix& GetStates();
+ void GetAttention(mblas::Matrix& Attention);
+ mblas::Matrix& GetAttention();
- CPU::mblas::Matrix& GetEmbeddings();
+ size_t GetVocabSize() const;
- const CPU::mblas::Matrix& GetStates() const;
+ BaseMatrix& GetProbs();
- const CPU::mblas::Matrix& GetEmbeddings() const;
+ void Filter(const std::vector<size_t>& filterIds);
private:
- CPU::mblas::Matrix states_;
- CPU::mblas::Matrix embeddings_;
-};
+ const dl4mt::Weights& model_;
+ std::unique_ptr<dl4mt::Encoder> encoder_;
+ std::unique_ptr<dl4mt::Decoder> decoder_;
+ mblas::Matrix SourceContext_;
+};
-class EncoderDecoder : public Scorer {
+class NematusEncoderDecoder : public Scorer {
private:
using EDState = EncoderDecoderState;
public:
- EncoderDecoder(const std::string& name,
+ NematusEncoderDecoder(const std::string& name,
const YAML::Node& config,
size_t tab,
- const dl4mt::Weights& model);
+ const Nematus::Weights& model);
virtual void Decode(const State& in, State& out, const std::vector<size_t>& beamSizes);
@@ -75,14 +96,10 @@ class EncoderDecoder : public Scorer {
void Filter(const std::vector<size_t>& filterIds);
- dl4mt::Encoder& GetEncoder();
-
- dl4mt::Decoder& GetDecoder();
-
private:
- const dl4mt::Weights& model_;
- std::unique_ptr<dl4mt::Encoder> encoder_;
- std::unique_ptr<dl4mt::Decoder> decoder_;
+ const Nematus::Weights& model_;
+ std::unique_ptr<Nematus::Encoder> encoder_;
+ std::unique_ptr<Nematus::Decoder> decoder_;
mblas::Matrix SourceContext_;
};
diff --git a/src/amun/cpu/dl4mt/encoder_decoder_state.h b/src/amun/cpu/dl4mt/encoder_decoder_state.h
new file mode 100644
index 00000000..ae4fbe65
--- /dev/null
+++ b/src/amun/cpu/dl4mt/encoder_decoder_state.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <vector>
+
+#include "cpu/mblas/matrix.h"
+#include "common/scorer.h"
+
+namespace amunmt {
+namespace CPU {
+
+class EncoderDecoderState : public State {
+ public:
+ EncoderDecoderState();
+ EncoderDecoderState(const EncoderDecoderState&) = delete;
+
+ virtual std::string Debug() const;
+
+ CPU::mblas::Matrix& GetStates();
+ const CPU::mblas::Matrix& GetStates() const;
+
+ CPU::mblas::Matrix& GetEmbeddings();
+ const CPU::mblas::Matrix& GetEmbeddings() const;
+
+ private:
+ CPU::mblas::Matrix states_;
+ CPU::mblas::Matrix embeddings_;
+};
+
+} // namespace CPU
+} // namespace amunmt