diff options
author | Frank Seide <fseide@microsoft.com> | 2019-02-07 07:25:43 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2019-02-07 07:25:43 +0300 |
commit | f88eb0d3687575cfe55f908ddad4617c9dc68ee2 (patch) | |
tree | 7fef836345671a8b3e806fbf968554a67997454d /src/graph | |
parent | 7c7f94c416dea09df1285bdd0614bea34de456e4 (diff) |
commenting and minor refactoring of beam search
Diffstat (limited to 'src/graph')
-rwxr-xr-x[-rw-r--r--] | src/graph/expression_operators.cpp | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index e558ffd0..8a79cbe0 100644..100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -135,31 +135,49 @@ Expr le(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant( /*********************************************************/ Expr operator+(Expr a, float b) { - return Expression<ScalarAddNodeOp>(a, b); + if (b == 0) + return a; + else + return Expression<ScalarAddNodeOp>(a, b); } Expr operator+(float a, Expr b) { - return Expression<ScalarAddNodeOp>(b, a); + if (a == 0) + return b; + else + return Expression<ScalarAddNodeOp>(b, a); } Expr operator-(Expr a, float b) { - return Expression<ScalarAddNodeOp>(a, -b); + if (b == 0) + return a; + else + return Expression<ScalarAddNodeOp>(a, -b); } Expr operator-(float a, Expr b) { - return Expression<ScalarAddNodeOp>(-b, a); + if (a == 0) + return -b; + else + return Expression<ScalarAddNodeOp>(-b, a); } Expr operator*(float a, Expr b) { - return Expression<ScalarMultNodeOp>(b, a); + if (a == 1.0f) + return b; + else + return Expression<ScalarMultNodeOp>(b, a); } Expr operator*(Expr a, float b) { - return Expression<ScalarMultNodeOp>(a, b); + if (b == 1.0f) + return a; + else + return Expression<ScalarMultNodeOp>(a, b); } Expr operator/(Expr a, float b) { - return Expression<ScalarMultNodeOp>(a, 1.f / b); + return a * (1.f / b); } // TODO: efficient version of this without constant() @@ -254,7 +272,12 @@ Expr gather(Expr a, int axis, Expr indices) { return Expression<GatherNodeOp>(a, axis, indices); } -// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector) +// index_select() -- gather arbitrary elements along an axis from an unbatched +// input 'a'. Indices are specified as a 1D vector. +// This is used e.g. for embedding lookup. +// Note: To use a batch of index vectors, reshape them into a single vector, +// call index_select(), then reshape the result back. Reshapes are cheap. +// This function has the same semantics as PyTorch operation of the same name. Expr index_select(Expr a, int axis, Expr indices) { ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor"); // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor. |