diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-23 14:26:38 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-23 14:26:38 +0300 |
commit | 537fccc3e81831fb0aa8baaaaef2858c96c346ec (patch) | |
tree | 57a48f2e1ff7894289587255f503f3ab7744203f | |
parent | 15947c6061c009679b0a00ca873d768c1159cd6f (diff) | |
parent | 9023667939b0fdd645f971cdeb0ab4e764b07057 (diff) |
Merge branch 'master' of https://github.com/marian-nmt/marian-dev into charS2S
29 files changed, 422 insertions, 205 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 4997b61d..113a35de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] -- Added support for CUBLAS_TENSOR_OP_MATH mode for cublas in cuda 9.0 + +## [1.1.0] - 2017-11-21 ### Added +- Batched translation for all model types, significant translation speed-up +- Batched translation during validation with translation +- `--maxi-batch-sort` option for `marian-decoder` +- Support for CUBLAS_TENSOR_OP_MATH mode for cublas in cuda 9.0 +- The "marian-vocab" tool to create vocabularies ## [1.0.0] - 2017-11-13 @@ -29,7 +35,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Synchronous SGD training for multi-gpu (enable with `--sync-sgd`) - Dynamic construction of complex models with different encoders and decoders, currently only available through the C++ API -- Option --quiet to suppress output to stderr +- Option `--quiet` to suppress output to stderr - Option to choose different variants of optimization criterion: mean cross-entropy, perplexity, cross-entopry sum - In-process translation for validation, uses the same memory as training diff --git a/CMakeLists.txt b/CMakeLists.txt index 9994b690..6c8f54de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) # Find packages find_package(CUDA "8.0" REQUIRED) if(CUDA_FOUND) - set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY}) + set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} tcmalloc_minimal) endif(CUDA_FOUND) if (CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -4,6 +4,7 @@ Marian [![Join the chat at https://gitter.im/marian-nmt](https://badges.gitter.im/amunmt/marian.svg)](https://gitter.im/marian-nmt?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](http://vali.inf.ed.ac.uk/jenkins/buildStatus/icon?job=marian-dev)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev/) [![Tests Status](http://vali.inf.ed.ac.uk/jenkins/buildStatus/icon?job=marian-regression-tests)](http://vali.inf.ed.ac.uk/jenkins/job/marian-regression-tests/) +[![Twitter](https://img.shields.io/twitter/follow/marian_nmt.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=marian_nmt) **Marian** is a C++ GPU-specific parallel automatic differentiation library with operator overloading. It is the training framework used in the Marian @@ -1 +1 @@ -v1.0.0 +v1.1.0 diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index c4812bb2..70ee4db7 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -163,6 +163,10 @@ void ConfigParser::validateOptions() const { !modelDir.empty() && !boost::filesystem::is_directory(modelDir), "Model directory does not exist"); + UTIL_THROW_IF2(!(boost::filesystem::status(modelDir).permissions() + & boost::filesystem::owner_write), + "No write permission in model directory"); + UTIL_THROW_IF2( has("valid-sets") && get<std::vector<std::string>>("valid-sets").size() @@ -518,6 +522,8 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) { "Size of mini-batch used during update") ("maxi-batch", po::value<int>()->default_value(1), "Number of batches to preload for length-based sorting") + ("maxi-batch-sort", po::value<std::string>()->default_value("none"), + "Sorting strategy for maxi-batch: none (default) src") ("n-best", po::value<bool>()->zero_tokens()->default_value(false), "Display n-best list") //("lexical-table", po::value<std::string>(), @@ -802,7 +808,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) { SET_OPTION("mini-batch", int); SET_OPTION("maxi-batch", int); - if(mode_ == ConfigMode::training) + if(mode_ == ConfigMode::training || mode_ == ConfigMode::translating) SET_OPTION("maxi-batch-sort", std::string); SET_OPTION("max-length", size_t); diff --git a/src/common/shape.h b/src/common/shape.h index 9cd08f5e..44e0da9f 100644 --- a/src/common/shape.h +++ b/src/common/shape.h @@ -128,6 +128,12 @@ struct Shape { return strm; } + operator std::string() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + int axis(int ax) { if(ax < 0) return size() + ax; @@ -147,7 +153,9 @@ struct Shape { for(auto& s : shapes) { for(int i = 0; i < s.size(); ++i) { ABORT_IF(shape[-i] != s[-i] && shape[-i] != 1 && s[-i] != 1, - "Shapes cannot be broadcasted"); + "Shapes {} and {} cannot be broadcasted", + (std::string)shape, + (std::string)s); shape.set(-i, std::max(shape[-i], s[-i])); } } @@ -170,10 +178,12 @@ struct Shape { shape.resize(maxDims); for(auto& node : nodes) { - Shape shapen = node->shape(); + const Shape& shapen = node->shape(); for(int i = 1; i <= shapen.size(); ++i) { ABORT_IF(shape[-i] != shapen[-i] && shape[-i] != 1 && shapen[-i] != 1, - "Shapes cannot be broadcasted"); + "Shapes {} and {} cannot be broadcasted", + (std::string)shape, + (std::string)shapen); shape.set(-i, std::max(shape[-i], shapen[-i])); } } diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 4c4e0feb..d657ba74 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -121,6 +121,13 @@ Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax) { return Expression<ConcatenateNodeOp>(concats, ax); } +Expr repeat(Expr a, size_t repeats, keywords::axis_k ax) { + if(repeats == 1) + return a; + return concatenate(std::vector<Expr>(repeats, a), ax); +} + + Expr reshape(Expr a, Shape shape) { return Expression<ReshapeNodeOp>(a, shape); } @@ -137,6 +144,10 @@ Expr atleast_3d(Expr a, size_t dims) { return atleast_nd(a, 3); } +Expr atleast_4d(Expr a) { + return atleast_nd(a, 4); +} + Expr atleast_nd(Expr a, size_t dims) { if(a->shape().size() >= dims) return a; @@ -154,6 +165,15 @@ Expr flatten(Expr a) { return Expression<ReshapeNodeOp>(a, shape); } +Expr flatten_2d(Expr a) { + Shape shape = { + a->shape().elements() / a->shape()[-1], + a->shape()[-1] + }; + + return Expression<ReshapeNodeOp>(a, shape); +} + Expr rows(Expr a, const std::vector<size_t>& indices) { return Expression<RowsNodeOp>(a, indices); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 37e7c137..7302deab 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -73,15 +73,18 @@ Expr transpose(Expr a); Expr transpose(Expr a, const std::vector<int>& axes); Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0); +Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0); Expr reshape(Expr a, Shape shape); Expr atleast_1d(Expr a); Expr atleast_2d(Expr a); Expr atleast_3d(Expr a); +Expr atleast_4d(Expr a); Expr atleast_nd(Expr a, size_t dims); Expr flatten(Expr a); +Expr flatten_2d(Expr a); Expr rows(Expr a, const std::vector<size_t>& indices); Expr cols(Expr a, const std::vector<size_t>& indices); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index faf21dee..8390a0c2 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -39,6 +39,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct ScalarMultNodeOp : public UnaryNodeOp { @@ -61,6 +80,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct LogitNodeOp : public UnaryNodeOp { @@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp { const std::string type() { return "PReLU"; } + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, alpha_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node); + if(!cnode) + return false; + if(alpha_ != cnode->alpha_) + return false; + return true; + } + private: float alpha_{0.01}; }; @@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp { }; struct SquareNodeOp : public UnaryNodeOp { - float epsilon_; - template <typename... Args> SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp { template <typename... Args> RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyRows(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp { Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct ColsNodeOp : public UnaryNodeOp { template <typename... Args> ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyCols(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp { Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { return {NodeOp( - Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))}; + Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))}; } NodeOps backwardOps() { return {NodeOp( - Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))}; + Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))}; } Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { @@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp { if(!hash_) { size_t seed = NaryNodeOp::hash(); boost::hash_combine(seed, axis_); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp { return false; if(axis_ != cnode->axis_) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; int axis_{0}; }; diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index 3c017cb3..4088223a 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -1,4 +1,4 @@ -#include <thrust/transform_reduce.h> +#include <thrust/transform_reduce.h> #include "kernels/cuda_helpers.h" #include "kernels/tensor_operators.h" @@ -34,7 +34,6 @@ bool IsNan(Tensor in) { void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) { cudaSetDevice(out->getDevice()); - int step = 1; for(int i = 0; i < axis; ++i) step *= out->shape()[i]; @@ -1309,21 +1308,20 @@ __global__ void gAtt(float* out, const float* va, const float* ctx, const float* state, - const float* cov, int m, // total rows (batch x time x beam) int k, // depth int b, // batch size - int t // time of ctx + int t // time of ctx ) { int rows = m; int cols = k; - for(int bid = 0; bid < m; bid += gridDim.x) { + + for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; if(j < rows) { const float* vaRow = va; const float* ctxRow = ctx + (j % (b * t)) * cols; - const float* stateRow = state + (j / (b * t) + j % b) * cols; - const float* covRow = cov ? cov + (j % (b * t)) * cols : nullptr; + const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols; extern __shared__ float _share[]; float* _sum = _share + blockDim.x; @@ -1333,8 +1331,6 @@ __global__ void gAtt(float* out, int id = tid + threadIdx.x; if(id < cols) { float z = ctxRow[id] + stateRow[id]; - if(cov) - z += covRow[id]; float ex = tanhf(z) * vaRow[id]; _sum[threadIdx.x] += ex; } @@ -1354,7 +1350,7 @@ __global__ void gAtt(float* out, } } -void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage) { +void Att(Tensor out, Tensor va, Tensor context, Tensor state) { cudaSetDevice(out->getDevice()); size_t m = out->shape().elements() / out->shape().back(); @@ -1372,7 +1368,6 @@ void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage) { va->data(), context->data(), state->data(), - coverage ? coverage->data() : nullptr, m, k, b, @@ -1382,11 +1377,9 @@ void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage) { __global__ void gAttBack(float* gVa, float* gContext, float* gState, - float* gCoverage, const float* va, const float* context, const float* state, - const float* coverage, const float* adj, int m, // rows int k, // cols @@ -1399,26 +1392,20 @@ __global__ void gAttBack(float* gVa, if(j < rows) { float* gcRow = gContext + j * cols; float* gsRow = gState + (j % n) * cols; - float* gcovRow = gCoverage ? gCoverage + j * cols : nullptr; const float* cRow = context + j * cols; const float* sRow = state + (j % n) * cols; - const float* covRow = coverage ? coverage + j * cols : nullptr; for(int tid = 0; tid < cols; tid += blockDim.x) { int id = tid + threadIdx.x; if(id < cols) { float z = cRow[id] + sRow[id]; - if(coverage) - z += covRow[id]; float t = tanhf(z); float r = va[id] * (1.f - t * t); gcRow[id] += r * adj[j]; gsRow[id] += r * adj[j]; - if(gCoverage) - gcovRow[id] += r * adj[j]; atomicAdd(gVa + id, t * adj[j]); } } @@ -1429,11 +1416,9 @@ __global__ void gAttBack(float* gVa, void AttBack(Tensor gVa, Tensor gContext, Tensor gState, - Tensor gCoverage, Tensor va, Tensor context, Tensor state, - Tensor coverage, Tensor adj) { cudaSetDevice(adj->getDevice()); @@ -1449,12 +1434,10 @@ void AttBack(Tensor gVa, gAttBack<<<blocks, threads>>>(gVa->data(), gContext->data(), gState->data(), - gCoverage ? gCoverage->data() : nullptr, va->data(), context->data(), state->data(), - coverage ? coverage->data() : nullptr, adj->data(), m, diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h index 06cb188c..d317ec34 100644 --- a/src/kernels/tensor_operators.h +++ b/src/kernels/tensor_operators.h @@ -339,15 +339,13 @@ void GRUFastBackward(std::vector<Tensor> outputs, Tensor adj, bool final = false); -void Att(Tensor out, Tensor va, Tensor context, Tensor state, Tensor coverage); +void Att(Tensor out, Tensor va, Tensor context, Tensor state); void AttBack(Tensor gva, Tensor gContext, Tensor gState, - Tensor gCoverage, Tensor va, Tensor context, Tensor state, - Tensor coverage, Tensor adj); void LayerNormalization(Tensor out, diff --git a/src/models/encdec.h b/src/models/encdec.h index 8447a3a8..cfe304ae 100644 --- a/src/models/encdec.h +++ b/src/models/encdec.h @@ -124,15 +124,23 @@ public: virtual void selectEmbeddings(Ptr<ExpressionGraph> graph, Ptr<DecoderState> state, - const std::vector<size_t>& embIdx) { + const std::vector<size_t>& embIdx, + int beamSize) { using namespace keywords; int dimTrgEmb = opt<int>("dim-emb"); int dimTrgVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_]; + int dimBatch = 1; + if(state->getEncoderStates().size() > 0) + dimBatch = state->getEncoderStates()[0]->getContext()->shape()[-2]; + + int dimBeam = embIdx.size() / dimBatch; + Expr selectedEmbs; if(embIdx.empty()) { - selectedEmbs = graph->constant({1, 1, 1, dimTrgEmb}, init = inits::zeros); + selectedEmbs = graph->constant({1, 1, dimBatch, dimTrgEmb}, + init = inits::zeros); } else { // embeddings are loaded from model during translation, no fixing required auto yEmbFactory = embedding(graph) // @@ -148,7 +156,7 @@ public: selectedEmbs = rows(yEmb, embIdx); selectedEmbs - = reshape(selectedEmbs, {(int)embIdx.size(), 1, 1, dimTrgEmb}); + = reshape(selectedEmbs, {dimBeam, 1, dimBatch, dimTrgEmb}); } state->setTargetEmbeddings(selectedEmbs); } @@ -167,13 +175,15 @@ class EncoderDecoderBase : public models::ModelBase { public: virtual void selectEmbeddings(Ptr<ExpressionGraph> graph, Ptr<DecoderState> state, - const std::vector<size_t>&) + const std::vector<size_t>&, + int beamSize) = 0; virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph, Ptr<DecoderState>, const std::vector<size_t>&, - const std::vector<size_t>&) + const std::vector<size_t>&, + int beamSize) = 0; virtual Ptr<DecoderState> step(Ptr<ExpressionGraph>, Ptr<DecoderState>) = 0; @@ -300,9 +310,10 @@ public: virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph, Ptr<DecoderState> state, const std::vector<size_t>& hypIndices, - const std::vector<size_t>& embIndices) { - auto selectedState = hypIndices.empty() ? state : state->select(hypIndices); - selectEmbeddings(graph, selectedState, embIndices); + const std::vector<size_t>& embIndices, + int beamSize) { + auto selectedState = hypIndices.empty() ? state : state->select(hypIndices, beamSize); + selectEmbeddings(graph, selectedState, embIndices, beamSize); selectedState->setSingleStep(true); auto nextState = step(graph, selectedState); nextState->setProbs(logsoftmax(nextState->getProbs())); @@ -311,8 +322,9 @@ public: virtual void selectEmbeddings(Ptr<ExpressionGraph> graph, Ptr<DecoderState> state, - const std::vector<size_t>& embIdx) { - decoders_[0]->selectEmbeddings(graph, state, embIdx); + const std::vector<size_t>& embIdx, + int beamSize) { + decoders_[0]->selectEmbeddings(graph, state, embIdx, beamSize); } virtual Expr build(Ptr<ExpressionGraph> graph, diff --git a/src/models/hardatt.h b/src/models/hardatt.h index d8ca8961..e6383a43 100644 --- a/src/models/hardatt.h +++ b/src/models/hardatt.h @@ -16,13 +16,13 @@ public: : DecoderState(states, probs, encStates), attentionIndices_(attentionIndices) {} - virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx) { + virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx, int beamSize) { std::vector<size_t> selectedAttentionIndices; for(auto i : selIdx) selectedAttentionIndices.push_back(attentionIndices_[i]); return New<DecoderStateHardAtt>( - states_.select(selIdx), probs_, encStates_, selectedAttentionIndices); + states_.select(selIdx, beamSize), probs_, encStates_, selectedAttentionIndices); } virtual void setAttentionIndices( @@ -259,8 +259,9 @@ public: virtual void selectEmbeddings(Ptr<ExpressionGraph> graph, Ptr<DecoderState> state, - const std::vector<size_t>& embIdx) { - DecoderBase::selectEmbeddings(graph, state, embIdx); + const std::vector<size_t>& embIdx, + int beamSize) { + DecoderBase::selectEmbeddings(graph, state, embIdx, beamSize); auto stateHardAtt = std::dynamic_pointer_cast<DecoderStateHardAtt>(state); diff --git a/src/models/s2s.h b/src/models/s2s.h index 8da73bdd..ba738793 100644 --- a/src/models/s2s.h +++ b/src/models/s2s.h @@ -299,6 +299,8 @@ public: // apply RNN to embeddings, initialized with encoder context mapped into // decoder space + + auto states = state->getStates(); auto decoderContext = rnn_->transduce(embeddings, state->getStates()); // retrieve the last state per layer. They are required during translation diff --git a/src/models/states.h b/src/models/states.h index e7cef04c..d3ad2dfa 100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -50,8 +50,8 @@ public: virtual Expr getProbs() { return probs_; } virtual void setProbs(Expr probs) { probs_ = probs; } - virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx) { - return New<DecoderState>(states_.select(selIdx), probs_, encStates_); + virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx, int beamSize) { + return New<DecoderState>(states_.select(selIdx, beamSize), probs_, encStates_); } virtual const rnn::States& getStates() { return states_; } diff --git a/src/models/transformer.h b/src/models/transformer.h index 0c718db9..92030c83 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -50,7 +50,7 @@ public: // convert 0/1 mask to transformer style -inf mask auto ms = mask->shape(); mask = (1 - mask) * -99999999.f; - return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}); + return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}) ; } Expr SplitHeads(Expr input, int dimHeads) { @@ -171,9 +171,10 @@ public: // @TODO: do this better int dimBeamQ = q->shape()[-4]; int dimBeamK = k->shape()[-4]; - if(dimBeamQ != dimBeamK) { - k = concatenate(std::vector<Expr>(dimBeamQ, k), axis = -4); - v = concatenate(std::vector<Expr>(dimBeamQ, v), axis = -4); + int dimBeam = dimBeamQ / dimBeamK; + if(dimBeam > 1) { + k = repeat(k, dimBeam, axis = -4); + v = repeat(v, dimBeam, axis = -4); } auto weights = softmax(bdot(q, k, false, true, scale) + mask); @@ -237,6 +238,7 @@ public: // apply multi-head attention to downscaled inputs auto output = Attention(graph, options, prefix, qh, kh, vh, masks[i], inference); + output = JoinHeads(output, q->shape()[-4]); outputs.push_back(output); @@ -456,12 +458,23 @@ public: std::vector<Ptr<EncoderState>> &encStates) : DecoderState(states, probs, encStates) {} - virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx) { + virtual Ptr<DecoderState> select(const std::vector<size_t> &selIdx, int beamSize) { rnn::States selectedStates; - for(auto state : states_) - selectedStates.push_back( - {marian::select(state.output, -4, selIdx), nullptr}); + int dimDepth = states_[0].output->shape()[-1]; + int dimTime = states_[0].output->shape()[-2]; + int dimBatch = selIdx.size() / beamSize; + + std::vector<size_t> selIdx2; + for(auto i : selIdx) + for(int j = 0; j < dimTime; ++j) + selIdx2.push_back(i * dimTime + j); + + for(auto state : states_) { + auto sel = rows(flatten_2d(state.output), selIdx2); + sel = reshape(sel, {beamSize, dimBatch, dimTime, dimDepth}); + selectedStates.push_back({sel, nullptr}); + } return New<TransformerState>(selectedStates, probs_, encStates_); } @@ -497,6 +510,9 @@ public: //************************************************************************// int dimEmb = embeddings->shape()[-1]; + int dimBeam = 1; + if(embeddings->shape().size() > 3) + dimBeam = embeddings->shape()[-4]; // according to paper embeddings are scaled by \sqrt(d_m) auto scaledEmbeddings = std::sqrt(dimEmb) * embeddings; @@ -528,6 +544,8 @@ public: decoderMask = reshape(TransposeTimeBatch(decoderMask), {1, dimBatch, 1, dimTrgWords}); selfMask = selfMask * decoderMask; + //if(dimBeam > 1) + // selfMask = repeat(selfMask, dimBeam, axis = -4); } selfMask = InverseMask(selfMask); @@ -548,6 +566,8 @@ public: encoderMask = reshape(TransposeTimeBatch(encoderMask), {1, dimBatch, 1, dimSrcWords}); encoderMask = InverseMask(encoderMask); + if(dimBeam > 1) + encoderMask = repeat(encoderMask, dimBeam, axis = -4); encoderContexts.push_back(encoderContext); encoderMasks.push_back(encoderMask); @@ -557,8 +577,7 @@ public: for(int i = 1; i <= opt<int>("dec-depth"); ++i) { auto values = query; if(prevDecoderStates.size() > 0) - values - = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2); + values = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2); decoderStates.push_back({values, nullptr}); diff --git a/src/rnn/attention.cu b/src/rnn/attention.cu index 752692bc..232b99f3 100644 --- a/src/rnn/attention.cu +++ b/src/rnn/attention.cu @@ -25,8 +25,7 @@ struct AttentionNodeOp : public NaryNodeOp { return {NodeOp(Att(val_, child(0)->val(), child(1)->val(), - child(2)->val(), - children_.size() == 4 ? child(3)->val() : nullptr))}; + child(2)->val()))}; } NodeOps backwardOps() { @@ -34,11 +33,9 @@ struct AttentionNodeOp : public NaryNodeOp { NodeOp(AttBack(child(0)->grad(), child(1)->grad(), child(2)->grad(), - children_.size() == 4 ? child(3)->grad() : nullptr, child(0)->val(), child(1)->val(), child(2)->val(), - children_.size() == 4 ? child(3)->val() : nullptr, adj_);) }; } @@ -54,10 +51,8 @@ struct AttentionNodeOp : public NaryNodeOp { const std::string color() { return "yellow"; } }; -Expr attOps(Expr va, Expr context, Expr state, Expr coverage) { +Expr attOps(Expr va, Expr context, Expr state) { std::vector<Expr> nodes{va, context, state}; - if(coverage) - nodes.push_back(coverage); int dimBatch = context->shape()[-2]; int dimWords = context->shape()[-3]; diff --git a/src/rnn/attention.h b/src/rnn/attention.h index 6d182920..751d3a3c 100644 --- a/src/rnn/attention.h +++ b/src/rnn/attention.h @@ -11,7 +11,7 @@ namespace marian { namespace rnn { -Expr attOps(Expr va, Expr context, Expr state, Expr coverage = nullptr); +Expr attOps(Expr va, Expr context, Expr state); class GlobalAttention : public CellInput { private: @@ -147,7 +147,7 @@ public: auto alignedSource = scalar_product(encState_->getAttended(), e, axis = -3); - + contexts_.push_back(alignedSource); alignments_.push_back(e); return alignedSource; diff --git a/src/rnn/types.h b/src/rnn/types.h index 7358685b..17fa9bab 100644 --- a/src/rnn/types.h +++ b/src/rnn/types.h @@ -14,24 +14,26 @@ struct State { Expr output; Expr cell; - State select(const std::vector<size_t>& indices) { - if(output->shape().size() < 4) { - int dimState = output->shape()[-1]; - int dimBatch = output->shape()[-2]; - int dimTime = output->shape()[-3]; - - output = reshape(output, {1, dimTime, dimBatch, dimState}); - if(cell) - cell = reshape(cell, {1, dimTime, dimBatch, dimState}); - } + State select(const std::vector<size_t>& indices, int beamSize) { + output = atleast_4d(output); + if(cell) + cell = atleast_4d(cell); + + int dimDepth = output->shape()[-1]; + int dimTime = output->shape()[-3]; + + int dimBatch = indices.size() / beamSize; if(cell) { return State{ - marian::select(output, 0, indices), - marian::select(cell, 0, indices)}; + reshape(rows(flatten_2d(output), indices), + {beamSize, dimTime, dimBatch, dimDepth}), + reshape(rows(flatten_2d(cell), indices), + {beamSize, dimTime, dimBatch, dimDepth})}; } else { return State{ - marian::select(output, 0, indices), + reshape(rows(flatten_2d(output), indices), + {beamSize, dimTime, dimBatch, dimDepth}), nullptr}; } } @@ -72,10 +74,10 @@ public: void push_back(const State& state) { states_.push_back(state); } - States select(const std::vector<size_t>& indices) { + States select(const std::vector<size_t>& indices, int beamSize) { States selected; for(auto& state : states_) - selected.push_back(state.select(indices)); + selected.push_back(state.select(indices, beamSize)); return selected; } diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h index 8a84e4d7..72d4805a 100644 --- a/src/tensors/allocator.h +++ b/src/tensors/allocator.h @@ -179,7 +179,6 @@ public: auto ptr = gap.data(); auto mp = New<MemoryPiece>(ptr, bytes); allocated_[ptr] = mp; - return mp; } diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 465d5478..f962594c 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -1,17 +1,19 @@ # Unit tests -set(TEST_SOURCES - graph_tests.cpp - operator_tests.cpp - rnn_tests.cpp - attention_tests.cpp +set(UNIT_TESTS + graph_tests + operator_tests + rnn_tests + attention_tests ) -add_executable(run_tests run_tests.cpp ${TEST_SOURCES}) -target_link_libraries(run_tests marian ${EXT_LIBS} Catch) -cuda_add_cublas_to_target(run_tests) -set_target_properties(run_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") +foreach(test ${UNIT_TESTS}) + add_executable("run_${test}" run_tests.cpp "${test}.cpp") + target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch) + cuda_add_cublas_to_target("run_${test}") + + add_test(NAME ${test} COMMAND "run_${test}") +endforeach(test) -add_test(NAME GraphTest COMMAND run_tests) # Testing apps add_executable(logger_test logger_test.cpp) diff --git a/src/tests/README.md b/src/tests/README.md index 31750313..77b23a6b 100644 --- a/src/tests/README.md +++ b/src/tests/README.md @@ -2,7 +2,19 @@ Marian tests ============ Unit tests and application tests are enabled with CMake option -`-DCOMPILE_TESTS=ON`. +`-DCOMPILE_TESTS=ON`, e.g.: + + cd build + cmake .. -DCOMPILE_TESTS=ON + make -j8 + +Running all unit tests: + + make test + +Running a single unit test is also possible: + + ./src/tests/run_graph_tests We use [Catch framework](https://github.com/philsquared/Catch) for unit testing. diff --git a/src/tests/attention_tests.cpp b/src/tests/attention_tests.cpp index 722b82a7..2bf5e1f2 100644 --- a/src/tests/attention_tests.cpp +++ b/src/tests/attention_tests.cpp @@ -5,7 +5,7 @@ using namespace marian; TEST_CASE("Model components, Attention", "[attention]") { - auto floatApprox = [](float x, float y) { return x == Approx(y); }; + auto floatApprox = [](float x, float y) { return x == Approx(y).epsilon(0.01); }; std::vector<size_t> vWords = { 43, 2, 83, 78, diff --git a/src/tests/rnn_tests.cpp b/src/tests/rnn_tests.cpp index 1827a9bf..54e0e7b1 100644 --- a/src/tests/rnn_tests.cpp +++ b/src/tests/rnn_tests.cpp @@ -5,7 +5,7 @@ using namespace marian; TEST_CASE("Model components, RNN etc.", "[model]") { - auto floatApprox = [](float x, float y) { return x == Approx(y); }; + auto floatApprox = [](float x, float y) { return x == Approx(y).epsilon(0.01); }; std::vector<size_t> vWords = { 43, 2, 83, 78, diff --git a/src/training/validator.h b/src/training/validator.h index 1c6233e9..fc42c390 100644 --- a/src/training/validator.h +++ b/src/training/validator.h @@ -58,6 +58,7 @@ public: opts->set("max-length", options_->get<size_t>("valid-max-length")); if(options_->has("valid-mini-batch")) opts->set("mini-batch", options_->get<size_t>("valid-mini-batch")); + opts->set("mini-batch-sort", "src"); // Create corpus auto validPaths = options_->get<std::vector<std::string>>("valid-sets"); @@ -224,8 +225,8 @@ public: // Temporary options for translation auto opts = New<Config>(*options_); - opts->set("mini-batch", 1); - opts->set("maxi-batch", 1); + //opts->set("mini-batch", 1); + //opts->set("maxi-batch", 1); opts->set("max-length", 1000); // Create corpus @@ -298,15 +299,17 @@ public: } auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}); - auto history = search->search(graph, batch, id); - - std::stringstream best1; - std::stringstream bestn; - Printer(options_, vocabs_.back(), history, best1, bestn); - collector->Write(history->GetLineNum(), - best1.str(), - bestn.str(), - options_->get<bool>("n-best")); + auto histories = search->search(graph, batch); + + for(auto history : histories) { + std::stringstream best1; + std::stringstream bestn; + Printer(options_, vocabs_.back(), history, best1, bestn); + collector->Write(history->GetLineNum(), + best1.str(), + bestn.str(), + options_->get<bool>("n-best")); + } }; threadPool.enqueue(task, sentenceId); diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h index eb25d98f..178e9238 100644 --- a/src/translator/beam_search.h +++ b/src/translator/beam_search.h @@ -1,4 +1,5 @@ #pragma once +#include <algorithm> #include "marian.h" #include "translator/history.h" @@ -26,51 +27,92 @@ public: ? options_->get<size_t>("beam-size") : 3) {} - Beam toHyps(const std::vector<uint> keys, - const std::vector<float> costs, - size_t vocabSize, - const Beam& beam, - std::vector<Ptr<ScorerState>>& states) { - Beam newBeam; - for(int i = 0; i < keys.size(); ++i) { - int embIdx = keys[i] % vocabSize; - int hypIdx = keys[i] / vocabSize; - float cost = costs[i]; - - std::vector<float> breakDown(states.size(), 0); - beam[hypIdx]->GetCostBreakdown().resize(states.size(), 0); - - for(int j = 0; j < states.size(); ++j) - breakDown[j] = states[j]->breakDown(keys[i]) - + beam[hypIdx]->GetCostBreakdown()[j]; + Beams toHyps(const std::vector<uint> keys, + const std::vector<float> costs, + size_t vocabSize, + const Beams& beams, + std::vector<Ptr<ScorerState>>& states, + size_t beamSize, + bool first) { - auto hyp = New<Hypothesis>(beam[hypIdx], embIdx, hypIdx, cost); - hyp->GetCostBreakdown() = breakDown; - newBeam.push_back(hyp); + Beams newBeams(beams.size()); + for(int i = 0; i < keys.size(); ++i) { + int embIdx = keys[i] % vocabSize; + int beamIdx = i / beamSize; + + if(newBeams[beamIdx].size() < beams[beamIdx].size()) { + auto& beam = beams[beamIdx]; + auto& newBeam = newBeams[beamIdx]; + + int hypIdx = keys[i] / vocabSize; + float cost = costs[i]; + + int hypIdxTrans = (hypIdx / beamSize) + + (hypIdx % beamSize) * beams.size(); + if(first) + hypIdxTrans = hypIdx; + + int beamHypIdx = hypIdx % beamSize; + if(beamHypIdx >= beam.size()) + beamHypIdx = beamHypIdx % beam.size(); + + if(first) + beamHypIdx = 0; + + auto hyp = New<Hypothesis>(beam[beamHypIdx], embIdx, hypIdxTrans, cost); + if(options_->get<bool>("n-best")) { + std::vector<float> breakDown(states.size(), 0); + beam[beamHypIdx]->GetCostBreakdown().resize(states.size(), 0); + for(int j = 0; j < states.size(); ++j) { + int key = embIdx + hypIdxTrans * vocabSize; + breakDown[j] = states[j]->breakDown(key) + + beam[beamHypIdx]->GetCostBreakdown()[j]; + } + hyp->GetCostBreakdown() = breakDown; + } + newBeam.push_back(hyp); + } } - return newBeam; + return newBeams; } - Beam pruneBeam(const Beam& beam) { - Beam newBeam; - for(auto hyp : beam) { - if(hyp->GetWord() > 0) { - newBeam.push_back(hyp); + Beams pruneBeam(const Beams& beams) { + Beams newBeams; + for(auto beam: beams) { + Beam newBeam; + for(auto hyp : beam) { + if(hyp->GetWord() > 0) { + newBeam.push_back(hyp); + } } + newBeams.push_back(newBeam); } - return newBeam; + return newBeams; } - Ptr<History> search(Ptr<ExpressionGraph> graph, - Ptr<data::CorpusBatch> batch, - size_t sentenceId = 0) { - auto history = New<History>(sentenceId, options_->get<float>("normalize")); - Beam beam(1, New<Hypothesis>()); + Histories search(Ptr<ExpressionGraph> graph, + Ptr<data::CorpusBatch> batch) { + + int dimBatch = batch->size(); + Histories histories; + for(int i = 0; i < dimBatch; ++i) { + size_t sentId = batch->getSentenceIds()[i]; + auto history = New<History>(sentId, options_->get<float>("normalize")); + histories.push_back(history); + } + + size_t localBeamSize = beamSize_; + auto nth = New<NthElement>(localBeamSize, dimBatch); + + Beams beams(dimBatch); + for(auto& beam : beams) + beam.resize(localBeamSize, New<Hypothesis>()); + bool first = true; bool final = false; - std::vector<size_t> beamSizes(1, beamSize_); - auto nth = New<NthElement>(beamSize_, batch->size()); - history->Add(beam); + + for(int i = 0; i < dimBatch; ++i) + histories[i]->Add(beams[i]); std::vector<Ptr<ScorerState>> states; @@ -94,13 +136,28 @@ public: keywords::init = inits::from_value(0)); } else { std::vector<float> beamCosts; - for(auto hyp : beam) { - hypIndices.push_back(hyp->GetPrevStateIndex()); - embIndices.push_back(hyp->GetWord()); - beamCosts.push_back(hyp->GetCost()); + + int dimBatch = batch->size(); + + for(int i = 0; i < localBeamSize; ++i) { + for(int j = 0; j < beams.size(); ++j) { + auto& beam = beams[j]; + if(i < beam.size()) { + auto hyp = beam[i]; + hypIndices.push_back(hyp->GetPrevStateIndex()); + embIndices.push_back(hyp->GetWord()); + beamCosts.push_back(hyp->GetCost()); + } + else { + hypIndices.push_back(0); + embIndices.push_back(0); + beamCosts.push_back(-9999); + } + } } + prevCosts - = graph->constant({(int)beamCosts.size(), 1, 1, 1}, + = graph->constant({(int)localBeamSize, 1, dimBatch, 1}, keywords::init = inits::from_vector(beamCosts)); } @@ -109,11 +166,18 @@ public: auto totalCosts = prevCosts; for(int i = 0; i < scorers_.size(); ++i) { - states[i] = scorers_[i]->step(graph, states[i], hypIndices, embIndices); - totalCosts - = totalCosts + scorers_[i]->getWeight() * states[i]->getProbs(); + states[i] = scorers_[i]->step(graph, states[i], hypIndices, embIndices, localBeamSize); + + if(scorers_[i]->getWeight() != 1.f) + totalCosts = totalCosts + scorers_[i]->getWeight() * states[i]->getProbs(); + else + totalCosts = totalCosts + states[i]->getProbs(); } + // make beams continuous + if(dimBatch > 1 && localBeamSize > 1) + totalCosts = transpose(totalCosts, {2, 1, 0, 3}); + if(first) graph->forward(); else @@ -131,21 +195,33 @@ public: std::vector<unsigned> outKeys; std::vector<float> outCosts; - beamSizes[0] = first ? beamSize_ : beam.size(); + std::vector<size_t> beamSizes(dimBatch, localBeamSize); nth->getNBestList(beamSizes, totalCosts->val(), outCosts, outKeys, first); int dimTrgVoc = totalCosts->shape()[-1]; - beam = toHyps(outKeys, outCosts, dimTrgVoc, beam, states); - - final = history->size() >= 3 * batch->words(); - history->Add(beam, final); - beam = pruneBeam(beam); + beams = toHyps(outKeys, outCosts, dimTrgVoc, beams, states, localBeamSize, first); + auto prunedBeams = pruneBeam(beams); + for(int i = 0; i < dimBatch; ++i) { + if(!beams[i].empty()) { + final = final || histories[i]->size() >= 3 * batch->front()->batchWidth(); + histories[i]->Add(beams[i], prunedBeams[i].empty() || final); + } + } + beams = prunedBeams; + + if(!first) { + size_t maxBeam = 0; + for(auto& beam : beams) + if(beam.size() > maxBeam) + maxBeam = beam.size(); + localBeamSize = maxBeam; + } first = false; - } while(!beam.empty() && !final); + } while(localBeamSize != 0 && !final); - return history; + return histories; } }; } diff --git a/src/translator/history.h b/src/translator/history.h index dbdf18a5..0070b723 100644 --- a/src/translator/history.h +++ b/src/translator/history.h @@ -27,6 +27,7 @@ public: if(beam[j]->GetWord() == 0 || last) { float cost = beam[j]->GetCost() / LengthPenalty(history_.size()); topHyps_.push({history_.size(), j, cost}); + //std::cerr << "Add " << history_.size() << " " << j << " " << cost << std::endl; } } history_.push_back(beam); @@ -43,11 +44,14 @@ public: size_t start = bestHypCoord.i; size_t j = bestHypCoord.j; + //float c = bestHypCoord.cost; + //std::cerr << "h: " << start << " " << j << " " << c << std::endl; Words targetWords; Ptr<Hypothesis> bestHyp = history_[start][j]; while(bestHyp->GetPrevHyp() != nullptr) { targetWords.push_back(bestHyp->GetWord()); + //std::cerr << bestHyp->GetWord() << " " << bestHyp << std::endl; bestHyp = bestHyp->GetPrevHyp(); } @@ -70,5 +74,5 @@ private: float alpha_; }; -typedef std::vector<History> Histories; +typedef std::vector<Ptr<History>> Histories; } diff --git a/src/translator/scorers.h b/src/translator/scorers.h index 171abccd..e963efb0 100644 --- a/src/translator/scorers.h +++ b/src/translator/scorers.h @@ -33,7 +33,8 @@ public: virtual Ptr<ScorerState> step(Ptr<ExpressionGraph>, Ptr<ScorerState>, const std::vector<size_t>&, - const std::vector<size_t>&) + const std::vector<size_t>&, + int beamSize) = 0; virtual void init(Ptr<ExpressionGraph> graph) {} @@ -88,12 +89,13 @@ public: virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph, Ptr<ScorerState> state, const std::vector<size_t>& hypIndices, - const std::vector<size_t>& embIndices) { + const std::vector<size_t>& embIndices, + int beamSize) { graph->switchParams(getName()); auto wrappedState = std::dynamic_pointer_cast<ScorerWrapperState>(state)->getState(); return New<ScorerWrapperState>( - encdec_->step(graph, wrappedState, hypIndices, embIndices)); + encdec_->step(graph, wrappedState, hypIndices, embIndices, beamSize)); } }; @@ -138,7 +140,8 @@ public: virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph, Ptr<ScorerState> state, const std::vector<size_t>& hypIndices, - const std::vector<size_t>& embIndices) { + const std::vector<size_t>& embIndices, + int beamSize) { return state; } }; @@ -173,7 +176,8 @@ public: virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph, Ptr<ScorerState> state, const std::vector<size_t>& hypIndices, - const std::vector<size_t>& embIndices) { + const std::vector<size_t>& embIndices, + int beamSize) { return state; } }; diff --git a/src/translator/translator.h b/src/translator/translator.h index bc96803d..598aee2b 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -62,7 +62,7 @@ public: auto devices = options_->get<std::vector<int>>("devices"); ThreadPool threadPool(devices.size(), devices.size()); - size_t sentenceId = 0; + size_t batchId = 0; auto collector = New<OutputCollector>(); if(options_->get<bool>("quiet-translation")) collector->setPrintingStrategy(New<QuietPrinting>()); @@ -83,18 +83,20 @@ public: } auto search = New<Search>(options_, scorers); - auto history = search->search(graph, batch, id); - - std::stringstream best1; - std::stringstream bestn; - Printer(options_, trgVocab_, history, best1, bestn); - collector->Write(history->GetLineNum(), - best1.str(), - bestn.str(), - options_->get<bool>("n-best")); + auto histories = search->search(graph, batch); + + for(auto history : histories) { + std::stringstream best1; + std::stringstream bestn; + Printer(options_, trgVocab_, history, best1, bestn); + collector->Write(history->GetLineNum(), + best1.str(), + bestn.str(), + options_->get<bool>("n-best")); + } }; - threadPool.enqueue(task, sentenceId++); + threadPool.enqueue(task, batchId++); } } }; @@ -150,7 +152,7 @@ public: data::BatchGenerator<data::TextInput> bg(corpus_, options_); auto collector = New<StringCollector>(); - size_t sentenceId = 0; + size_t batchId = 0; bg.prepare(false); @@ -171,16 +173,18 @@ public: } auto search = New<Search>(options_, scorers); - auto history = search->search(graph, batch, id); + auto histories = search->search(graph, batch); - std::stringstream best1; - std::stringstream bestn; - Printer(options_, trgVocab_, history, best1, bestn); - collector->add(history->GetLineNum(), best1.str(), bestn.str()); + for(auto history : histories) { + std::stringstream best1; + std::stringstream bestn; + Printer(options_, trgVocab_, history, best1, bestn); + collector->add(history->GetLineNum(), best1.str(), bestn.str()); + } }; - threadPool_.enqueue(task, sentenceId); - sentenceId++; + threadPool_.enqueue(task, batchId); + batchId++; } } |