diff options
author | Frank Seide <fseide@microsoft.com> | 2018-06-21 23:34:19 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2018-06-21 23:34:19 +0300 |
commit | e87564c6643dc0b1d4344ed69e32bf72e009b12d (patch) | |
tree | 6fd27a7ac5abcca9a0ab661d6122c3a999640674 /src/graph/node_operators_unary.h | |
parent | 8adde0787ebc360fd82c455e8ed04e87d42d0b44 (diff) |
bug fix: History::Add() should obey the actual EOS symbol from the given vocabulary;
SubBatch now holds the vocabulary, allowing to debug-print the sequences with word ids;
new operators: logsum, max, min;
new overload operator/(float,Expr);
shift() now takes a padding value;
transformer.h is now a compiled .cpp file on Windows (in prep for renaming to transformer.cpp);
fixed two MSVC warnings;
bug fix: logger should not output two \r characters at the end of the line;
minor changes (const correctness, accessors)
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 259e6072..6749a585 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -550,7 +550,8 @@ struct LogNodeOp : public UnaryNodeOp { NodeOps backwardOps() { using namespace functional; return { - NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))}; + //NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))}; + NodeOp(Add(_1 / _2, child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() { return "log"; } @@ -931,8 +932,9 @@ public: Shape outShape = a->shape(); axis_ = outShape.axis(axis); - for(int i = 0; i <= axis_; ++i) - outShape.set(i, 1); + for(int i = 0; i < axis_; ++i) + ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()"); + outShape.set(axis_, 1); return outShape; } @@ -993,15 +995,15 @@ public: }; struct ShiftNodeOp : public UnaryNodeOp { - ShiftNodeOp(Expr a, Shape shift) - : UnaryNodeOp(a, a->shape()), shift_(shift) {} + ShiftNodeOp(Expr a, Shape shift, float padValue) + : UnaryNodeOp(a, a->shape()), shift_(shift), padValue_(padValue) {} NodeOps forwardOps() { - return {NodeOp(Shift(val_, child(0)->val(), shift_, false))}; + return {NodeOp(Shift(val_, child(0)->val(), shift_, padValue_, /*invert=*/false))}; } NodeOps backwardOps() { - return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))}; + return {NodeOp(Shift(child(0)->grad(), adj_, shift_, /*padValue=*/0.f, /*invert=*/true))}; } const std::string type() { return "shift"; } @@ -1011,6 +1013,7 @@ struct ShiftNodeOp : public UnaryNodeOp { size_t seed = NaryNodeOp::hash(); for(auto i : shift_) boost::hash_combine(seed, i); + boost::hash_combine(seed, padValue_); hash_ = seed; } return hash_; @@ -1027,7 +1030,8 @@ struct ShiftNodeOp : public UnaryNodeOp { return true; } - Shape shift_; + Shape shift_; // shift offsets in each dimension + float padValue_; // what value to shift in }; // struct LexicalProbNodeOp : public NaryNodeOp { |