blob: 520476aed0a1668cd1b208f930ad1eb013608641 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
#include <sstream>
#include "graph/expression_graph.h"
#include "tensors/tensor_operators.h"
namespace marian {
ExpressionGraph::ExpressionGraph(bool inference)
: inferenceOnly_(inference), backend_(nullptr) {}
void ExpressionGraph::setDevice(DeviceId deviceId) {
if(!backend_) {
backend_ = BackendByDevice(deviceId, Config::seed);
params_ = New<Parameters>();
params_->init(backend_);
tensors_ = New<TensorAllocator>(backend_);
}
}
Expr ExpressionGraph::dropout(float prob, Shape shape) {
return Expression<ConstantNode>(shared_from_this(),
keywords::init = [prob, this](Tensor t) {
Dropout(t, prob);
},
keywords::shape = shape);
}
void ExpressionGraph::checkNan(Tensor t) {
ABORT_IF(throwNaN_, "Not implemented");
//ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN");
}
}
|