diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-19 17:47:41 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-19 17:47:41 +0300 |
commit | 3c62826dcd87ee0ae2de0c7f4f16bc27196713d4 (patch) | |
tree | 85ecef954046562925e39c759ea0dee6198a328d | |
parent | fb067264cb662a07717963efddc8948c32dd5405 (diff) |
Move EncoderDecoderState to its own header
-rw-r--r-- | src/amun/cpu/decoder/encoder_decoder.cpp | 12 | ||||
-rw-r--r-- | src/amun/cpu/decoder/encoder_decoder.h | 59 | ||||
-rw-r--r-- | src/amun/cpu/dl4mt/encoder_decoder_state.h | 30 |
3 files changed, 69 insertions, 32 deletions
diff --git a/src/amun/cpu/decoder/encoder_decoder.cpp b/src/amun/cpu/decoder/encoder_decoder.cpp index d1714f59..df808dc5 100644 --- a/src/amun/cpu/decoder/encoder_decoder.cpp +++ b/src/amun/cpu/decoder/encoder_decoder.cpp @@ -128,16 +128,6 @@ void EncoderDecoder::Filter(const std::vector<size_t>& filterIds) { } -dl4mt::Encoder& EncoderDecoder::GetEncoder() { - return *encoder_; -} - - -dl4mt::Decoder& EncoderDecoder::GetDecoder() { - return *decoder_; -} - - BaseMatrix& EncoderDecoder::GetProbs() { return decoder_->GetProbs(); } @@ -154,7 +144,7 @@ void EncoderDecoderLoader::Load(const God&) { weights_.emplace_back(new dl4mt::Weights(path, 0)); } -ScorerPtr EncoderDecoderLoader::NewScorer(const God&, const DeviceInfo& deviceInfo) const { +ScorerPtr EncoderDecoderLoader::NewScorer(const God&, const DeviceInfo&) const { size_t tab = Has("tab") ? Get<size_t>("tab") : 0; return ScorerPtr(new EncoderDecoder(name_, config_, tab, *weights_[0])); diff --git a/src/amun/cpu/decoder/encoder_decoder.h b/src/amun/cpu/decoder/encoder_decoder.h index 99062e10..0e722dd3 100644 --- a/src/amun/cpu/decoder/encoder_decoder.h +++ b/src/amun/cpu/decoder/encoder_decoder.h @@ -9,8 +9,11 @@ #include "common/scorer.h" #include "cpu/dl4mt/dl4mt.h" +#include "cpu/nematus/encoder.h" +#include "cpu/nematus/decoder.h" #include "cpu/mblas/matrix.h" +#include "cpu/decoder/encoder_decoder_state.h" namespace amunmt { @@ -23,36 +26,54 @@ class Encoder; class Decoder; } -class EncoderDecoderState : public State { +class EncoderDecoder : public Scorer { + private: + using EDState = EncoderDecoderState; + public: - EncoderDecoderState(); - EncoderDecoderState(const EncoderDecoderState&) = delete; + EncoderDecoder(const std::string& name, + const YAML::Node& config, + size_t tab, + const dl4mt::Weights& model); + + virtual void Decode(const State& in, State& out, const std::vector<size_t>& beamSizes); - virtual std::string Debug() const; + virtual State* NewState() const; + + virtual void BeginSentenceState(State& state, size_t batchSize); + + virtual void SetSource(const Sentences& sources); + + virtual void AssembleBeamState(const State& in, + const Beam& beam, + State& out); - CPU::mblas::Matrix& GetStates(); + void GetAttention(mblas::Matrix& Attention); + mblas::Matrix& GetAttention(); - CPU::mblas::Matrix& GetEmbeddings(); + size_t GetVocabSize() const; - const CPU::mblas::Matrix& GetStates() const; + BaseMatrix& GetProbs(); - const CPU::mblas::Matrix& GetEmbeddings() const; + void Filter(const std::vector<size_t>& filterIds); private: - CPU::mblas::Matrix states_; - CPU::mblas::Matrix embeddings_; -}; + const dl4mt::Weights& model_; + std::unique_ptr<dl4mt::Encoder> encoder_; + std::unique_ptr<dl4mt::Decoder> decoder_; + mblas::Matrix SourceContext_; +}; -class EncoderDecoder : public Scorer { +class NematusEncoderDecoder : public Scorer { private: using EDState = EncoderDecoderState; public: - EncoderDecoder(const std::string& name, + NematusEncoderDecoder(const std::string& name, const YAML::Node& config, size_t tab, - const dl4mt::Weights& model); + const Nematus::Weights& model); virtual void Decode(const State& in, State& out, const std::vector<size_t>& beamSizes); @@ -75,14 +96,10 @@ class EncoderDecoder : public Scorer { void Filter(const std::vector<size_t>& filterIds); - dl4mt::Encoder& GetEncoder(); - - dl4mt::Decoder& GetDecoder(); - private: - const dl4mt::Weights& model_; - std::unique_ptr<dl4mt::Encoder> encoder_; - std::unique_ptr<dl4mt::Decoder> decoder_; + const Nematus::Weights& model_; + std::unique_ptr<Nematus::Encoder> encoder_; + std::unique_ptr<Nematus::Decoder> decoder_; mblas::Matrix SourceContext_; }; diff --git a/src/amun/cpu/dl4mt/encoder_decoder_state.h b/src/amun/cpu/dl4mt/encoder_decoder_state.h new file mode 100644 index 00000000..ae4fbe65 --- /dev/null +++ b/src/amun/cpu/dl4mt/encoder_decoder_state.h @@ -0,0 +1,30 @@ +#pragma once + +#include <vector> + +#include "cpu/mblas/matrix.h" +#include "common/scorer.h" + +namespace amunmt { +namespace CPU { + +class EncoderDecoderState : public State { + public: + EncoderDecoderState(); + EncoderDecoderState(const EncoderDecoderState&) = delete; + + virtual std::string Debug() const; + + CPU::mblas::Matrix& GetStates(); + const CPU::mblas::Matrix& GetStates() const; + + CPU::mblas::Matrix& GetEmbeddings(); + const CPU::mblas::Matrix& GetEmbeddings() const; + + private: + CPU::mblas::Matrix states_; + CPU::mblas::Matrix embeddings_; +}; + +} // namespace CPU +} // namespace amunmt |