Welcome to mirror list, hosted at ThFree Co, Russian Federation.

node_operators.h « graph « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 01c7f59047e18385b9ee85f876b5bd709d531499 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#pragma once

#include "graph/node.h"
#include "graph/node_initializers.h"
#include "tensors/tensor.h"

namespace marian {

struct ConstantNode : public Node {
  ConstantNode(Ptr<ExpressionGraph> graph,
               const Shape& shape,
               const NodeInitializer& init,
               Type value_type = Type::float32)
      : Node(graph, shape, value_type),
        init_(new NodeInitializer(init)),
        initialized_(false) {
    setTrainable(false);
  }

  ~ConstantNode() {}

  virtual size_t allocate() override;
  virtual void init() override;

  const std::string type() override { return "const"; }

  const std::string form() override { return "diamond"; }

  const std::string color() override { return "white"; }

  virtual size_t hash() override {
    size_t seed = util::hash<size_t>()((size_t)this);
    return seed;
  }

  virtual bool equal(Expr node) override { return this == node.get(); }
  virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) override{};

private:
  UPtr<NodeInitializer> init_;
  bool initialized_;
};

struct ParamNode : public Node {
  ParamNode(Ptr<ExpressionGraph> graph,
            const Shape& shape,
            const NodeInitializer& init,
            bool fixed = false);

  ~ParamNode() {}

  virtual size_t allocate() override {
    ABORT_IF(!val_, "Parameters should be allocated by their graph. Parameter {} was not", name_);
    return 0;
  }

  virtual void init() override;

  const std::string type() override { return "param"; }

  const std::string form() override { return "hexagon"; }

  const std::string color() override { return "orangered"; }

  virtual size_t hash() override {
    size_t seed = util::hash<size_t>()((size_t)this);
    return seed;
  }

  virtual bool equal(Expr node) override { return name() == node->name(); }

  virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) override{};

private:
  UPtr<NodeInitializer> init_;
  bool initialized_;
};
}  // namespace marian