diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-03-04 07:35:00 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-03-04 07:35:00 +0300 |
commit | b88c3fcb71ad06326f60cc930fd283d77cd41a6c (patch) | |
tree | dd59a1cceb678e3c077beba3cf84b00d32e529db /src | |
parent | 42406cc715a301db2c3087d0def6ff70332d1074 (diff) |
costs.cpp
Diffstat (limited to 'src')
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/common/timer.cpp | 0 | ||||
-rw-r--r-- | src/models/costs.cpp | 16 | ||||
-rw-r--r-- | src/models/costs.h | 7 |
4 files changed, 18 insertions, 6 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a3155d68..c59d8bf6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,6 +87,7 @@ set(MARIAN_SOURCES models/model_factory.cpp models/encoder_decoder.cpp models/transformer_stub.cpp + models/costs.cpp rescorer/score_collector.cpp embedder/vector_collector.cpp diff --git a/src/common/timer.cpp b/src/common/timer.cpp new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/common/timer.cpp diff --git a/src/models/costs.cpp b/src/models/costs.cpp new file mode 100644 index 00000000..5105f590 --- /dev/null +++ b/src/models/costs.cpp @@ -0,0 +1,16 @@ +#include "costs.h" + +namespace marian { +namespace models { + +Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) { +// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) +state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); +// @TODO: This is becoming more and more opaque ^^. Can we simplify this? +return state; +} + + +} +} + diff --git a/src/models/costs.h b/src/models/costs.h index 3d8f2c51..2d34c53a 100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -282,12 +282,7 @@ public: class LogSoftmaxStep : public ILogProbStep { public: virtual ~LogSoftmaxStep() {} - virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override { - // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) - state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); - // @TODO: This is becoming more and more opaque ^^. Can we simplify this? - return state; - } + virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override; }; // Gumbel-max noising for sampling during beam-search |