Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-27 13:47:47 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-27 13:47:47 +0300
commit2bf3026097033773f14368e5219346ffc812c37f (patch)
tree6ae1181df774086f375193362c453056443f9e27
parentc2059610f5ae37d7c2c943a4762776579195c40e (diff)
APE penalty
-rw-r--r--CMakeLists.txt4
-rw-r--r--src/decoder/ape_penalty.h59
-rw-r--r--src/decoder/god.cu11
3 files changed, 72 insertions, 2 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6009bcc4..1effd778 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(amunn CXX)
-SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
-LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
+SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
+LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;)
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
diff --git a/src/decoder/ape_penalty.h b/src/decoder/ape_penalty.h
new file mode 100644
index 00000000..adbdd1eb
--- /dev/null
+++ b/src/decoder/ape_penalty.h
@@ -0,0 +1,59 @@
+#pragma once
+
+#include <vector>
+
+#include "types.h"
+#include "scorer.h"
+#include "matrix.h"
+
+class ApePenaltyState : public State {
+ // Dummy
+};
+
+class ApePenalty : public Scorer {
+
+ public:
+ ApePenalty(size_t sourceIndex)
+ : Scorer(sourceIndex)
+ { }
+
+ virtual void SetSource(const Sentence& source) {
+ const Words& words = source.GetWords(sourceIndex_);
+ const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
+ const Vocab& tvcb = God::GetTargetVocab();
+
+ costs_.clear();
+ costs_.resize(tvcb.size(), -1.0);
+ for(auto& s : words) {
+ const std::string& sstr = svcb[s];
+ Word t = tvcb[sstr];
+ if(t != UNK && t < costs_.size())
+ costs_[t] = 0.0;
+ }
+ }
+
+ virtual void Score(const State& in,
+ Prob& prob,
+ State& out) {
+ size_t cols = prob.Cols();
+ for(size_t i = 0; i < prob.Rows(); ++i)
+ algo::copy(costs_.begin(), costs_.begin() + cols, prob.begin() + i * cols);
+ }
+
+ virtual State* NewState() {
+ return new ApePenaltyState();
+ }
+
+ virtual void BeginSentenceState(State& state) { }
+
+ virtual void AssembleBeamState(const State& in,
+ const Beam& beam,
+ State& out) { }
+
+ virtual size_t GetVocabSize() const {
+ return 0;
+ }
+
+ private:
+ std::vector<float> costs_;
+};
diff --git a/src/decoder/god.cu b/src/decoder/god.cu
index eb168d4c..28b49e0e 100644
--- a/src/decoder/god.cu
+++ b/src/decoder/god.cu
@@ -6,6 +6,7 @@
#include "threadpool.h"
#include "encoder_decoder.h"
#include "language_model.h"
+#include "ape_penalty.h"
God God::instance_;
@@ -45,6 +46,8 @@ God& God::NonStaticInit(int argc, char** argv) {
"Path to source vocabulary file.")
("target,t", po::value(&targetVocabPath)->required(),
"Path to target vocabulary file.")
+ ("ape", po::value<bool>()->zero_tokens()->default_value(false),
+ "Add APE-penalty")
("lm,l", po::value(&lmPaths)->multitoken(),
"Path to KenLM language model(s)")
("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
@@ -124,11 +127,17 @@ God& God::NonStaticInit(int argc, char** argv) {
tabMap_.resize(modelPaths.size(), 0);
}
+ // @TODO: handle this better!
if(weights_.size() < modelPaths.size()) {
// this should be a warning
LOG(info) << "More neural models than weights, setting weights to 1.0";
weights_.resize(modelPaths.size(), 1.0);
}
+
+ if(Get<bool>("ape") && weights_.size() < modelPaths.size() + 1) {
+ LOG(info) << "Adding weight for APE-penalty: " << 1.0;
+ weights_.resize(modelPaths.size(), 1.0);
+ }
if(weights_.size() < modelPaths.size() + lmPaths.size()) {
// this should be a warning
@@ -186,6 +195,8 @@ std::vector<ScorerPtr> God::GetScorers(size_t threadId) {
size_t i = 0;
for(auto& m : Summon().modelsPerDevice_[deviceId])
scorers.emplace_back(new EncoderDecoder(*m, Summon().tabMap_[i++]));
+ if(God::Get<bool>("ape"))
+ scorers.emplace_back(new ApePenalty(Summon().tabMap_[i++]));
for(auto& lm : Summon().lms_)
scorers.emplace_back(new LanguageModel(lm));
return scorers;