Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-02-04 02:51:41 +0300
committerFrank Seide <fseide@microsoft.com>2019-02-04 02:51:41 +0300
commit0c14096d7f0a44185db2ba6b327b041f33be23c1 (patch)
tree2d10a67610b2bfbd86798800fde447567b713e11 /src/graph
parent5d01d870537e8853f28604e36367ccb6b1904234 (diff)
(comments)
Diffstat (limited to 'src/graph')
-rwxr-xr-xsrc/graph/expression_operators.cpp7
1 files changed, 6 insertions, 1 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 6a07611d..b74b822e 100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -250,7 +250,12 @@ Expr gather(Expr a, int axis, Expr indices) {
return Expression<GatherNodeOp>(a, axis, indices);
}
-// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector)
+// index_select() -- gather arbitrary elements along an axis from an unbatched
+// input 'a'. Indices are specified as a 1D vector.
+// This is used e.g. for embedding lookup.
+// Note: To use a batch of index vectors, reshape them into a single vector,
+// call index_select(), then reshape the result back. Reshapes are cheap.
+// This function has the same semantics as PyTorch operation of the same name.
Expr index_select(Expr a, int axis, Expr indices) {
ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
// We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.