Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-03-04 07:35:00 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-03-04 07:35:00 +0300
commitb88c3fcb71ad06326f60cc930fd283d77cd41a6c (patch)
treedd59a1cceb678e3c077beba3cf84b00d32e529db /src
parent42406cc715a301db2c3087d0def6ff70332d1074 (diff)
costs.cpp
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/common/timer.cpp0
-rw-r--r--src/models/costs.cpp16
-rw-r--r--src/models/costs.h7
4 files changed, 18 insertions, 6 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index a3155d68..c59d8bf6 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -87,6 +87,7 @@ set(MARIAN_SOURCES
models/model_factory.cpp
models/encoder_decoder.cpp
models/transformer_stub.cpp
+ models/costs.cpp
rescorer/score_collector.cpp
embedder/vector_collector.cpp
diff --git a/src/common/timer.cpp b/src/common/timer.cpp
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/common/timer.cpp
diff --git a/src/models/costs.cpp b/src/models/costs.cpp
new file mode 100644
index 00000000..5105f590
--- /dev/null
+++ b/src/models/costs.cpp
@@ -0,0 +1,16 @@
+#include "costs.h"
+
+namespace marian {
+namespace models {
+
+Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
+// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
+state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
+// @TODO: This is becoming more and more opaque ^^. Can we simplify this?
+return state;
+}
+
+
+}
+}
+
diff --git a/src/models/costs.h b/src/models/costs.h
index 3d8f2c51..2d34c53a 100644
--- a/src/models/costs.h
+++ b/src/models/costs.h
@@ -282,12 +282,7 @@ public:
class LogSoftmaxStep : public ILogProbStep {
public:
virtual ~LogSoftmaxStep() {}
- virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
- // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
- state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
- // @TODO: This is becoming more and more opaque ^^. Can we simplify this?
- return state;
- }
+ virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// Gumbel-max noising for sampling during beam-search