Welcome to mirror list, hosted at ThFree Co, Russian Federation.

graph_group.h « training « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0e4a68dccd477c0d604eb6d5cec8aac65c393d67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma once

#include "common/definitions.h"
#include "common/options.h"
#include "data/batch_generator.h"
#include "graph/expression_graph.h"
#include "models/model_base.h"
#include "optimizers/optimizers.h"
#include "training/scheduler.h"
#include "training/communicator.h"

namespace marian {

// With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here.
static inline bool isFinite(float x) {
#ifdef __GNUC__
  ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option.");
#endif
  return std::isfinite(x);
}

#ifdef _MSC_VER // MS Visual studio insists that this funtion is not being referenced although is being referenced by name as an argument
#pragma warning(push)
#pragma warning(disable: 4505) //Unreferenced local function has been removed
#endif
// to accumulate gradients norms, first undo sqrt, sum, re-apply sqrt.
// if one value is nonfinite propagate Nan into the reduction.
static inline void accNanOrNorm(float& lhs, float rhs) {
  if(isFinite(lhs) && isFinite(rhs)) {
    lhs = sqrtf(lhs * lhs + rhs * rhs);
  } else
    lhs = std::numeric_limits<float>::quiet_NaN();
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif

/**
 *  Base class for managing the training process across one, multiple gpus,
 *  or even multiple machines with multiple gpus.
 */
class GraphGroup {
protected:
  Ptr<Options> options_;

  Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
  Ptr<IMPIWrapper> mpi_;    // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)

  std::vector<DeviceId> devices_;                   // [deviceIndex]
  ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).

  // common for all graph groups, individual graph groups decide how to fill them
  std::vector<Ptr<ExpressionGraph>> graphs_;            // [deviceIndex]
  std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex]
  std::vector<Ptr<OptimizerBase>> optimizerShards_;     // [deviceIndex]

  Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed

  bool finalized_{false};    // 'true' if training has completed (further updates are no longer allowed)
  double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
  bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false

  bool costScale_{false};
  float costScaleFactor_{1.f}; // @TODO, add current costScaleFactor_ to trainingState for serialization
  size_t costScaleFreq_{2000};
  float costScaleMultiplier_{2.f};
  float costScaleNanTolerance_{0.f};
  size_t costScaleNanRange_{1};
  float costScaleFactorMinimum_{1.f}; // @TODO make this configureable
  size_t noNanSeen_{0}; // @TODO, add current noNanSeen_ to trainingState for serialization
  size_t nanSeen_{0};

  bool dynamicGradientScaling_{false};
  float dynamicGradientScalingFactor_{2.f};
  bool dynamicGradientScalingUseLogs_{false};

  bool checkGradientNan_{false};

  // determines the number of input streams (i.e. input files or fields in the TSV input) that need
  // to be included in the batch, i.e. without alignments and weights
  size_t numberOfInputFiles();

public:
  GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi);
  GraphGroup(Ptr<Options> options);

  void initGraphsAndOpts();

  virtual ~GraphGroup() {}

  virtual void update(Ptr<data::Batch> batch) = 0;

  // increase cost-scaling factor if no NaN has been detected for a
  // given number of iterations. Usually we increase by 2 which adds
  // one more bit for precision.
  void increaseCostScaleFactor();

  // call when a NaN was seen to decrease cost-scaling factor
  void decreaseCostScaleFactor();

  virtual void load();
  virtual void save(bool isFinal = false);

private:
  void load(const OptimizerBase::ScatterStateFunc& scatterFn);
  void save(bool isFinal,
            const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn);

  bool restoreFromCheckpoint(const std::string& modelFileName,
                             const OptimizerBase::ScatterStateFunc& scatterFn);

  void saveCheckpoint(const std::string& modelFileName,
                      const OptimizerBase::GatherStateFunc& gatherFn);

public:
  void swapWithSmoothed();

  bool isMainProcess() const { return mpi_->isMainProcess(); } // (we need this test a few times)
  void barrier() const { mpi_->barrier(); } // (we need this several times)

  void validate();

  virtual void finalize();

  virtual void setScheduler(Ptr<Scheduler> scheduler) = 0;

  float checkNanOrNorm(size_t i, size_t begin, size_t end);
  float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task);

  float computeNormalizationFactor(float gNorm, size_t updateTrgWords);

  /**
   * Determine maximal batch size that can fit into the given workspace
   * so that reallocation does not happen. Rather adjust the batch size
   * based on the statistics collected here. Activated with
   * `--mini-batch-fit`.
   * In a multi-GPU scenario, the first GPU is used to determine the size.
   * The actual allowed size is then determined by multiplying it with the
   * number of devices, which is passed in as the 'multiplier'.
   */
  // @TODO: Can this be made const? It seems wrong to have a stateful method that still returns a result.
  virtual Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
                                             Ptr<models::ICriterionFunction> model,
                                             const std::vector<Ptr<Vocab>>& vocabs,
                                             double multiplier = 1.);

  virtual Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) = 0;

  void setTypicalTrgBatchWords(size_t typicalTrgBatchWords);
  double getTypicalTrgBatchWords();
  void updateAverageTrgBatchWords(size_t trgBatchWords);
};

}  // namespace marian