Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2018-06-21 23:34:19 +0300
committerFrank Seide <fseide@microsoft.com>2018-06-21 23:34:19 +0300
commite87564c6643dc0b1d4344ed69e32bf72e009b12d (patch)
tree6fd27a7ac5abcca9a0ab661d6122c3a999640674 /src/graph/node_operators_unary.h
parent8adde0787ebc360fd82c455e8ed04e87d42d0b44 (diff)
bug fix: History::Add() should obey the actual EOS symbol from the given vocabulary;
SubBatch now holds the vocabulary, allowing to debug-print the sequences with word ids; new operators: logsum, max, min; new overload operator/(float,Expr); shift() now takes a padding value; transformer.h is now a compiled .cpp file on Windows (in prep for renaming to transformer.cpp); fixed two MSVC warnings; bug fix: logger should not output two \r characters at the end of the line; minor changes (const correctness, accessors)
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h20
1 files changed, 12 insertions, 8 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 259e6072..6749a585 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -550,7 +550,8 @@ struct LogNodeOp : public UnaryNodeOp {
NodeOps backwardOps() {
using namespace functional;
return {
- NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
+ //NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
+ NodeOp(Add(_1 / _2, child(0)->grad(), adj_, child(0)->val()))};
}
const std::string type() { return "log"; }
@@ -931,8 +932,9 @@ public:
Shape outShape = a->shape();
axis_ = outShape.axis(axis);
- for(int i = 0; i <= axis_; ++i)
- outShape.set(i, 1);
+ for(int i = 0; i < axis_; ++i)
+ ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()");
+ outShape.set(axis_, 1);
return outShape;
}
@@ -993,15 +995,15 @@ public:
};
struct ShiftNodeOp : public UnaryNodeOp {
- ShiftNodeOp(Expr a, Shape shift)
- : UnaryNodeOp(a, a->shape()), shift_(shift) {}
+ ShiftNodeOp(Expr a, Shape shift, float padValue)
+ : UnaryNodeOp(a, a->shape()), shift_(shift), padValue_(padValue) {}
NodeOps forwardOps() {
- return {NodeOp(Shift(val_, child(0)->val(), shift_, false))};
+ return {NodeOp(Shift(val_, child(0)->val(), shift_, padValue_, /*invert=*/false))};
}
NodeOps backwardOps() {
- return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))};
+ return {NodeOp(Shift(child(0)->grad(), adj_, shift_, /*padValue=*/0.f, /*invert=*/true))};
}
const std::string type() { return "shift"; }
@@ -1011,6 +1013,7 @@ struct ShiftNodeOp : public UnaryNodeOp {
size_t seed = NaryNodeOp::hash();
for(auto i : shift_)
boost::hash_combine(seed, i);
+ boost::hash_combine(seed, padValue_);
hash_ = seed;
}
return hash_;
@@ -1027,7 +1030,8 @@ struct ShiftNodeOp : public UnaryNodeOp {
return true;
}
- Shape shift_;
+ Shape shift_; // shift offsets in each dimension
+ float padValue_; // what value to shift in
};
// struct LexicalProbNodeOp : public NaryNodeOp {