Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-11-22 06:32:54 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-11-22 06:32:54 +0300
commitc85d0608483789d446361ea28d95f7d7c9545f2d (patch)
tree126a83e9fcbe9618cb21cbdd516d1dad1e24f2d1
parent1404201926b5b4e27993776d52dfac809e8556f4 (diff)
Merged PR 20729: Add top-k sampling
This adds Top-K sampling to Marian and extends the --output-sampling option to take arguments
m---------regression-tests0
-rw-r--r--src/common/config_parser.cpp7
-rw-r--r--src/graph/expression_operators.cpp7
-rw-r--r--src/graph/expression_operators.h15
-rw-r--r--src/graph/node_operators_binary.h61
-rw-r--r--src/graph/node_operators_tuple.h2
-rw-r--r--src/models/costs.cpp35
-rw-r--r--src/models/costs.h32
-rw-r--r--src/models/model_factory.cpp21
-rwxr-xr-xsrc/tensors/cpu/tensor_operators.cpp9
-rw-r--r--src/tensors/gpu/tensor_operators.cu63
-rw-r--r--src/tensors/tensor_operators.h23
-rw-r--r--src/translator/translator.h2
13 files changed, 225 insertions, 52 deletions
diff --git a/regression-tests b/regression-tests
-Subproject 7d612ca5e4b27a76f92584dad76d240e34f216d
+Subproject 0aa7b6b7632732d1f22f3d8169d3262a7e6b1e9
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 51764cdc..59b328e9 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -695,9 +695,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Use softmax shortlist: path first best prune");
cli.add<std::vector<float>>("--weights",
"Scorer weights");
- cli.add<bool>("--output-sampling",
- "Noise output layer with gumbel noise",
- false);
+ cli.add<std::vector<std::string>>("--output-sampling",
+ "Noise output layer with gumbel noise. Implicit default is 'full' for sampling from full distribution. "
+ " Also accepts 'topk num' (e.g. topk 100) for top-100 sampling.")
+ ->implicit_val("full");
cli.add<std::vector<int>>("--output-approx-knn",
"Use approximate knn search in output layer (currently only in transformer)")
->implicit_val("100 1024");
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 560ab4e7..b26c2ae0 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -357,6 +357,13 @@ Expr gather(Expr a, int axis, Expr indices) {
return Expression<GatherNodeOp>(a, axis, indices);
}
+// scatter() -- scatter arbitrary elements along an axis; batched or non-batched
+// This is the reverse operation to gather.
+Expr scatter(Expr a, int axis, Expr indices, Expr source) {
+ return Expression<ScatterNodeOp>(a, axis, indices, source);
+}
+
+
// index_select() -- gather arbitrary elements along an axis from an unbatched
// input 'a'. Indices are specified as a 1D vector.
// This is used e.g. for embedding lookup.
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index e34ddc8a..d032e8d3 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -707,10 +707,23 @@ Expr stopGradient(Expr a);
* @param indices The indices to be gathered
* @returns Gathered expression with the same shape as @p indices
* @note @p a and @p indices must have the same rank
- * @note The non-target axes of @p a and @p indicies must have the same size, or be broadcastable.
+ * @note The non-target axes of @p a and @p indices must have the same size, or be broadcastable.
*/
Expr gather(Expr a, int axis, Expr indices);
+/**
+ * Scatter elements from source along an axis into a. Unindexed elements from a remain unchanged.
+ * This is the reverse operation to gather.
+ * @param a The input expression
+ * @param axis The axis along which to index
+ * @param indices The indices to be scattered
+ * @param source Expression with values to scatter.
+ * @returns Scattered expression with the same shape as @p a now containing values from @p source in positions @p indices
+ * @note @p source and @p indices must have the same rank
+ * @note In this version @p source and @p indicies must have the same shape
+ */
+Expr scatter(Expr a, int axis, Expr indices, Expr source);
+
#if 0
// reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis.
// with broadcasting
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index a180bb5c..b2a646b1 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -1033,12 +1033,14 @@ struct GatherNodeOp : public NaryNodeOp {
NodeOps forwardOps() override {
return {NodeOp(
+ // @TODO: rename to gather
Select(val_, child(0)->val(), child(1)->val(), axis_))};
}
NodeOps backwardOps() override {
return {NodeOp(
- Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
+ // @TODO: rename to scatter
+ Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))};
}
Shape newShape(Expr a, int axis, Expr indices) {
@@ -1046,7 +1048,6 @@ struct GatherNodeOp : public NaryNodeOp {
axis = shape.axis(axis);
auto rank = shape.size();
ABORT_IF(rank != indices->shape().size(), "Mismatching ranks for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
- axis = a->shape().axis(axis);
shape.set(axis, indices->shape()[axis]);
for (size_t i = 0; i < rank; ++i) {
if (i != axis) {
@@ -1086,6 +1087,62 @@ private:
int axis_;
};
+struct ScatterNodeOp : public NaryNodeOp {
+ ScatterNodeOp(Expr a, int axis, Expr indices, Expr source)
+ : NaryNodeOp({a, indices, source}, newShape(a, axis, indices, source), a->value_type()),
+ axis_(a->shape().axis(axis)) {
+ matchOrAbort<IndexType>(indices->value_type());
+ }
+
+ NodeOps forwardOps() override {
+ return {NodeOp(
+ CopyCast(val_, child(0)->val()); // @TODO: use normal copy
+ Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_)
+ )};
+ }
+
+ NodeOps backwardOps() override {
+ ABORT("backward for ScatterNodeOp not yet implemented");
+ }
+
+ Shape newShape(Expr a, int axis, Expr indices, Expr source) {
+ ABORT_IF(axis != -1, "only last dimensions");
+ ABORT_IF(indices->shape() != source->shape(), "Shapes must match");
+
+ Shape shape = a->shape();
+ // @TODO: do proper checking
+ return shape;
+ }
+
+ const std::string type() override { return "scatter"; }
+
+ const std::string color() override { return "orange"; }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, axis_);
+ hash_ = seed;
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScatterNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(axis_ != cnode->axis_)
+ return false;
+ return true;
+ }
+
+private:
+ friend class SerializationHelpers;
+ int axis_;
+};
+
struct ColsNodeOp : public NaryNodeOp {
ColsNodeOp(Expr a, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
diff --git a/src/graph/node_operators_tuple.h b/src/graph/node_operators_tuple.h
index c7a9531a..8acb1bc8 100644
--- a/src/graph/node_operators_tuple.h
+++ b/src/graph/node_operators_tuple.h
@@ -133,7 +133,7 @@ public:
}
void backward() override {
- Insert(/*out*/child(0)->grad(), adj_, val_, axis_);
+ Insert</*add=*/true>(/*out*/child(0)->grad(), adj_, val_, axis_);
}
const std::string type() override { return "topk"; }
diff --git a/src/models/costs.cpp b/src/models/costs.cpp
index c688b211..4b15bcb3 100644
--- a/src/models/costs.cpp
+++ b/src/models/costs.cpp
@@ -10,5 +10,40 @@ Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
return state;
}
+Ptr<DecoderState> GumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
+ state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
+ [](Expr logits) { // lemma gets gumbelled
+ return logsoftmax(logits + constant_like(logits, inits::gumbel()));
+ },
+ logsoftmax)); // factors don't
+ return state;
+}
+
+TopkGumbelSoftmaxStep::TopkGumbelSoftmaxStep(int k) : k_{k} {}
+
+Ptr<DecoderState> TopkGumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
+ state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
+ [=](Expr logits) { // lemma gets gumbelled
+ // create logits-sized tensor consisting only of invalid path scores
+ float invalidPathScore = NumericLimits<float>(logits->value_type()).lowest;
+ Expr invalidLogits = constant_like(logits, inits::fromValue(invalidPathScore));
+
+ // select top-k values
+ Expr val, idx;
+ std::tie(val, idx) = topk(logits, k_, /*axis=*/-1, /*descending=*/true);
+
+ // uncomment below to display probability mass in top-k selection
+ // debug(sum(gather(softmax(logits), -1, idx), -1), "sum");
+
+ // Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
+ Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
+
+ // Scatter gumbelled values back into logits to fill with usable values
+ return scatter(invalidLogits, -1, idx, gumbelVal);
+ },
+ logsoftmax)); // factors don't
+ return state;
+}
+
} // namespace models
} // namespace marian
diff --git a/src/models/costs.h b/src/models/costs.h
index e5463bfd..a087ed6a 100644
--- a/src/models/costs.h
+++ b/src/models/costs.h
@@ -297,20 +297,30 @@ public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
-// Gumbel-max noising for sampling during beam-search
-// Seems to work well enough with beam-size=1. Turn on
-// with --output-sampling during translation with marian-decoder
+// Gumbel-max noising for sampling during translation.
+// Produces accurate sampling with beam=1. Turn on
+// with --output-sampling [full] during translation
+// with marian-decoder for samnpling from the full
+// softmax distribution.
class GumbelSoftmaxStep : public ILogProbStep {
public:
virtual ~GumbelSoftmaxStep() {}
- virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
- state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
- [](Expr logits) { // lemma gets gumbelled
- return logsoftmax(logits + constant_like(logits, inits::gumbel()));
- },
- logsoftmax)); // factors don't
- return state;
- }
+ virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
+};
+
+
+// Gumbel-max noising for top-k sampling during translation.
+// Produces accurate sampling with beam=1. Turn on
+// with --output-sampling topk [10] during translation
+// with marian-decoder for top-10 sampling.
+class TopkGumbelSoftmaxStep : public ILogProbStep {
+private:
+ int k_{1};
+
+public:
+ TopkGumbelSoftmaxStep(int k);
+ virtual ~TopkGumbelSoftmaxStep() {}
+ virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp
index e176e6a4..52a87e72 100644
--- a/src/models/model_factory.cpp
+++ b/src/models/model_factory.cpp
@@ -370,10 +370,25 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
// add (log)softmax if requested
if (use == usage::translation) {
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
- if(options->get<bool>("output-sampling", false))
- return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
- else
+ if(options->hasAndNotEmpty("output-sampling")) {
+ auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
+ std::string method = sampling.size() > 0 ? sampling[0] : "full";
+
+ if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
+ LOG(info, "Output sampling from the full softmax distribution");
+ return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
+ } else if(method == "topk") {
+ int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10;
+ if(k == 1)
+ LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
+ LOG(info, "Output sampling via top-{} sampling", k);
+ return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<TopkGumbelSoftmaxStep>(k));
+ } else {
+ ABORT("Unknown sampling method: {}", method);
+ }
+ } else {
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
+ }
}
#ifdef COMPILE_EXAMPLES
// note: 'usage::translation' here means 'inference'
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index f3964f91..1e1adc38 100755
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -739,6 +739,7 @@ void Select(Tensor out,
}
}
+template <bool add>
void Insert(Tensor out,
const Tensor in,
const Tensor indices,
@@ -760,10 +761,16 @@ void Insert(Tensor out,
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex];
int outIndex = outShape.index(dims);
- out->data()[outIndex] += in->data()[index];
+ if(add)
+ out->data()[outIndex] += in->data()[index];
+ else
+ out->data()[outIndex] = in->data()[index];
}
}
+template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
+template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
+
void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) {
int rows = out_->shape().elements() / out_->shape().back();
int cols = out_->shape().back();
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu
index 1347c3bb..2103ca9d 100644
--- a/src/tensors/gpu/tensor_operators.cu
+++ b/src/tensors/gpu/tensor_operators.cu
@@ -1309,7 +1309,7 @@ __global__ void gSelect(T* out,
}
}
-template <typename T>
+template <bool add, typename T>
__global__ void gInsert(T* out,
functional::Shape outShape,
const T* in,
@@ -1327,7 +1327,10 @@ __global__ void gInsert(T* out,
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axis] = (int)d_indices[idxIndex];
int outIndex = outShape.index(dims);
- out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
+ if(add)
+ out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
+ else
+ out[outIndex] = in[index];
}
}
}
@@ -1349,21 +1352,21 @@ void Select(Tensor out,
if(out->type() == Type::float32) {
gSelect<<<blocks, threads>>>(out->data<float>(),
- out->shape(),
- in->data<float>(),
- in->shape(),
- axisGPU,
- indices->data<IndexType>(),
- indices->shape());
+ out->shape(),
+ in->data<float>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gSelect<<<blocks, threads>>>(out->data<half>(),
- out->shape(),
- in->data<half>(),
- in->shape(),
- axisGPU,
- indices->data<IndexType>(),
- indices->shape());
+ out->shape(),
+ in->data<half>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
#endif
} else if(out->type() == Type::uint32) {
gSelect<<<blocks, threads>>>(out->data<IndexType>(),
@@ -1378,6 +1381,7 @@ void Select(Tensor out,
}
}
+template <bool add>
void Insert(Tensor out,
const Tensor in,
const Tensor indices,
@@ -1393,28 +1397,31 @@ void Insert(Tensor out,
int axisGPU = axis + functional::Shape::size() - out->shape().size();
if(out->type() == Type::float32) {
- gInsert<<<blocks, threads>>>(out->data<float>(),
- out->shape(),
- in->data<float>(),
- in->shape(),
- axisGPU,
- indices->data<IndexType>(),
- indices->shape());
+ gInsert<add><<<blocks, threads>>>(out->data<float>(),
+ out->shape(),
+ in->data<float>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
- gInsert<<<blocks, threads>>>(out->data<half>(),
- out->shape(),
- in->data<half>(),
- in->shape(),
- axisGPU,
- indices->data<IndexType>(),
- indices->shape());
+ gInsert<add><<<blocks, threads>>>(out->data<half>(),
+ out->shape(),
+ in->data<half>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
#endif
} else {
ABORT("Insert not implemented for type {}", out->type());
}
}
+template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
+template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
+
template <typename T>
__global__ void gGRUFastForward(T* out,
const T* state,
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index dc29bf35..1fc4542d 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -297,7 +297,28 @@ DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH4(Select, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
-DISPATCH4(Insert, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
+
+#ifdef CUDA_FOUND
+namespace gpu {
+ template <bool add>
+ void Insert(Tensor out, const Tensor in, const Tensor indices, int axis);
+}
+#endif
+
+namespace cpu {
+ template <bool add>
+ void Insert(Tensor out, const Tensor in, const Tensor indices, int axis);
+}
+
+template <bool add>
+static inline void Insert(Tensor out, const Tensor in, const Tensor indices, int axis) {
+#ifdef CUDA_FOUND
+ if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
+ gpu::Insert<add>(out, in, indices, axis);
+ else
+#endif
+ cpu::Insert<add>(out, in, indices, axis);
+}
DISPATCH7(TopK, marian::Tensor, marian::Tensor, Ptr<Allocator>, const marian::Tensor, int, int, bool);
diff --git a/src/translator/translator.h b/src/translator/translator.h
index db1f3d03..3e375f65 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -119,7 +119,7 @@ public:
threadPool.enqueue(task, device, id++);
}
- if(options_->get<bool>("output-sampling", false)) {
+ if(options_->hasAndNotEmpty("output-sampling")) {
if(options_->get<size_t>("beam-size") > 1)
LOG(warn,
"[warning] Output sampling and beam search (beam-size > 1) are contradictory methods "