diff options
author | Frank Seide <fseide@microsoft.com> | 2019-02-04 02:51:41 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2019-02-04 02:51:41 +0300 |
commit | 0c14096d7f0a44185db2ba6b327b041f33be23c1 (patch) | |
tree | 2d10a67610b2bfbd86798800fde447567b713e11 /src/graph | |
parent | 5d01d870537e8853f28604e36367ccb6b1904234 (diff) |
(comments)
Diffstat (limited to 'src/graph')
-rwxr-xr-x | src/graph/expression_operators.cpp | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 6a07611d..b74b822e 100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -250,7 +250,12 @@ Expr gather(Expr a, int axis, Expr indices) { return Expression<GatherNodeOp>(a, axis, indices); } -// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector) +// index_select() -- gather arbitrary elements along an axis from an unbatched +// input 'a'. Indices are specified as a 1D vector. +// This is used e.g. for embedding lookup. +// Note: To use a batch of index vectors, reshape them into a single vector, +// call index_select(), then reshape the result back. Reshapes are cheap. +// This function has the same semantics as PyTorch operation of the same name. Expr index_select(Expr a, int axis, Expr indices) { ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor"); // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor. |