Welcome to mirror list, hosted at ThFree Co, Russian Federation.

factory.h « layers « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: f9e4ddf92a0ca08dafd4672198e9c02c9c9c577a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#pragma once

#include "marian.h"

namespace marian {

class Factory : public std::enable_shared_from_this<Factory> {
protected:
  Ptr<Options> options_;

public:
  // construct with empty options
  Factory() : options_(New<Options>()) {}
  // construct with options
  Factory(Ptr<Options> options) : Factory() {
    options_->merge(options);
  }
  // construct with one or more individual option parameters
  // Factory("var1", val1, "var2", val2, ...)
  template <typename T, typename... Args>
  Factory(const std::string& key, T value, Args&&... moreArgs) : Factory() {
    setOpts(key, value, std::forward<Args>(moreArgs)...);
  }
  // construct with options and one or more individual option parameters
  // Factory(options, "var1", val1, "var2", val2, ...)
  template <typename... Args>
  Factory(Ptr<Options> options, Args&&... args) : Factory(options) {
    setOpts(std::forward<Args>(args)...);
  }
  Factory(const Factory& factory) = default;

  virtual ~Factory() {}

  std::string asYamlString() { return options_->asYamlString(); }

  // retrieve an option
  // auto val = opt<T>("var");
  template <typename T>
  T opt(const char* const key) { return options_->get<T>(key); }

  template <typename T>
  T opt(const char* const key, T defaultValue) { return options_->get<T>(key, defaultValue); }

  template <typename T>
  T opt(const std::string& key) { return options_->get<T>(key.c_str()); }

  template <typename T>
  T opt(const std::string& key, T defaultValue) { return options_->get<T>(key.c_str(), defaultValue); }

  // set a single option
  // setOpt("var", val);
  template <typename T>
  void setOpt(const std::string& key, T value) { options_->set(key, value); }

  // set one or more options at once
  // setOpts("var1", val1, "var2", val2, ...);
  template <typename T, typename... Args>
  void setOpts(const std::string& key, T value, Args&&... moreArgs) { options_->set(key, value, std::forward<Args>(moreArgs)...); }

  void mergeOpts(Ptr<Options> options) { options_->merge(options); }

  template <class Cast>
  inline Ptr<Cast> as() { return std::dynamic_pointer_cast<Cast>(shared_from_this()); }

  // @TODO: this fails with 'target type must be a pointer or reference to a defined class'
  //template <class Cast>
  //inline bool is() { return dynamic_cast<Cast>(this) != nullptr; }
  template <class Cast>
  inline bool is() { return std::dynamic_pointer_cast<Cast>(shared_from_this()) != nullptr; }
};

// simplest form of Factory that just passes on options to the constructor of a layer type
template<class Class>
struct ConstructingFactory : public Factory {
  using Factory::Factory;

  Ptr<Class> construct(Ptr<ExpressionGraph> graph) {
    return New<Class>(graph, options_);
  }
};

template <class BaseFactory> // where BaseFactory : Factory
class Accumulator : public BaseFactory {
  typedef BaseFactory Factory;

public:
  Accumulator() : Factory() {}
  Accumulator(Ptr<Options> options) : Factory(options) {}
  template <typename... Args>
  Accumulator(Ptr<Options> options, Args&&... moreArgs) : Factory(options, std::forward<Args>(moreArgs)...) {}
  template <typename T, typename... Args>
  Accumulator(const std::string& key, T value, Args&&... moreArgs) : Factory(key, value, std::forward<Args>(moreArgs)...) {}
  Accumulator(const Factory& factory) : Factory(factory) {}
  Accumulator(const Accumulator&) = default;
  Accumulator(Accumulator&&) = default;

  // deprecated chaining syntax
  template <typename T>
  Accumulator& operator()(const std::string& key, T value) {
    Factory::setOpt(key, value);
    return *this;
  }

  Accumulator& operator()(Ptr<Options> options) {
    Factory::mergeOpts(options);
    return *this;
  }

  Accumulator<Factory> clone() {
    return Accumulator<Factory>(Factory::clone());
  }
};
}  // namespace marian