Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-02-07 07:25:43 +0300
committerFrank Seide <fseide@microsoft.com>2019-02-07 07:25:43 +0300
commitf88eb0d3687575cfe55f908ddad4617c9dc68ee2 (patch)
tree7fef836345671a8b3e806fbf968554a67997454d /src/graph
parent7c7f94c416dea09df1285bdd0614bea34de456e4 (diff)
commenting and minor refactoring of beam search
Diffstat (limited to 'src/graph')
-rwxr-xr-x[-rw-r--r--]src/graph/expression_operators.cpp39
1 files changed, 31 insertions, 8 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index e558ffd0..8a79cbe0 100644..100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -135,31 +135,49 @@ Expr le(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant(
/*********************************************************/
Expr operator+(Expr a, float b) {
- return Expression<ScalarAddNodeOp>(a, b);
+ if (b == 0)
+ return a;
+ else
+ return Expression<ScalarAddNodeOp>(a, b);
}
Expr operator+(float a, Expr b) {
- return Expression<ScalarAddNodeOp>(b, a);
+ if (a == 0)
+ return b;
+ else
+ return Expression<ScalarAddNodeOp>(b, a);
}
Expr operator-(Expr a, float b) {
- return Expression<ScalarAddNodeOp>(a, -b);
+ if (b == 0)
+ return a;
+ else
+ return Expression<ScalarAddNodeOp>(a, -b);
}
Expr operator-(float a, Expr b) {
- return Expression<ScalarAddNodeOp>(-b, a);
+ if (a == 0)
+ return -b;
+ else
+ return Expression<ScalarAddNodeOp>(-b, a);
}
Expr operator*(float a, Expr b) {
- return Expression<ScalarMultNodeOp>(b, a);
+ if (a == 1.0f)
+ return b;
+ else
+ return Expression<ScalarMultNodeOp>(b, a);
}
Expr operator*(Expr a, float b) {
- return Expression<ScalarMultNodeOp>(a, b);
+ if (b == 1.0f)
+ return a;
+ else
+ return Expression<ScalarMultNodeOp>(a, b);
}
Expr operator/(Expr a, float b) {
- return Expression<ScalarMultNodeOp>(a, 1.f / b);
+ return a * (1.f / b);
}
// TODO: efficient version of this without constant()
@@ -254,7 +272,12 @@ Expr gather(Expr a, int axis, Expr indices) {
return Expression<GatherNodeOp>(a, axis, indices);
}
-// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector)
+// index_select() -- gather arbitrary elements along an axis from an unbatched
+// input 'a'. Indices are specified as a 1D vector.
+// This is used e.g. for embedding lookup.
+// Note: To use a batch of index vectors, reshape them into a single vector,
+// call index_select(), then reshape the result back. Reshapes are cheap.
+// This function has the same semantics as PyTorch operation of the same name.
Expr index_select(Expr a, int axis, Expr indices) {
ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
// We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.