diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-04-20 13:38:31 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-04-20 13:38:31 +0300 |
commit | 4e8bbcd9f7589bc3527aee554d7ad37bb626e46a (patch) | |
tree | 04e9a3b89e9ebc4c52ee603cf684f64980988084 /src/graph/node_operators_unary.h | |
parent | aab18b26ef0ade95b1eb50f1874fc458f385c449 (diff) |
revert to single parameter, fix derivative
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 8fab1a5e..20298e2a 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -804,8 +804,9 @@ struct ShiftNodeOp : public UnaryNodeOp { struct LexicalProbNodeOp : public NaryNodeOp { template <typename ...Args> - LexicalProbNodeOp(Expr logits, Expr att, Expr exp, Ptr<sparse::CSR> lf, Args ...args) - : NaryNodeOp({logits, att, exp}, keywords::shape=logits->shape(), args...), + LexicalProbNodeOp(Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf, Args ...args) + : NaryNodeOp({logits, att}, keywords::shape=logits->shape(), args...), + eps_(eps), lf_(lf) { } @@ -814,18 +815,17 @@ struct LexicalProbNodeOp : public NaryNodeOp { children_[0]->val(), children_[1]->val(), lf_); - // val = x + ln(p + eps + 1e-9) - Element(_1 = (Log(_1 + _3 + 1e-9) + _2), - val_, children_[0]->val(), children_[2]->val()); + // val = x + ln(p + eps) + Element(_1 = (Log(_1 + eps_) + _2), + val_, children_[0]->val()); } void backward() { Add(_1, children_[0]->grad(), adj_); - // adj' = adj / ( p + eps + 1e-9) = adj / exp(val - x) + // adj' = adj / (p + eps) = adj / exp(val - x) Element(_1 = _1 / Exp(_2 - _3), adj_, val_, children_[0]->val()); sparse::LfaBackward(children_[1]->grad(), adj_, lf_); - Add(_1, children_[2]->grad(), adj_); } const std::string type() { @@ -841,7 +841,7 @@ struct LexicalProbNodeOp : public NaryNodeOp { return hash_; } - + float eps_; Ptr<sparse::CSR> lf_; }; |