Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt18
-rw-r--r--scripts/idf.py9
-rw-r--r--src/CMakeLists.txt2
-rw-r--r--src/common/exception.cpp10
-rw-r--r--src/common/exception.h6
-rw-r--r--src/common/file_stream.h42
-rw-r--r--src/common/vocab.h8
-rw-r--r--src/decoder/ape_penalty.h59
-rw-r--r--src/decoder/config.cpp3
-rw-r--r--src/decoder/encoder_decoder.h12
-rw-r--r--src/decoder/god.cu3
-rw-r--r--src/decoder/god.h3
-rw-r--r--src/decoder/scorer.h11
13 files changed, 136 insertions, 50 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8b939813..9cd42916 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -10,7 +10,7 @@ SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
include_directories(${amunn_SOURCE_DIR})
find_package(CUDA REQUIRED)
-find_package(Boost COMPONENTS system filesystem program_options timer)
+find_package(Boost COMPONENTS system filesystem program_options timer iostreams)
if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS})
set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
@@ -24,15 +24,15 @@ if (YAMLCPP_FOUND)
set(EXT_LIBS ${EXT_LIBS} ${YAMLCPP_LIBRARY})
endif (YAMLCPP_FOUND)
-set(KENLM CACHE STRING "Path to compiled kenlm directory")
-if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
- message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
-endif()
+#set(KENLM CACHE STRING "Path to compiled kenlm directory")
+#if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
+# message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
+#endif()
-set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm.a)
-set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm_util.a)
-include_directories(${KENLM})
-add_definitions(-DKENLM_MAX_ORDER=6)
+#set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm.a)
+#set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm_util.a)
+#include_directories(${KENLM})
+#add_definitions(-DKENLM_MAX_ORDER=6)
find_package (BZip2)
if (BZIP2_FOUND)
diff --git a/scripts/idf.py b/scripts/idf.py
index efd05e9a..60905d80 100644
--- a/scripts/idf.py
+++ b/scripts/idf.py
@@ -1,5 +1,6 @@
import sys
import math
+import yaml
from collections import Counter
c = Counter()
@@ -10,7 +11,9 @@ for line in sys.stdin:
c[word] += 1
N += 1
-keys = sorted([k for k in c])
-for word in keys:
+out = dict()
+for word in c:
idf = math.log(float(N) / float(c[word])) / math.log(N)
- print word, ":", idf
+ out[word] = idf
+
+yaml.safe_dump(out, sys.stdout, default_flow_style=False, allow_unicode=True) \ No newline at end of file
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index a42b9028..d2deb379 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -17,7 +17,7 @@ add_library(librescorer OBJECT
add_library(libamunn OBJECT
decoder/config.cpp
- decoder/kenlm.cpp
+# decoder/kenlm.cpp
)
cuda_add_executable(
diff --git a/src/common/exception.cpp b/src/common/exception.cpp
index 01ff9a67..453fcf66 100644
--- a/src/common/exception.cpp
+++ b/src/common/exception.cpp
@@ -1,4 +1,4 @@
-#include "util/exception.hh"
+#include "exception.h"
#ifdef __GXX_RTTI
#include <typeinfo>
@@ -17,14 +17,18 @@ namespace util {
Exception::Exception() throw() {}
Exception::~Exception() throw() {}
+Exception::Exception(const Exception& o) throw() {
+ what_.str(o.what_.str());
+}
+
void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
/* The child class might have set some text, but we want this to come first.
* Another option would be passing this information to the constructor, but
* then child classes would have to accept constructor arguments and pass
* them down.
*/
- std::string old_text;
- what_.swap(old_text);
+ std::string old_text = what_.str();
+ what_.str(std::string());
what_ << file << ':' << line;
if (func) what_ << " in " << func << " threw ";
if (child_name) {
diff --git a/src/common/exception.h b/src/common/exception.h
index 0bfd73a6..85827d8c 100644
--- a/src/common/exception.h
+++ b/src/common/exception.h
@@ -1,7 +1,6 @@
#pragma once
-#include "util/string_stream.hh"
-
+#include <sstream>
#include <exception>
#include <limits>
#include <string>
@@ -15,6 +14,7 @@ class Exception : public std::exception {
public:
Exception() throw();
virtual ~Exception() throw();
+ Exception(const Exception& o) throw();
const char *what() const throw() { return what_.str().c_str(); }
@@ -34,7 +34,7 @@ class Exception : public std::exception {
typedef T Identity;
};
- StringStream what_;
+ std::stringstream what_;
};
/* This implements the normal operator<< for Exception and all its children.
diff --git a/src/common/file_stream.h b/src/common/file_stream.h
new file mode 100644
index 00000000..bff84355
--- /dev/null
+++ b/src/common/file_stream.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include <boost/filesystem.hpp>
+#include <boost/filesystem/fstream.hpp>
+#include <boost/iostreams/filtering_stream.hpp>
+#include <boost/iostreams/filter/gzip.hpp>
+#include <iostream>
+
+#include "exception.h"
+
+class InputFileStream {
+ public:
+ InputFileStream(const std::string& file)
+ : file_(file), ifstream_(file_)
+ {
+ UTIL_THROW_IF2(!boost::filesystem::exists(file_),
+ "File " << file << " does not exist");
+
+ if(file_.extension() == ".gz")
+ istream_.push(boost::iostreams::gzip_decompressor());
+ istream_.push(ifstream_);
+ }
+
+ operator std::istream& () {
+ return istream_;
+ }
+
+ operator bool () {
+ return istream_;
+ }
+
+ template <typename T>
+ friend InputFileStream& operator>>(InputFileStream& stream, T& t) {
+ stream.istream_ >> t;
+ return stream;
+ }
+
+ private:
+ boost::filesystem::path file_;
+ boost::filesystem::ifstream ifstream_;
+ boost::iostreams::filtering_istream istream_;
+};
diff --git a/src/common/vocab.h b/src/common/vocab.h
index 5994f2f9..04b014a0 100644
--- a/src/common/vocab.h
+++ b/src/common/vocab.h
@@ -3,19 +3,19 @@
#include <map>
#include <string>
#include <vector>
-#include <fstream>
#include <sstream>
#include "types.h"
#include "utils.h"
+#include "file_stream.h"
class Vocab {
public:
- Vocab(const std::string& txt) {
- std::ifstream in(txt.c_str());
+ Vocab(const std::string& path) {
+ InputFileStream in(path);
size_t c = 0;
std::string line;
- while(std::getline(in, line)) {
+ while(std::getline((std::istream&)in, line)) {
str2id_[line] = c++;
id2str_.push_back(line);
}
diff --git a/src/decoder/ape_penalty.h b/src/decoder/ape_penalty.h
index c64a31c9..8723d1a3 100644
--- a/src/decoder/ape_penalty.h
+++ b/src/decoder/ape_penalty.h
@@ -3,36 +3,48 @@
#include <vector>
#include "types.h"
+#include "file_stream.h"
#include "scorer.h"
#include "matrix.h"
+typedef std::vector<Word> SrcTrgMap;
+typedef std::vector<float> Penalties;
+
class ApePenaltyState : public State {
// Dummy, this scorer is stateless
};
class ApePenalty : public Scorer {
+ private:
+ const SrcTrgMap& srcTrgMap_;
+ const Penalties& penalties_;
public:
- ApePenalty(size_t sourceIndex)
- : Scorer(sourceIndex)
+ ApePenalty(
+ const SrcTrgMap& srcTrgMap,
+ const Penalties& penalties,
+ const YAML::Node& config,
+ size_t tab)
+ : Scorer(config, tab), srcTrgMap_(srcTrgMap),
+ penalties_(penalties)
{ }
// @TODO: make this work on GPU
virtual void SetSource(const Sentence& source) {
- const Words& words = source.GetWords(sourceIndex_);
- const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
- const Vocab& tvcb = God::GetTargetVocab();
+ const Words& words = source.GetWords(tab_);
costs_.clear();
- costs_.resize(tvcb.size(), -1.0);
- for(auto& s : words) {
- const std::string& sstr = svcb[s];
- Word t = tvcb[sstr];
+ costs_.resize(penalties_.size());
+ algo::copy(penalties_.begin(), penalties_.end(), costs_.begin());
+
+ for(auto&& s : words) {
+ Word t = srcTrgMap_[s];
if(t != UNK && t < costs_.size())
costs_[t] = 0.0;
}
}
+ // @TODO: make this work on GPU
virtual void Score(const State& in,
Prob& prob,
State& out) {
@@ -65,11 +77,34 @@ class ApePenaltyLoader : public Loader {
: Loader(config) {}
virtual void Load() {
- // @TODO: IDF weights
+ size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
+ const Vocab& svcb = God::GetSourceVocab(tab);
+ const Vocab& tvcb = God::GetTargetVocab();
+
+ srcTrgMap_.resize(svcb.size(), UNK);
+ for(Word s = 0; s < svcb.size(); ++s)
+ srcTrgMap_[s] = tvcb[svcb[s]];
+
+ penalties_.resize(tvcb.size(), -1.0);
+
+ if(Has("path")) {
+ LOG(info) << "Loading APE penalties from " << Get<std::string>("path");
+ YAML::Node penalties = YAML::Load(InputFileStream(Get<std::string>("path")));
+ for(auto&& pair : penalties) {
+ std::string entry = pair.first.as<std::string>();
+ float penalty = pair.second.as<float>();
+ penalties_[tvcb[entry]] = -penalty;
+ }
+ }
}
virtual ScorerPtr NewScorer(size_t taskId) {
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
- return ScorerPtr(new ApePenalty(tab));
+ return ScorerPtr(new ApePenalty(srcTrgMap_, penalties_,
+ config_, tab));
}
-}; \ No newline at end of file
+
+ private:
+ SrcTrgMap srcTrgMap_;
+ Penalties penalties_;
+};
diff --git a/src/decoder/config.cpp b/src/decoder/config.cpp
index 6fdebe5a..1278dba7 100644
--- a/src/decoder/config.cpp
+++ b/src/decoder/config.cpp
@@ -1,6 +1,7 @@
#include <set>
#include "config.h"
+#include "file_stream.h"
#include "exception.h"
#define SET_OPTION(key, type) \
@@ -79,7 +80,7 @@ void Config::AddOptions(size_t argc, char** argv) {
exit(0);
}
- config_ = YAML::LoadFile(configPath);
+ config_ = YAML::Load(InputFileStream(configPath));
SET_OPTION("n-best", bool)
SET_OPTION("normalize", bool)
diff --git a/src/decoder/encoder_decoder.h b/src/decoder/encoder_decoder.h
index 92d4985f..1dac4e4d 100644
--- a/src/decoder/encoder_decoder.h
+++ b/src/decoder/encoder_decoder.h
@@ -38,8 +38,10 @@ class EncoderDecoder : public Scorer {
typedef EncoderDecoderState EDState;
public:
- EncoderDecoder(const Weights& model, size_t tabIndex)
- : Scorer(tabIndex), model_(model),
+ EncoderDecoder(const Weights& model,
+ const YAML::Node& config,
+ size_t tab)
+ : Scorer(config, tab), model_(model),
encoder_(new Encoder(model_)), decoder_(new Decoder(model_))
{}
@@ -65,7 +67,7 @@ class EncoderDecoder : public Scorer {
}
virtual void SetSource(const Sentence& source) {
- encoder_->GetContext(source.GetWords(sourceIndex_),
+ encoder_->GetContext(source.GetWords(tab_),
SourceContext_);
}
@@ -107,7 +109,7 @@ class EncoderDecoderLoader : public Loader {
public:
EncoderDecoderLoader(const YAML::Node& config)
: Loader(config) {}
-
+
virtual void Load() {
std::string path = Get<std::string>("path");
auto devices = God::Get<std::vector<size_t>>("devices");
@@ -126,7 +128,7 @@ class EncoderDecoderLoader : public Loader {
size_t d = weights_[i]->GetDevice();
cudaSetDevice(d);
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
- return ScorerPtr(new EncoderDecoder(*weights_[i], tab));
+ return ScorerPtr(new EncoderDecoder(*weights_[i], config_, tab));
}
private:
diff --git a/src/decoder/god.cu b/src/decoder/god.cu
index eb10dce3..d2138459 100644
--- a/src/decoder/god.cu
+++ b/src/decoder/god.cu
@@ -7,6 +7,7 @@
#include "config.h"
#include "scorer.h"
#include "threadpool.h"
+#include "file_stream.h"
#include "loader_factory.h"
God God::instance_;
@@ -76,7 +77,7 @@ void God::CleanUp() {
void God::LoadWeights(const std::string& path) {
LOG(info) << "Reading weights from " << path;
- std::ifstream fweights(path.c_str());
+ InputFileStream fweights(path);
std::string name;
float weight;
size_t i = 0;
diff --git a/src/decoder/god.h b/src/decoder/god.h
index 6110f2a7..70c15c9a 100644
--- a/src/decoder/god.h
+++ b/src/decoder/god.h
@@ -7,9 +7,6 @@
#include "loader.h"
#include "scorer.h"
#include "logging.h"
-
-// this should not be here
-#include "kenlm.h"
class Weights;
diff --git a/src/decoder/scorer.h b/src/decoder/scorer.h
index ef6671a2..ea139938 100644
--- a/src/decoder/scorer.h
+++ b/src/decoder/scorer.h
@@ -28,8 +28,8 @@ typedef std::vector<Prob> Probs;
class Scorer {
public:
- Scorer() : sourceIndex_(0) {}
- Scorer(size_t sourceIndex) : sourceIndex_(sourceIndex) {}
+ Scorer(const YAML::Node& config, size_t tab)
+ : config_(config), tab_(tab) {}
virtual ~Scorer() {}
@@ -52,13 +52,14 @@ class Scorer {
virtual void CleanUpAfterSentence() {}
protected:
- size_t sourceIndex_;
+ const YAML::Node& config_;
+ size_t tab_;
};
class SourceIndependentScorer : public Scorer {
public:
- SourceIndependentScorer() : Scorer(0) {}
- SourceIndependentScorer(size_t) : Scorer(0) {}
+ SourceIndependentScorer(const YAML::Node& config, size_t)
+ : Scorer(config, 0) {}
virtual ~SourceIndependentScorer() {}