Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-23 14:26:38 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-23 14:26:38 +0300
commit537fccc3e81831fb0aa8baaaaef2858c96c346ec (patch)
tree57a48f2e1ff7894289587255f503f3ab7744203f
parent15947c6061c009679b0a00ca873d768c1159cd6f (diff)
parent9023667939b0fdd645f971cdeb0ab4e764b07057 (diff)
Merge branch 'master' of https://github.com/marian-nmt/marian-dev into charS2S
-rw-r--r--CHANGELOG.md10
-rw-r--r--CMakeLists.txt2
-rw-r--r--README.md1
-rw-r--r--VERSION2
-rw-r--r--src/common/config_parser.cpp8
-rw-r--r--src/common/shape.h16
-rw-r--r--src/graph/expression_operators.cu20
-rw-r--r--src/graph/expression_operators.h3
-rw-r--r--src/graph/node_operators_unary.h95
-rw-r--r--src/kernels/tensor_operators.cu29
-rw-r--r--src/kernels/tensor_operators.h4
-rw-r--r--src/models/encdec.h32
-rw-r--r--src/models/hardatt.h9
-rw-r--r--src/models/s2s.h2
-rw-r--r--src/models/states.h4
-rw-r--r--src/models/transformer.h39
-rw-r--r--src/rnn/attention.cu9
-rw-r--r--src/rnn/attention.h4
-rw-r--r--src/rnn/types.h32
-rw-r--r--src/tensors/allocator.h1
-rw-r--r--src/tests/CMakeLists.txt22
-rw-r--r--src/tests/README.md14
-rw-r--r--src/tests/attention_tests.cpp2
-rw-r--r--src/tests/rnn_tests.cpp2
-rw-r--r--src/training/validator.h25
-rw-r--r--src/translator/beam_search.h178
-rw-r--r--src/translator/history.h6
-rw-r--r--src/translator/scorers.h14
-rw-r--r--src/translator/translator.h42
29 files changed, 422 insertions, 205 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4997b61d..113a35de 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,9 +6,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
-- Added support for CUBLAS_TENSOR_OP_MATH mode for cublas in cuda 9.0
+
+## [1.1.0] - 2017-11-21
### Added
+- Batched translation for all model types, significant translation speed-up
+- Batched translation during validation with translation
+- `--maxi-batch-sort` option for `marian-decoder`
+- Support for CUBLAS_TENSOR_OP_MATH mode for cublas in cuda 9.0
+- The "marian-vocab" tool to create vocabularies
## [1.0.0] - 2017-11-13
@@ -29,7 +35,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Synchronous SGD training for multi-gpu (enable with `--sync-sgd`)
- Dynamic construction of complex models with different encoders and decoders,
currently only available through the C++ API
-- Option --quiet to suppress output to stderr
+- Option `--quiet` to suppress output to stderr
- Option to choose different variants of optimization criterion: mean
cross-entropy, perplexity, cross-entopry sum
- In-process translation for validation, uses the same memory as training
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9994b690..6c8f54de 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -27,7 +27,7 @@ set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
# Find packages
find_package(CUDA "8.0" REQUIRED)
if(CUDA_FOUND)
- set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY})
+ set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} tcmalloc_minimal)
endif(CUDA_FOUND)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
diff --git a/README.md b/README.md
index 8fefa775..44685be9 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,7 @@ Marian
[![Join the chat at https://gitter.im/marian-nmt](https://badges.gitter.im/amunmt/marian.svg)](https://gitter.im/marian-nmt?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Build Status](http://vali.inf.ed.ac.uk/jenkins/buildStatus/icon?job=marian-dev)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev/)
[![Tests Status](http://vali.inf.ed.ac.uk/jenkins/buildStatus/icon?job=marian-regression-tests)](http://vali.inf.ed.ac.uk/jenkins/job/marian-regression-tests/)
+[![Twitter](https://img.shields.io/twitter/follow/marian_nmt.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=marian_nmt)
**Marian** is a C++ GPU-specific parallel automatic differentiation library
with operator overloading. It is the training framework used in the Marian
diff --git a/VERSION b/VERSION
index 0ec25f75..795460fc 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-v1.0.0
+v1.1.0
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index c4812bb2..70ee4db7 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -163,6 +163,10 @@ void ConfigParser::validateOptions() const {
!modelDir.empty() && !boost::filesystem::is_directory(modelDir),
"Model directory does not exist");
+ UTIL_THROW_IF2(!(boost::filesystem::status(modelDir).permissions()
+ & boost::filesystem::owner_write),
+ "No write permission in model directory");
+
UTIL_THROW_IF2(
has("valid-sets")
&& get<std::vector<std::string>>("valid-sets").size()
@@ -518,6 +522,8 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) {
"Size of mini-batch used during update")
("maxi-batch", po::value<int>()->default_value(1),
"Number of batches to preload for length-based sorting")
+ ("maxi-batch-sort", po::value<std::string>()->default_value("none"),
+ "Sorting strategy for maxi-batch: none (default) src")
("n-best", po::value<bool>()->zero_tokens()->default_value(false),
"Display n-best list")
//("lexical-table", po::value<std::string>(),
@@ -802,7 +808,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION("mini-batch", int);
SET_OPTION("maxi-batch", int);
- if(mode_ == ConfigMode::training)
+ if(mode_ == ConfigMode::training || mode_ == ConfigMode::translating)
SET_OPTION("maxi-batch-sort", std::string);
SET_OPTION("max-length", size_t);
diff --git a/src/common/shape.h b/src/common/shape.h
index 9cd08f5e..44e0da9f 100644
--- a/src/common/shape.h
+++ b/src/common/shape.h
@@ -128,6 +128,12 @@ struct Shape {
return strm;
}
+ operator std::string() const {
+ std::stringstream ss;
+ ss << *this;
+ return ss.str();
+ }
+
int axis(int ax) {
if(ax < 0)
return size() + ax;
@@ -147,7 +153,9 @@ struct Shape {
for(auto& s : shapes) {
for(int i = 0; i < s.size(); ++i) {
ABORT_IF(shape[-i] != s[-i] && shape[-i] != 1 && s[-i] != 1,
- "Shapes cannot be broadcasted");
+ "Shapes {} and {} cannot be broadcasted",
+ (std::string)shape,
+ (std::string)s);
shape.set(-i, std::max(shape[-i], s[-i]));
}
}
@@ -170,10 +178,12 @@ struct Shape {
shape.resize(maxDims);
for(auto& node : nodes) {
- Shape shapen = node->shape();
+ const Shape& shapen = node->shape();
for(int i = 1; i <= shapen.size(); ++i) {
ABORT_IF(shape[-i] != shapen[-i] && shape[-i] != 1 && shapen[-i] != 1,
- "Shapes cannot be broadcasted");
+ "Shapes {} and {} cannot be broadcasted",
+ (std::string)shape,
+ (std::string)shapen);
shape.set(-i, std::max(shape[-i], shapen[-i]));
}
}
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index 4c4e0feb..d657ba74 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -121,6 +121,13 @@ Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax) {
return Expression<ConcatenateNodeOp>(concats, ax);
}
+Expr repeat(Expr a, size_t repeats, keywords::axis_k ax) {
+ if(repeats == 1)
+ return a;
+ return concatenate(std::vector<Expr>(repeats, a), ax);
+}
+
+
Expr reshape(Expr a, Shape shape) {
return Expression<ReshapeNodeOp>(a, shape);
}
@@ -137,6 +144,10 @@ Expr atleast_3d(Expr a, size_t dims) {
return atleast_nd(a, 3);
}
+Expr atleast_4d(Expr a) {
+ return atleast_nd(a, 4);
+}
+
Expr atleast_nd(Expr a, size_t dims) {
if(a->shape().size() >= dims)
return a;
@@ -154,6 +165,15 @@ Expr flatten(Expr a) {
return Expression<ReshapeNodeOp>(a, shape);
}
+Expr flatten_2d(Expr a) {
+ Shape shape = {
+ a->shape().elements() / a->shape()[-1],
+ a->shape()[-1]
+ };
+
+ return Expression<ReshapeNodeOp>(a, shape);
+}
+
Expr rows(Expr a, const std::vector<size_t>& indices) {
return Expression<RowsNodeOp>(a, indices);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 37e7c137..7302deab 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -73,15 +73,18 @@ Expr transpose(Expr a);
Expr transpose(Expr a, const std::vector<int>& axes);
Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0);
+Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0);
Expr reshape(Expr a, Shape shape);
Expr atleast_1d(Expr a);
Expr atleast_2d(Expr a);
Expr atleast_3d(Expr a);
+Expr atleast_4d(Expr a);
Expr atleast_nd(Expr a, size_t dims);
Expr flatten(Expr a);
+Expr flatten_2d(Expr a);
Expr rows(Expr a, const std::vector<size_t>& indices);
Expr cols(Expr a, const std::vector<size_t>& indices);
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index faf21dee..8390a0c2 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -39,6 +39,25 @@ public:
}
const std::string type() { return "scalar_add"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, scalar_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
};
struct ScalarMultNodeOp : public UnaryNodeOp {
@@ -61,6 +80,25 @@ public:
}
const std::string type() { return "scalar_add"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, scalar_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
};
struct LogitNodeOp : public UnaryNodeOp {
@@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp {
const std::string type() { return "PReLU"; }
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, alpha_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(alpha_ != cnode->alpha_)
+ return false;
+ return true;
+ }
+
private:
float alpha_{0.01};
};
@@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp {
};
struct SquareNodeOp : public UnaryNodeOp {
- float epsilon_;
-
template <typename... Args>
SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp {
template <typename... Args>
RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))};
+ return {NodeOp(CopyRows(val_, child(0)->val(), indices_))};
}
NodeOps backwardOps() {
- return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))};
+ return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))};
}
template <class... Args>
@@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp {
virtual size_t hash() {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp {
Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node);
if(!cnode)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
};
struct ColsNodeOp : public UnaryNodeOp {
template <typename... Args>
ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))};
+ return {NodeOp(CopyCols(val_, child(0)->val(), indices_))};
}
NodeOps backwardOps() {
- return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))};
+ return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))};
}
template <class... Args>
@@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp {
virtual size_t hash() {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp {
Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node);
if(!cnode)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
};
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
: UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
- Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))};
+ Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))};
}
NodeOps backwardOps() {
return {NodeOp(
- Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))};
+ Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))};
}
Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
@@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
boost::hash_combine(seed, axis_);
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp {
return false;
if(axis_ != cnode->axis_)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
int axis_{0};
};
diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu
index 3c017cb3..4088223a 100644
--- a/src/kernels/tensor_operators.cu
+++ b/src/kernels/tensor_operators.cu
@@ -1,4 +1,4 @@
-#include <thrust/transform_reduce.h>
+#include <thrust/transform_reduce.h>
#include "kernels/cuda_helpers.h"
#include "kernels/tensor_operators.h"
@@ -34,7 +34,6 @@ bool IsNan(Tensor in) {
void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) {
cudaSetDevice(out->getDevice());
-
int step = 1;
for(int i = 0; i < axis; ++i)
step *= out->shape()[i];
@@ -1309,21 +1308,20 @@ __global__ void gAtt(float* out,
const float* va,
const float* ctx,
const float* state,
- const float* cov,
int m, // total rows (batch x time x beam)
int k, // depth
int b, // batch size
- int t // time of ctx
+ int t // time of ctx
) {
int rows = m;
int cols = k;
- for(int bid = 0; bid < m; bid += gridDim.x) {
+
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
const float* vaRow = va;
const float* ctxRow = ctx + (j % (b * t)) * cols;
- const float* stateRow = state + (j / (b * t) + j % b) * cols;
- const float* covRow = cov ? cov + (j % (b * t)) * cols : nullptr;
+ const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols;
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
@@ -1333,8 +1331,6 @@ __global__ void gAtt(float* out,
int id = tid + threadIdx.x;
if(id < cols) {
float z = ctxRow[id] + stateRow[id];
- if(cov)
- z += covRow[id];
float ex = tanhf(z) * vaRow[id];
_sum[threadIdx.x] += ex;
}
@@ -1354,7 +1350,7 @@ __global__ void gAtt(float* out,
}
}
-void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage) {
+void Att(Tensor out, Tensor va, Tensor context, Tensor state) {
cudaSetDevice(out->getDevice());
size_t m = out->shape().elements() / out->shape().back();
@@ -1372,7 +1368,6 @@ void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage) {
va->data(),
context->data(),
state->data(),
- coverage ? coverage->data() : nullptr,
m,
k,
b,
@@ -1382,11 +1377,9 @@ void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage) {
__global__ void gAttBack(float* gVa,
float* gContext,
float* gState,
- float* gCoverage,
const float* va,
const float* context,
const float* state,
- const float* coverage,
const float* adj,
int m, // rows
int k, // cols
@@ -1399,26 +1392,20 @@ __global__ void gAttBack(float* gVa,
if(j < rows) {
float* gcRow = gContext + j * cols;
float* gsRow = gState + (j % n) * cols;
- float* gcovRow = gCoverage ? gCoverage + j * cols : nullptr;
const float* cRow = context + j * cols;
const float* sRow = state + (j % n) * cols;
- const float* covRow = coverage ? coverage + j * cols : nullptr;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
float z = cRow[id] + sRow[id];
- if(coverage)
- z += covRow[id];
float t = tanhf(z);
float r = va[id] * (1.f - t * t);
gcRow[id] += r * adj[j];
gsRow[id] += r * adj[j];
- if(gCoverage)
- gcovRow[id] += r * adj[j];
atomicAdd(gVa + id, t * adj[j]);
}
}
@@ -1429,11 +1416,9 @@ __global__ void gAttBack(float* gVa,
void AttBack(Tensor gVa,
Tensor gContext,
Tensor gState,
- Tensor gCoverage,
Tensor va,
Tensor context,
Tensor state,
- Tensor coverage,
Tensor adj) {
cudaSetDevice(adj->getDevice());
@@ -1449,12 +1434,10 @@ void AttBack(Tensor gVa,
gAttBack<<<blocks, threads>>>(gVa->data(),
gContext->data(),
gState->data(),
- gCoverage ? gCoverage->data() : nullptr,
va->data(),
context->data(),
state->data(),
- coverage ? coverage->data() : nullptr,
adj->data(),
m,
diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h
index 06cb188c..d317ec34 100644
--- a/src/kernels/tensor_operators.h
+++ b/src/kernels/tensor_operators.h
@@ -339,15 +339,13 @@ void GRUFastBackward(std::vector<Tensor> outputs,
Tensor adj,
bool final = false);
-void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage);
+void Att(Tensor out, Tensor va, Tensor context, Tensor state);
void AttBack(Tensor gva,
Tensor gContext,
Tensor gState,
- Tensor gCoverage,
Tensor va,
Tensor context,
Tensor state,
- Tensor coverage,
Tensor adj);
void LayerNormalization(Tensor out,
diff --git a/src/models/encdec.h b/src/models/encdec.h
index 8447a3a8..cfe304ae 100644
--- a/src/models/encdec.h
+++ b/src/models/encdec.h
@@ -124,15 +124,23 @@ public:
virtual void selectEmbeddings(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<size_t>& embIdx) {
+ const std::vector<size_t>& embIdx,
+ int beamSize) {
using namespace keywords;
int dimTrgEmb = opt<int>("dim-emb");
int dimTrgVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
+ int dimBatch = 1;
+ if(state->getEncoderStates().size() > 0)
+ dimBatch = state->getEncoderStates()[0]->getContext()->shape()[-2];
+
+ int dimBeam = embIdx.size() / dimBatch;
+
Expr selectedEmbs;
if(embIdx.empty()) {
- selectedEmbs = graph->constant({1, 1, 1, dimTrgEmb}, init = inits::zeros);
+ selectedEmbs = graph->constant({1, 1, dimBatch, dimTrgEmb},
+ init = inits::zeros);
} else {
// embeddings are loaded from model during translation, no fixing required
auto yEmbFactory = embedding(graph) //
@@ -148,7 +156,7 @@ public:
selectedEmbs = rows(yEmb, embIdx);
selectedEmbs
- = reshape(selectedEmbs, {(int)embIdx.size(), 1, 1, dimTrgEmb});
+ = reshape(selectedEmbs, {dimBeam, 1, dimBatch, dimTrgEmb});
}
state->setTargetEmbeddings(selectedEmbs);
}
@@ -167,13 +175,15 @@ class EncoderDecoderBase : public models::ModelBase {
public:
virtual void selectEmbeddings(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<size_t>&)
+ const std::vector<size_t>&,
+ int beamSize)
= 0;
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState>,
const std::vector<size_t>&,
- const std::vector<size_t>&)
+ const std::vector<size_t>&,
+ int beamSize)
= 0;
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph>, Ptr<DecoderState>) = 0;
@@ -300,9 +310,10 @@ public:
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const std::vector<size_t>& hypIndices,
- const std::vector<size_t>& embIndices) {
- auto selectedState = hypIndices.empty() ? state : state->select(hypIndices);
- selectEmbeddings(graph, selectedState, embIndices);
+ const std::vector<size_t>& embIndices,
+ int beamSize) {
+ auto selectedState = hypIndices.empty() ? state : state->select(hypIndices, beamSize);
+ selectEmbeddings(graph, selectedState, embIndices, beamSize);
selectedState->setSingleStep(true);
auto nextState = step(graph, selectedState);
nextState->setProbs(logsoftmax(nextState->getProbs()));
@@ -311,8 +322,9 @@ public:
virtual void selectEmbeddings(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<size_t>& embIdx) {
- decoders_[0]->selectEmbeddings(graph, state, embIdx);
+ const std::vector<size_t>& embIdx,
+ int beamSize) {
+ decoders_[0]->selectEmbeddings(graph, state, embIdx, beamSize);
}
virtual Expr build(Ptr<ExpressionGraph> graph,
diff --git a/src/models/hardatt.h b/src/models/hardatt.h
index d8ca8961..e6383a43 100644
--- a/src/models/hardatt.h
+++ b/src/models/hardatt.h
@@ -16,13 +16,13 @@ public:
: DecoderState(states, probs, encStates),
attentionIndices_(attentionIndices) {}
- virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx) {
+ virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx, int beamSize) {
std::vector<size_t> selectedAttentionIndices;
for(auto i : selIdx)
selectedAttentionIndices.push_back(attentionIndices_[i]);
return New<DecoderStateHardAtt>(
- states_.select(selIdx), probs_, encStates_, selectedAttentionIndices);
+ states_.select(selIdx, beamSize), probs_, encStates_, selectedAttentionIndices);
}
virtual void setAttentionIndices(
@@ -259,8 +259,9 @@ public:
virtual void selectEmbeddings(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<size_t>& embIdx) {
- DecoderBase::selectEmbeddings(graph, state, embIdx);
+ const std::vector<size_t>& embIdx,
+ int beamSize) {
+ DecoderBase::selectEmbeddings(graph, state, embIdx, beamSize);
auto stateHardAtt = std::dynamic_pointer_cast<DecoderStateHardAtt>(state);
diff --git a/src/models/s2s.h b/src/models/s2s.h
index 8da73bdd..ba738793 100644
--- a/src/models/s2s.h
+++ b/src/models/s2s.h
@@ -299,6 +299,8 @@ public:
// apply RNN to embeddings, initialized with encoder context mapped into
// decoder space
+
+ auto states = state->getStates();
auto decoderContext = rnn_->transduce(embeddings, state->getStates());
// retrieve the last state per layer. They are required during translation
diff --git a/src/models/states.h b/src/models/states.h
index e7cef04c..d3ad2dfa 100644
--- a/src/models/states.h
+++ b/src/models/states.h
@@ -50,8 +50,8 @@ public:
virtual Expr getProbs() { return probs_; }
virtual void setProbs(Expr probs) { probs_ = probs; }
- virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx) {
- return New<DecoderState>(states_.select(selIdx), probs_, encStates_);
+ virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx, int beamSize) {
+ return New<DecoderState>(states_.select(selIdx, beamSize), probs_, encStates_);
}
virtual const rnn::States& getStates() { return states_; }
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 0c718db9..92030c83 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -50,7 +50,7 @@ public:
// convert 0/1 mask to transformer style -inf mask
auto ms = mask->shape();
mask = (1 - mask) * -99999999.f;
- return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]});
+ return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}) ;
}
Expr SplitHeads(Expr input, int dimHeads) {
@@ -171,9 +171,10 @@ public:
// @TODO: do this better
int dimBeamQ = q->shape()[-4];
int dimBeamK = k->shape()[-4];
- if(dimBeamQ != dimBeamK) {
- k = concatenate(std::vector<Expr>(dimBeamQ, k), axis = -4);
- v = concatenate(std::vector<Expr>(dimBeamQ, v), axis = -4);
+ int dimBeam = dimBeamQ / dimBeamK;
+ if(dimBeam > 1) {
+ k = repeat(k, dimBeam, axis = -4);
+ v = repeat(v, dimBeam, axis = -4);
}
auto weights = softmax(bdot(q, k, false, true, scale) + mask);
@@ -237,6 +238,7 @@ public:
// apply multi-head attention to downscaled inputs
auto output
= Attention(graph, options, prefix, qh, kh, vh, masks[i], inference);
+
output = JoinHeads(output, q->shape()[-4]);
outputs.push_back(output);
@@ -456,12 +458,23 @@ public:
std::vector<Ptr<EncoderState>> &encStates)
: DecoderState(states, probs, encStates) {}
- virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx) {
+ virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx, int beamSize) {
rnn::States selectedStates;
- for(auto state : states_)
- selectedStates.push_back(
- {marian::select(state.output, -4, selIdx), nullptr});
+ int dimDepth = states_[0].output->shape()[-1];
+ int dimTime = states_[0].output->shape()[-2];
+ int dimBatch = selIdx.size() / beamSize;
+
+ std::vector<size_t> selIdx2;
+ for(auto i : selIdx)
+ for(int j = 0; j < dimTime; ++j)
+ selIdx2.push_back(i * dimTime + j);
+
+ for(auto state : states_) {
+ auto sel = rows(flatten_2d(state.output), selIdx2);
+ sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth});
+ selectedStates.push_back({sel, nullptr});
+ }
return New<TransformerState>(selectedStates, probs_, encStates_);
}
@@ -497,6 +510,9 @@ public:
//************************************************************************//
int dimEmb = embeddings->shape()[-1];
+ int dimBeam = 1;
+ if(embeddings->shape().size() > 3)
+ dimBeam = embeddings->shape()[-4];
// according to paper embeddings are scaled by \sqrt(d_m)
auto scaledEmbeddings = std::sqrt(dimEmb) * embeddings;
@@ -528,6 +544,8 @@ public:
decoderMask = reshape(TransposeTimeBatch(decoderMask),
{1, dimBatch, 1, dimTrgWords});
selfMask = selfMask * decoderMask;
+ //if(dimBeam > 1)
+ // selfMask = repeat(selfMask, dimBeam, axis = -4);
}
selfMask = InverseMask(selfMask);
@@ -548,6 +566,8 @@ public:
encoderMask = reshape(TransposeTimeBatch(encoderMask),
{1, dimBatch, 1, dimSrcWords});
encoderMask = InverseMask(encoderMask);
+ if(dimBeam > 1)
+ encoderMask = repeat(encoderMask, dimBeam, axis = -4);
encoderContexts.push_back(encoderContext);
encoderMasks.push_back(encoderMask);
@@ -557,8 +577,7 @@ public:
for(int i = 1; i <= opt<int>("dec-depth"); ++i) {
auto values = query;
if(prevDecoderStates.size() > 0)
- values
- = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2);
+ values = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2);
decoderStates.push_back({values, nullptr});
diff --git a/src/rnn/attention.cu b/src/rnn/attention.cu
index 752692bc..232b99f3 100644
--- a/src/rnn/attention.cu
+++ b/src/rnn/attention.cu
@@ -25,8 +25,7 @@ struct AttentionNodeOp : public NaryNodeOp {
return {NodeOp(Att(val_,
child(0)->val(),
child(1)->val(),
- child(2)->val(),
- children_.size() == 4 ? child(3)->val() : nullptr))};
+ child(2)->val()))};
}
NodeOps backwardOps() {
@@ -34,11 +33,9 @@ struct AttentionNodeOp : public NaryNodeOp {
NodeOp(AttBack(child(0)->grad(),
child(1)->grad(),
child(2)->grad(),
- children_.size() == 4 ? child(3)->grad() : nullptr,
child(0)->val(),
child(1)->val(),
child(2)->val(),
- children_.size() == 4 ? child(3)->val() : nullptr,
adj_);)
};
}
@@ -54,10 +51,8 @@ struct AttentionNodeOp : public NaryNodeOp {
const std::string color() { return "yellow"; }
};
-Expr attOps(Expr va, Expr context, Expr state, Expr coverage) {
+Expr attOps(Expr va, Expr context, Expr state) {
std::vector<Expr> nodes{va, context, state};
- if(coverage)
- nodes.push_back(coverage);
int dimBatch = context->shape()[-2];
int dimWords = context->shape()[-3];
diff --git a/src/rnn/attention.h b/src/rnn/attention.h
index 6d182920..751d3a3c 100644
--- a/src/rnn/attention.h
+++ b/src/rnn/attention.h
@@ -11,7 +11,7 @@ namespace marian {
namespace rnn {
-Expr attOps(Expr va, Expr context, Expr state, Expr coverage = nullptr);
+Expr attOps(Expr va, Expr context, Expr state);
class GlobalAttention : public CellInput {
private:
@@ -147,7 +147,7 @@ public:
auto alignedSource
= scalar_product(encState_->getAttended(), e, axis = -3);
-
+
contexts_.push_back(alignedSource);
alignments_.push_back(e);
return alignedSource;
diff --git a/src/rnn/types.h b/src/rnn/types.h
index 7358685b..17fa9bab 100644
--- a/src/rnn/types.h
+++ b/src/rnn/types.h
@@ -14,24 +14,26 @@ struct State {
Expr output;
Expr cell;
- State select(const std::vector<size_t>& indices) {
- if(output->shape().size() < 4) {
- int dimState = output->shape()[-1];
- int dimBatch = output->shape()[-2];
- int dimTime = output->shape()[-3];
-
- output = reshape(output, {1, dimTime, dimBatch, dimState});
- if(cell)
- cell = reshape(cell, {1, dimTime, dimBatch, dimState});
- }
+ State select(const std::vector<size_t>& indices, int beamSize) {
+ output = atleast_4d(output);
+ if(cell)
+ cell = atleast_4d(cell);
+
+ int dimDepth = output->shape()[-1];
+ int dimTime = output->shape()[-3];
+
+ int dimBatch = indices.size() / beamSize;
if(cell) {
return State{
- marian::select(output, 0, indices),
- marian::select(cell, 0, indices)};
+ reshape(rows(flatten_2d(output), indices),
+ {beamSize, dimTime, dimBatch, dimDepth}),
+ reshape(rows(flatten_2d(cell), indices),
+ {beamSize, dimTime, dimBatch, dimDepth})};
} else {
return State{
- marian::select(output, 0, indices),
+ reshape(rows(flatten_2d(output), indices),
+ {beamSize, dimTime, dimBatch, dimDepth}),
nullptr};
}
}
@@ -72,10 +74,10 @@ public:
void push_back(const State& state) { states_.push_back(state); }
- States select(const std::vector<size_t>& indices) {
+ States select(const std::vector<size_t>& indices, int beamSize) {
States selected;
for(auto& state : states_)
- selected.push_back(state.select(indices));
+ selected.push_back(state.select(indices, beamSize));
return selected;
}
diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h
index 8a84e4d7..72d4805a 100644
--- a/src/tensors/allocator.h
+++ b/src/tensors/allocator.h
@@ -179,7 +179,6 @@ public:
auto ptr = gap.data();
auto mp = New<MemoryPiece>(ptr, bytes);
allocated_[ptr] = mp;
-
return mp;
}
diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt
index 465d5478..f962594c 100644
--- a/src/tests/CMakeLists.txt
+++ b/src/tests/CMakeLists.txt
@@ -1,17 +1,19 @@
# Unit tests
-set(TEST_SOURCES
- graph_tests.cpp
- operator_tests.cpp
- rnn_tests.cpp
- attention_tests.cpp
+set(UNIT_TESTS
+ graph_tests
+ operator_tests
+ rnn_tests
+ attention_tests
)
-add_executable(run_tests run_tests.cpp ${TEST_SOURCES})
-target_link_libraries(run_tests marian ${EXT_LIBS} Catch)
-cuda_add_cublas_to_target(run_tests)
-set_target_properties(run_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
+foreach(test ${UNIT_TESTS})
+ add_executable("run_${test}" run_tests.cpp "${test}.cpp")
+ target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
+ cuda_add_cublas_to_target("run_${test}")
+
+ add_test(NAME ${test} COMMAND "run_${test}")
+endforeach(test)
-add_test(NAME GraphTest COMMAND run_tests)
# Testing apps
add_executable(logger_test logger_test.cpp)
diff --git a/src/tests/README.md b/src/tests/README.md
index 31750313..77b23a6b 100644
--- a/src/tests/README.md
+++ b/src/tests/README.md
@@ -2,7 +2,19 @@ Marian tests
============
Unit tests and application tests are enabled with CMake option
-`-DCOMPILE_TESTS=ON`.
+`-DCOMPILE_TESTS=ON`, e.g.:
+
+ cd build
+ cmake .. -DCOMPILE_TESTS=ON
+ make -j8
+
+Running all unit tests:
+
+ make test
+
+Running a single unit test is also possible:
+
+ ./src/tests/run_graph_tests
We use [Catch framework](https://github.com/philsquared/Catch) for unit
testing.
diff --git a/src/tests/attention_tests.cpp b/src/tests/attention_tests.cpp
index 722b82a7..2bf5e1f2 100644
--- a/src/tests/attention_tests.cpp
+++ b/src/tests/attention_tests.cpp
@@ -5,7 +5,7 @@ using namespace marian;
TEST_CASE("Model components, Attention", "[attention]") {
- auto floatApprox = [](float x, float y) { return x == Approx(y); };
+ auto floatApprox = [](float x, float y) { return x == Approx(y).epsilon(0.01); };
std::vector<size_t> vWords = {
43, 2, 83, 78,
diff --git a/src/tests/rnn_tests.cpp b/src/tests/rnn_tests.cpp
index 1827a9bf..54e0e7b1 100644
--- a/src/tests/rnn_tests.cpp
+++ b/src/tests/rnn_tests.cpp
@@ -5,7 +5,7 @@ using namespace marian;
TEST_CASE("Model components, RNN etc.", "[model]") {
- auto floatApprox = [](float x, float y) { return x == Approx(y); };
+ auto floatApprox = [](float x, float y) { return x == Approx(y).epsilon(0.01); };
std::vector<size_t> vWords = {
43, 2, 83, 78,
diff --git a/src/training/validator.h b/src/training/validator.h
index 1c6233e9..fc42c390 100644
--- a/src/training/validator.h
+++ b/src/training/validator.h
@@ -58,6 +58,7 @@ public:
opts->set("max-length", options_->get<size_t>("valid-max-length"));
if(options_->has("valid-mini-batch"))
opts->set("mini-batch", options_->get<size_t>("valid-mini-batch"));
+ opts->set("mini-batch-sort", "src");
// Create corpus
auto validPaths = options_->get<std::vector<std::string>>("valid-sets");
@@ -224,8 +225,8 @@ public:
// Temporary options for translation
auto opts = New<Config>(*options_);
- opts->set("mini-batch", 1);
- opts->set("maxi-batch", 1);
+ //opts->set("mini-batch", 1);
+ //opts->set("maxi-batch", 1);
opts->set("max-length", 1000);
// Create corpus
@@ -298,15 +299,17 @@ public:
}
auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer});
- auto history = search->search(graph, batch, id);
-
- std::stringstream best1;
- std::stringstream bestn;
- Printer(options_, vocabs_.back(), history, best1, bestn);
- collector->Write(history->GetLineNum(),
- best1.str(),
- bestn.str(),
- options_->get<bool>("n-best"));
+ auto histories = search->search(graph, batch);
+
+ for(auto history : histories) {
+ std::stringstream best1;
+ std::stringstream bestn;
+ Printer(options_, vocabs_.back(), history, best1, bestn);
+ collector->Write(history->GetLineNum(),
+ best1.str(),
+ bestn.str(),
+ options_->get<bool>("n-best"));
+ }
};
threadPool.enqueue(task, sentenceId);
diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h
index eb25d98f..178e9238 100644
--- a/src/translator/beam_search.h
+++ b/src/translator/beam_search.h
@@ -1,4 +1,5 @@
#pragma once
+#include <algorithm>
#include "marian.h"
#include "translator/history.h"
@@ -26,51 +27,92 @@ public:
? options_->get<size_t>("beam-size")
: 3) {}
- Beam toHyps(const std::vector<uint> keys,
- const std::vector<float> costs,
- size_t vocabSize,
- const Beam& beam,
- std::vector<Ptr<ScorerState>>& states) {
- Beam newBeam;
- for(int i = 0; i < keys.size(); ++i) {
- int embIdx = keys[i] % vocabSize;
- int hypIdx = keys[i] / vocabSize;
- float cost = costs[i];
-
- std::vector<float> breakDown(states.size(), 0);
- beam[hypIdx]->GetCostBreakdown().resize(states.size(), 0);
-
- for(int j = 0; j < states.size(); ++j)
- breakDown[j] = states[j]->breakDown(keys[i])
- + beam[hypIdx]->GetCostBreakdown()[j];
+ Beams toHyps(const std::vector<uint> keys,
+ const std::vector<float> costs,
+ size_t vocabSize,
+ const Beams& beams,
+ std::vector<Ptr<ScorerState>>& states,
+ size_t beamSize,
+ bool first) {
- auto hyp = New<Hypothesis>(beam[hypIdx], embIdx, hypIdx, cost);
- hyp->GetCostBreakdown() = breakDown;
- newBeam.push_back(hyp);
+ Beams newBeams(beams.size());
+ for(int i = 0; i < keys.size(); ++i) {
+ int embIdx = keys[i] % vocabSize;
+ int beamIdx = i / beamSize;
+
+ if(newBeams[beamIdx].size() < beams[beamIdx].size()) {
+ auto& beam = beams[beamIdx];
+ auto& newBeam = newBeams[beamIdx];
+
+ int hypIdx = keys[i] / vocabSize;
+ float cost = costs[i];
+
+ int hypIdxTrans = (hypIdx / beamSize) +
+ (hypIdx % beamSize) * beams.size();
+ if(first)
+ hypIdxTrans = hypIdx;
+
+ int beamHypIdx = hypIdx % beamSize;
+ if(beamHypIdx >= beam.size())
+ beamHypIdx = beamHypIdx % beam.size();
+
+ if(first)
+ beamHypIdx = 0;
+
+ auto hyp = New<Hypothesis>(beam[beamHypIdx], embIdx, hypIdxTrans, cost);
+ if(options_->get<bool>("n-best")) {
+ std::vector<float> breakDown(states.size(), 0);
+ beam[beamHypIdx]->GetCostBreakdown().resize(states.size(), 0);
+ for(int j = 0; j < states.size(); ++j) {
+ int key = embIdx + hypIdxTrans * vocabSize;
+ breakDown[j] = states[j]->breakDown(key)
+ + beam[beamHypIdx]->GetCostBreakdown()[j];
+ }
+ hyp->GetCostBreakdown() = breakDown;
+ }
+ newBeam.push_back(hyp);
+ }
}
- return newBeam;
+ return newBeams;
}
- Beam pruneBeam(const Beam& beam) {
- Beam newBeam;
- for(auto hyp : beam) {
- if(hyp->GetWord() > 0) {
- newBeam.push_back(hyp);
+ Beams pruneBeam(const Beams& beams) {
+ Beams newBeams;
+ for(auto beam: beams) {
+ Beam newBeam;
+ for(auto hyp : beam) {
+ if(hyp->GetWord() > 0) {
+ newBeam.push_back(hyp);
+ }
}
+ newBeams.push_back(newBeam);
}
- return newBeam;
+ return newBeams;
}
- Ptr<History> search(Ptr<ExpressionGraph> graph,
- Ptr<data::CorpusBatch> batch,
- size_t sentenceId = 0) {
- auto history = New<History>(sentenceId, options_->get<float>("normalize"));
- Beam beam(1, New<Hypothesis>());
+ Histories search(Ptr<ExpressionGraph> graph,
+ Ptr<data::CorpusBatch> batch) {
+
+ int dimBatch = batch->size();
+ Histories histories;
+ for(int i = 0; i < dimBatch; ++i) {
+ size_t sentId = batch->getSentenceIds()[i];
+ auto history = New<History>(sentId, options_->get<float>("normalize"));
+ histories.push_back(history);
+ }
+
+ size_t localBeamSize = beamSize_;
+ auto nth = New<NthElement>(localBeamSize, dimBatch);
+
+ Beams beams(dimBatch);
+ for(auto& beam : beams)
+ beam.resize(localBeamSize, New<Hypothesis>());
+
bool first = true;
bool final = false;
- std::vector<size_t> beamSizes(1, beamSize_);
- auto nth = New<NthElement>(beamSize_, batch->size());
- history->Add(beam);
+
+ for(int i = 0; i < dimBatch; ++i)
+ histories[i]->Add(beams[i]);
std::vector<Ptr<ScorerState>> states;
@@ -94,13 +136,28 @@ public:
keywords::init = inits::from_value(0));
} else {
std::vector<float> beamCosts;
- for(auto hyp : beam) {
- hypIndices.push_back(hyp->GetPrevStateIndex());
- embIndices.push_back(hyp->GetWord());
- beamCosts.push_back(hyp->GetCost());
+
+ int dimBatch = batch->size();
+
+ for(int i = 0; i < localBeamSize; ++i) {
+ for(int j = 0; j < beams.size(); ++j) {
+ auto& beam = beams[j];
+ if(i < beam.size()) {
+ auto hyp = beam[i];
+ hypIndices.push_back(hyp->GetPrevStateIndex());
+ embIndices.push_back(hyp->GetWord());
+ beamCosts.push_back(hyp->GetCost());
+ }
+ else {
+ hypIndices.push_back(0);
+ embIndices.push_back(0);
+ beamCosts.push_back(-9999);
+ }
+ }
}
+
prevCosts
- = graph->constant({(int)beamCosts.size(), 1, 1, 1},
+ = graph->constant({(int)localBeamSize, 1, dimBatch, 1},
keywords::init = inits::from_vector(beamCosts));
}
@@ -109,11 +166,18 @@ public:
auto totalCosts = prevCosts;
for(int i = 0; i < scorers_.size(); ++i) {
- states[i] = scorers_[i]->step(graph, states[i], hypIndices, embIndices);
- totalCosts
- = totalCosts + scorers_[i]->getWeight() * states[i]->getProbs();
+ states[i] = scorers_[i]->step(graph, states[i], hypIndices, embIndices, localBeamSize);
+
+ if(scorers_[i]->getWeight() != 1.f)
+ totalCosts = totalCosts + scorers_[i]->getWeight() * states[i]->getProbs();
+ else
+ totalCosts = totalCosts + states[i]->getProbs();
}
+ // make beams continuous
+ if(dimBatch > 1 && localBeamSize > 1)
+ totalCosts = transpose(totalCosts, {2, 1, 0, 3});
+
if(first)
graph->forward();
else
@@ -131,21 +195,33 @@ public:
std::vector<unsigned> outKeys;
std::vector<float> outCosts;
- beamSizes[0] = first ? beamSize_ : beam.size();
+ std::vector<size_t> beamSizes(dimBatch, localBeamSize);
nth->getNBestList(beamSizes, totalCosts->val(), outCosts, outKeys, first);
int dimTrgVoc = totalCosts->shape()[-1];
- beam = toHyps(outKeys, outCosts, dimTrgVoc, beam, states);
-
- final = history->size() >= 3 * batch->words();
- history->Add(beam, final);
- beam = pruneBeam(beam);
+ beams = toHyps(outKeys, outCosts, dimTrgVoc, beams, states, localBeamSize, first);
+ auto prunedBeams = pruneBeam(beams);
+ for(int i = 0; i < dimBatch; ++i) {
+ if(!beams[i].empty()) {
+ final = final || histories[i]->size() >= 3 * batch->front()->batchWidth();
+ histories[i]->Add(beams[i], prunedBeams[i].empty() || final);
+ }
+ }
+ beams = prunedBeams;
+
+ if(!first) {
+ size_t maxBeam = 0;
+ for(auto& beam : beams)
+ if(beam.size() > maxBeam)
+ maxBeam = beam.size();
+ localBeamSize = maxBeam;
+ }
first = false;
- } while(!beam.empty() && !final);
+ } while(localBeamSize != 0 && !final);
- return history;
+ return histories;
}
};
}
diff --git a/src/translator/history.h b/src/translator/history.h
index dbdf18a5..0070b723 100644
--- a/src/translator/history.h
+++ b/src/translator/history.h
@@ -27,6 +27,7 @@ public:
if(beam[j]->GetWord() == 0 || last) {
float cost = beam[j]->GetCost() / LengthPenalty(history_.size());
topHyps_.push({history_.size(), j, cost});
+ //std::cerr << "Add " << history_.size() << " " << j << " " << cost << std::endl;
}
}
history_.push_back(beam);
@@ -43,11 +44,14 @@ public:
size_t start = bestHypCoord.i;
size_t j = bestHypCoord.j;
+ //float c = bestHypCoord.cost;
+ //std::cerr << "h: " << start << " " << j << " " << c << std::endl;
Words targetWords;
Ptr<Hypothesis> bestHyp = history_[start][j];
while(bestHyp->GetPrevHyp() != nullptr) {
targetWords.push_back(bestHyp->GetWord());
+ //std::cerr << bestHyp->GetWord() << " " << bestHyp << std::endl;
bestHyp = bestHyp->GetPrevHyp();
}
@@ -70,5 +74,5 @@ private:
float alpha_;
};
-typedef std::vector<History> Histories;
+typedef std::vector<Ptr<History>> Histories;
}
diff --git a/src/translator/scorers.h b/src/translator/scorers.h
index 171abccd..e963efb0 100644
--- a/src/translator/scorers.h
+++ b/src/translator/scorers.h
@@ -33,7 +33,8 @@ public:
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph>,
Ptr<ScorerState>,
const std::vector<size_t>&,
- const std::vector<size_t>&)
+ const std::vector<size_t>&,
+ int beamSize)
= 0;
virtual void init(Ptr<ExpressionGraph> graph) {}
@@ -88,12 +89,13 @@ public:
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,
Ptr<ScorerState> state,
const std::vector<size_t>& hypIndices,
- const std::vector<size_t>& embIndices) {
+ const std::vector<size_t>& embIndices,
+ int beamSize) {
graph->switchParams(getName());
auto wrappedState
= std::dynamic_pointer_cast<ScorerWrapperState>(state)->getState();
return New<ScorerWrapperState>(
- encdec_->step(graph, wrappedState, hypIndices, embIndices));
+ encdec_->step(graph, wrappedState, hypIndices, embIndices, beamSize));
}
};
@@ -138,7 +140,8 @@ public:
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,
Ptr<ScorerState> state,
const std::vector<size_t>& hypIndices,
- const std::vector<size_t>& embIndices) {
+ const std::vector<size_t>& embIndices,
+ int beamSize) {
return state;
}
};
@@ -173,7 +176,8 @@ public:
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,
Ptr<ScorerState> state,
const std::vector<size_t>& hypIndices,
- const std::vector<size_t>& embIndices) {
+ const std::vector<size_t>& embIndices,
+ int beamSize) {
return state;
}
};
diff --git a/src/translator/translator.h b/src/translator/translator.h
index bc96803d..598aee2b 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -62,7 +62,7 @@ public:
auto devices = options_->get<std::vector<int>>("devices");
ThreadPool threadPool(devices.size(), devices.size());
- size_t sentenceId = 0;
+ size_t batchId = 0;
auto collector = New<OutputCollector>();
if(options_->get<bool>("quiet-translation"))
collector->setPrintingStrategy(New<QuietPrinting>());
@@ -83,18 +83,20 @@ public:
}
auto search = New<Search>(options_, scorers);
- auto history = search->search(graph, batch, id);
-
- std::stringstream best1;
- std::stringstream bestn;
- Printer(options_, trgVocab_, history, best1, bestn);
- collector->Write(history->GetLineNum(),
- best1.str(),
- bestn.str(),
- options_->get<bool>("n-best"));
+ auto histories = search->search(graph, batch);
+
+ for(auto history : histories) {
+ std::stringstream best1;
+ std::stringstream bestn;
+ Printer(options_, trgVocab_, history, best1, bestn);
+ collector->Write(history->GetLineNum(),
+ best1.str(),
+ bestn.str(),
+ options_->get<bool>("n-best"));
+ }
};
- threadPool.enqueue(task, sentenceId++);
+ threadPool.enqueue(task, batchId++);
}
}
};
@@ -150,7 +152,7 @@ public:
data::BatchGenerator<data::TextInput> bg(corpus_, options_);
auto collector = New<StringCollector>();
- size_t sentenceId = 0;
+ size_t batchId = 0;
bg.prepare(false);
@@ -171,16 +173,18 @@ public:
}
auto search = New<Search>(options_, scorers);
- auto history = search->search(graph, batch, id);
+ auto histories = search->search(graph, batch);
- std::stringstream best1;
- std::stringstream bestn;
- Printer(options_, trgVocab_, history, best1, bestn);
- collector->add(history->GetLineNum(), best1.str(), bestn.str());
+ for(auto history : histories) {
+ std::stringstream best1;
+ std::stringstream bestn;
+ Printer(options_, trgVocab_, history, best1, bestn);
+ collector->add(history->GetLineNum(), best1.str(), bestn.str());
+ }
};
- threadPool_.enqueue(task, sentenceId);
- sentenceId++;
+ threadPool_.enqueue(task, batchId);
+ batchId++;
}
}