diff options
author | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-11-22 06:32:54 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-11-22 06:32:54 +0300 |
commit | c85d0608483789d446361ea28d95f7d7c9545f2d (patch) | |
tree | 126a83e9fcbe9618cb21cbdd516d1dad1e24f2d1 | |
parent | 1404201926b5b4e27993776d52dfac809e8556f4 (diff) |
Merged PR 20729: Add top-k sampling
This adds Top-K sampling to Marian and extends the --output-sampling option to take arguments
m--------- | regression-tests | 0 | ||||
-rw-r--r-- | src/common/config_parser.cpp | 7 | ||||
-rw-r--r-- | src/graph/expression_operators.cpp | 7 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 15 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 61 | ||||
-rw-r--r-- | src/graph/node_operators_tuple.h | 2 | ||||
-rw-r--r-- | src/models/costs.cpp | 35 | ||||
-rw-r--r-- | src/models/costs.h | 32 | ||||
-rw-r--r-- | src/models/model_factory.cpp | 21 | ||||
-rwxr-xr-x | src/tensors/cpu/tensor_operators.cpp | 9 | ||||
-rw-r--r-- | src/tensors/gpu/tensor_operators.cu | 63 | ||||
-rw-r--r-- | src/tensors/tensor_operators.h | 23 | ||||
-rw-r--r-- | src/translator/translator.h | 2 |
13 files changed, 225 insertions, 52 deletions
diff --git a/regression-tests b/regression-tests -Subproject 7d612ca5e4b27a76f92584dad76d240e34f216d +Subproject 0aa7b6b7632732d1f22f3d8169d3262a7e6b1e9 diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 51764cdc..59b328e9 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -695,9 +695,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { "Use softmax shortlist: path first best prune"); cli.add<std::vector<float>>("--weights", "Scorer weights"); - cli.add<bool>("--output-sampling", - "Noise output layer with gumbel noise", - false); + cli.add<std::vector<std::string>>("--output-sampling", + "Noise output layer with gumbel noise. Implicit default is 'full' for sampling from full distribution. " + " Also accepts 'topk num' (e.g. topk 100) for top-100 sampling.") + ->implicit_val("full"); cli.add<std::vector<int>>("--output-approx-knn", "Use approximate knn search in output layer (currently only in transformer)") ->implicit_val("100 1024"); diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 560ab4e7..b26c2ae0 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -357,6 +357,13 @@ Expr gather(Expr a, int axis, Expr indices) { return Expression<GatherNodeOp>(a, axis, indices); } +// scatter() -- scatter arbitrary elements along an axis; batched or non-batched +// This is the reverse operation to gather. +Expr scatter(Expr a, int axis, Expr indices, Expr source) { + return Expression<ScatterNodeOp>(a, axis, indices, source); +} + + // index_select() -- gather arbitrary elements along an axis from an unbatched // input 'a'. Indices are specified as a 1D vector. // This is used e.g. for embedding lookup. diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index e34ddc8a..d032e8d3 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -707,10 +707,23 @@ Expr stopGradient(Expr a); * @param indices The indices to be gathered * @returns Gathered expression with the same shape as @p indices * @note @p a and @p indices must have the same rank - * @note The non-target axes of @p a and @p indicies must have the same size, or be broadcastable. + * @note The non-target axes of @p a and @p indices must have the same size, or be broadcastable. */ Expr gather(Expr a, int axis, Expr indices); +/** + * Scatter elements from source along an axis into a. Unindexed elements from a remain unchanged. + * This is the reverse operation to gather. + * @param a The input expression + * @param axis The axis along which to index + * @param indices The indices to be scattered + * @param source Expression with values to scatter. + * @returns Scattered expression with the same shape as @p a now containing values from @p source in positions @p indices + * @note @p source and @p indices must have the same rank + * @note In this version @p source and @p indicies must have the same shape + */ +Expr scatter(Expr a, int axis, Expr indices, Expr source); + #if 0 // reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis. // with broadcasting diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index a180bb5c..b2a646b1 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -1033,12 +1033,14 @@ struct GatherNodeOp : public NaryNodeOp { NodeOps forwardOps() override { return {NodeOp( + // @TODO: rename to gather Select(val_, child(0)->val(), child(1)->val(), axis_))}; } NodeOps backwardOps() override { return {NodeOp( - Insert(child(0)->grad(), adj_, child(1)->val(), axis_))}; + // @TODO: rename to scatter + Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))}; } Shape newShape(Expr a, int axis, Expr indices) { @@ -1046,7 +1048,6 @@ struct GatherNodeOp : public NaryNodeOp { axis = shape.axis(axis); auto rank = shape.size(); ABORT_IF(rank != indices->shape().size(), "Mismatching ranks for input ({}) and indices ({})", std::string(shape), std::string(indices->shape())); - axis = a->shape().axis(axis); shape.set(axis, indices->shape()[axis]); for (size_t i = 0; i < rank; ++i) { if (i != axis) { @@ -1086,6 +1087,62 @@ private: int axis_; }; +struct ScatterNodeOp : public NaryNodeOp { + ScatterNodeOp(Expr a, int axis, Expr indices, Expr source) + : NaryNodeOp({a, indices, source}, newShape(a, axis, indices, source), a->value_type()), + axis_(a->shape().axis(axis)) { + matchOrAbort<IndexType>(indices->value_type()); + } + + NodeOps forwardOps() override { + return {NodeOp( + CopyCast(val_, child(0)->val()); // @TODO: use normal copy + Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_) + )}; + } + + NodeOps backwardOps() override { + ABORT("backward for ScatterNodeOp not yet implemented"); + } + + Shape newShape(Expr a, int axis, Expr indices, Expr source) { + ABORT_IF(axis != -1, "only last dimensions"); + ABORT_IF(indices->shape() != source->shape(), "Shapes must match"); + + Shape shape = a->shape(); + // @TODO: do proper checking + return shape; + } + + const std::string type() override { return "scatter"; } + + const std::string color() override { return "orange"; } + + virtual size_t hash() override { + if(!hash_) { + size_t seed = NaryNodeOp::hash(); + util::hash_combine(seed, axis_); + hash_ = seed; + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScatterNodeOp>(node); + if(!cnode) + return false; + if(axis_ != cnode->axis_) + return false; + return true; + } + +private: + friend class SerializationHelpers; + int axis_; +}; + struct ColsNodeOp : public NaryNodeOp { ColsNodeOp(Expr a, Expr indices) : NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) { diff --git a/src/graph/node_operators_tuple.h b/src/graph/node_operators_tuple.h index c7a9531a..8acb1bc8 100644 --- a/src/graph/node_operators_tuple.h +++ b/src/graph/node_operators_tuple.h @@ -133,7 +133,7 @@ public: } void backward() override { - Insert(/*out*/child(0)->grad(), adj_, val_, axis_); + Insert</*add=*/true>(/*out*/child(0)->grad(), adj_, val_, axis_); } const std::string type() override { return "topk"; } diff --git a/src/models/costs.cpp b/src/models/costs.cpp index c688b211..4b15bcb3 100644 --- a/src/models/costs.cpp +++ b/src/models/costs.cpp @@ -10,5 +10,40 @@ Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) { return state; } +Ptr<DecoderState> GumbelSoftmaxStep::apply(Ptr<DecoderState> state) { + state->setLogProbs(state->getLogProbs().applyUnaryFunctions( + [](Expr logits) { // lemma gets gumbelled + return logsoftmax(logits + constant_like(logits, inits::gumbel())); + }, + logsoftmax)); // factors don't + return state; +} + +TopkGumbelSoftmaxStep::TopkGumbelSoftmaxStep(int k) : k_{k} {} + +Ptr<DecoderState> TopkGumbelSoftmaxStep::apply(Ptr<DecoderState> state) { + state->setLogProbs(state->getLogProbs().applyUnaryFunctions( + [=](Expr logits) { // lemma gets gumbelled + // create logits-sized tensor consisting only of invalid path scores + float invalidPathScore = NumericLimits<float>(logits->value_type()).lowest; + Expr invalidLogits = constant_like(logits, inits::fromValue(invalidPathScore)); + + // select top-k values + Expr val, idx; + std::tie(val, idx) = topk(logits, k_, /*axis=*/-1, /*descending=*/true); + + // uncomment below to display probability mass in top-k selection + // debug(sum(gather(softmax(logits), -1, idx), -1), "sum"); + + // Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search + Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel())); + + // Scatter gumbelled values back into logits to fill with usable values + return scatter(invalidLogits, -1, idx, gumbelVal); + }, + logsoftmax)); // factors don't + return state; +} + } // namespace models } // namespace marian diff --git a/src/models/costs.h b/src/models/costs.h index e5463bfd..a087ed6a 100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -297,20 +297,30 @@ public: virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override; }; -// Gumbel-max noising for sampling during beam-search -// Seems to work well enough with beam-size=1. Turn on -// with --output-sampling during translation with marian-decoder +// Gumbel-max noising for sampling during translation. +// Produces accurate sampling with beam=1. Turn on +// with --output-sampling [full] during translation +// with marian-decoder for samnpling from the full +// softmax distribution. class GumbelSoftmaxStep : public ILogProbStep { public: virtual ~GumbelSoftmaxStep() {} - virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override { - state->setLogProbs(state->getLogProbs().applyUnaryFunctions( - [](Expr logits) { // lemma gets gumbelled - return logsoftmax(logits + constant_like(logits, inits::gumbel())); - }, - logsoftmax)); // factors don't - return state; - } + virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override; +}; + + +// Gumbel-max noising for top-k sampling during translation. +// Produces accurate sampling with beam=1. Turn on +// with --output-sampling topk [10] during translation +// with marian-decoder for top-10 sampling. +class TopkGumbelSoftmaxStep : public ILogProbStep { +private: + int k_{1}; + +public: + TopkGumbelSoftmaxStep(int k); + virtual ~TopkGumbelSoftmaxStep() {} + virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override; }; // class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence, diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp index e176e6a4..52a87e72 100644 --- a/src/models/model_factory.cpp +++ b/src/models/model_factory.cpp @@ -370,10 +370,25 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) { // add (log)softmax if requested if (use == usage::translation) { if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) { - if(options->get<bool>("output-sampling", false)) - return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>()); - else + if(options->hasAndNotEmpty("output-sampling")) { + auto sampling = options->get<std::vector<std::string>>("output-sampling", {}); + std::string method = sampling.size() > 0 ? sampling[0] : "full"; + + if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) { + LOG(info, "Output sampling from the full softmax distribution"); + return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>()); + } else if(method == "topk") { + int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10; + if(k == 1) + LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1"); + LOG(info, "Output sampling via top-{} sampling", k); + return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<TopkGumbelSoftmaxStep>(k)); + } else { + ABORT("Unknown sampling method: {}", method); + } + } else { return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>()); + } } #ifdef COMPILE_EXAMPLES // note: 'usage::translation' here means 'inference' diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index f3964f91..1e1adc38 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -739,6 +739,7 @@ void Select(Tensor out, } } +template <bool add> void Insert(Tensor out, const Tensor in, const Tensor indices, @@ -760,10 +761,16 @@ void Insert(Tensor out, int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex]; int outIndex = outShape.index(dims); - out->data()[outIndex] += in->data()[index]; + if(add) + out->data()[outIndex] += in->data()[index]; + else + out->data()[outIndex] = in->data()[index]; } } +template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis); +template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis); + void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) { int rows = out_->shape().elements() / out_->shape().back(); int cols = out_->shape().back(); diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index 1347c3bb..2103ca9d 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -1309,7 +1309,7 @@ __global__ void gSelect(T* out, } } -template <typename T> +template <bool add, typename T> __global__ void gInsert(T* out, functional::Shape outShape, const T* in, @@ -1327,7 +1327,10 @@ __global__ void gInsert(T* out, int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor dims[axis] = (int)d_indices[idxIndex]; int outIndex = outShape.index(dims); - out[outIndex] += in[index]; // this is probably wrong, atomicAdd? + if(add) + out[outIndex] += in[index]; // this is probably wrong, atomicAdd? + else + out[outIndex] = in[index]; } } } @@ -1349,21 +1352,21 @@ void Select(Tensor out, if(out->type() == Type::float32) { gSelect<<<blocks, threads>>>(out->data<float>(), - out->shape(), - in->data<float>(), - in->shape(), - axisGPU, - indices->data<IndexType>(), - indices->shape()); + out->shape(), + in->data<float>(), + in->shape(), + axisGPU, + indices->data<IndexType>(), + indices->shape()); #if COMPILE_FP16 } else if (out->type() == Type::float16) { gSelect<<<blocks, threads>>>(out->data<half>(), - out->shape(), - in->data<half>(), - in->shape(), - axisGPU, - indices->data<IndexType>(), - indices->shape()); + out->shape(), + in->data<half>(), + in->shape(), + axisGPU, + indices->data<IndexType>(), + indices->shape()); #endif } else if(out->type() == Type::uint32) { gSelect<<<blocks, threads>>>(out->data<IndexType>(), @@ -1378,6 +1381,7 @@ void Select(Tensor out, } } +template <bool add> void Insert(Tensor out, const Tensor in, const Tensor indices, @@ -1393,28 +1397,31 @@ void Insert(Tensor out, int axisGPU = axis + functional::Shape::size() - out->shape().size(); if(out->type() == Type::float32) { - gInsert<<<blocks, threads>>>(out->data<float>(), - out->shape(), - in->data<float>(), - in->shape(), - axisGPU, - indices->data<IndexType>(), - indices->shape()); + gInsert<add><<<blocks, threads>>>(out->data<float>(), + out->shape(), + in->data<float>(), + in->shape(), + axisGPU, + indices->data<IndexType>(), + indices->shape()); #if COMPILE_FP16 } else if (out->type() == Type::float16) { - gInsert<<<blocks, threads>>>(out->data<half>(), - out->shape(), - in->data<half>(), - in->shape(), - axisGPU, - indices->data<IndexType>(), - indices->shape()); + gInsert<add><<<blocks, threads>>>(out->data<half>(), + out->shape(), + in->data<half>(), + in->shape(), + axisGPU, + indices->data<IndexType>(), + indices->shape()); #endif } else { ABORT("Insert not implemented for type {}", out->type()); } } +template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis); +template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis); + template <typename T> __global__ void gGRUFastForward(T* out, const T* state, diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index dc29bf35..1fc4542d 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -297,7 +297,28 @@ DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const marian::Tensor) DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const marian::Tensor) DISPATCH4(Select, marian::Tensor, const marian::Tensor, const marian::Tensor, int) -DISPATCH4(Insert, marian::Tensor, const marian::Tensor, const marian::Tensor, int) + +#ifdef CUDA_FOUND +namespace gpu { + template <bool add> + void Insert(Tensor out, const Tensor in, const Tensor indices, int axis); +} +#endif + +namespace cpu { + template <bool add> + void Insert(Tensor out, const Tensor in, const Tensor indices, int axis); +} + +template <bool add> +static inline void Insert(Tensor out, const Tensor in, const Tensor indices, int axis) { +#ifdef CUDA_FOUND + if(out->getBackend()->getDeviceId().type == DeviceType::gpu) + gpu::Insert<add>(out, in, indices, axis); + else +#endif + cpu::Insert<add>(out, in, indices, axis); +} DISPATCH7(TopK, marian::Tensor, marian::Tensor, Ptr<Allocator>, const marian::Tensor, int, int, bool); diff --git a/src/translator/translator.h b/src/translator/translator.h index db1f3d03..3e375f65 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -119,7 +119,7 @@ public: threadPool.enqueue(task, device, id++); } - if(options_->get<bool>("output-sampling", false)) { + if(options_->hasAndNotEmpty("output-sampling")) { if(options_->get<size_t>("beam-size") > 1) LOG(warn, "[warning] Output sampling and beam search (beam-size > 1) are contradictory methods " |