Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-04-20 13:38:31 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-04-20 13:38:31 +0300
commit4e8bbcd9f7589bc3527aee554d7ad37bb626e46a (patch)
tree04e9a3b89e9ebc4c52ee603cf684f64980988084 /src/graph/node_operators_unary.h
parentaab18b26ef0ade95b1eb50f1874fc458f385c449 (diff)
revert to single parameter, fix derivative
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h16
1 files changed, 8 insertions, 8 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 8fab1a5e..20298e2a 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -804,8 +804,9 @@ struct ShiftNodeOp : public UnaryNodeOp {
struct LexicalProbNodeOp : public NaryNodeOp {
template <typename ...Args>
- LexicalProbNodeOp(Expr logits, Expr att, Expr exp, Ptr<sparse::CSR> lf, Args ...args)
- : NaryNodeOp({logits, att, exp}, keywords::shape=logits->shape(), args...),
+ LexicalProbNodeOp(Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf, Args ...args)
+ : NaryNodeOp({logits, att}, keywords::shape=logits->shape(), args...),
+ eps_(eps),
lf_(lf) {
}
@@ -814,18 +815,17 @@ struct LexicalProbNodeOp : public NaryNodeOp {
children_[0]->val(),
children_[1]->val(),
lf_);
- // val = x + ln(p + eps + 1e-9)
- Element(_1 = (Log(_1 + _3 + 1e-9) + _2),
- val_, children_[0]->val(), children_[2]->val());
+ // val = x + ln(p + eps)
+ Element(_1 = (Log(_1 + eps_) + _2),
+ val_, children_[0]->val());
}
void backward() {
Add(_1, children_[0]->grad(), adj_);
- // adj' = adj / ( p + eps + 1e-9) = adj / exp(val - x)
+ // adj' = adj / (p + eps) = adj / exp(val - x)
Element(_1 = _1 / Exp(_2 - _3),
adj_, val_, children_[0]->val());
sparse::LfaBackward(children_[1]->grad(), adj_, lf_);
- Add(_1, children_[2]->grad(), adj_);
}
const std::string type() {
@@ -841,7 +841,7 @@ struct LexicalProbNodeOp : public NaryNodeOp {
return hash_;
}
-
+ float eps_;
Ptr<sparse::CSR> lf_;
};