Welcome to mirror list, hosted at ThFree Co, Russian Federation.

expression_graph.cpp « graph « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 520476aed0a1668cd1b208f930ad1eb013608641 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <sstream>
#include "graph/expression_graph.h"

#include "tensors/tensor_operators.h"

namespace marian {

ExpressionGraph::ExpressionGraph(bool inference)
    : inferenceOnly_(inference), backend_(nullptr) {}

void ExpressionGraph::setDevice(DeviceId deviceId) {
  if(!backend_) {
    backend_ = BackendByDevice(deviceId, Config::seed);
    params_ = New<Parameters>();
    params_->init(backend_);
    tensors_ = New<TensorAllocator>(backend_);
  }
}

Expr ExpressionGraph::dropout(float prob, Shape shape) {
  return Expression<ConstantNode>(shared_from_this(),
                                  keywords::init = [prob, this](Tensor t) {
                                    Dropout(t, prob);
                                  },
                                  keywords::shape = shape);
}

void ExpressionGraph::checkNan(Tensor t) {
  ABORT_IF(throwNaN_, "Not implemented");
  //ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN");
}
}