Welcome to mirror list, hosted at ThFree Co, Russian Federation.

batch_stats.h « data « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 846d61097992ef5a83063d57d76a9e31fc0d5b7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#pragma once

#include <deque>
#include <queue>

#include "data/corpus.h"
#include "data/vocab.h"

namespace marian {
namespace data {

class BatchStats {
private:
  std::map<std::vector<size_t>, size_t> map_; // [(src len, tgt len)] -> batch size

public:
  BatchStats() { }

  size_t getBatchSize(const std::vector<size_t>& lengths) {
    // find the first item where all item.first[i] >= lengths[i], i.e. that can fit sentence tuples of lengths[]
    auto it = map_.lower_bound(lengths); // typ. 20 items, ~4..5 steps
    for(size_t i = 0; i < lengths.size(); ++i)
      while(it != map_.end() && it->first[i] < lengths[i])
        it++;

    ABORT_IF(it == map_.end(), "Missing batch statistics");
    return it->second;
  }

  void add(Ptr<data::CorpusBatch> batch, size_t multiplier = 1) {
    std::vector<size_t> lengths;
    for(size_t i = 0; i < batch->sets(); ++i)
      lengths.push_back((*batch)[i]->batchWidth());
    size_t batchSize = batch->size() * multiplier;

    if(map_[lengths] < batchSize)
      map_[lengths] = batchSize;
  }

  // helpers for multi-node  --note: presently unused, but keeping them around for later use
  // serialize into a flat vector, for MPI data exchange
  std::vector<size_t> flatten() const {
    std::vector<size_t> res;
    if(map_.empty())
      return res;
    auto numStreams = map_.begin()->first.size();
    // format:
    //  - num streams
    //  - tuples ((stream sizes), )
    res.push_back(numStreams);
    for (const auto& entry : map_) {
      ABORT_IF(entry.first.size() != numStreams, "inconsistent number of streams??");
      for (auto streamLen : entry.first)
        res.push_back(streamLen);
      res.push_back(entry.second);
    }
    return res;
  }

  // deserialize a flattened batchStats
  // used as part of MPI data exchange
  BatchStats(const std::vector<size_t>& flattenedStats) {
    if (flattenedStats.empty())
      return;
    size_t i = 0;
    auto numStreams = flattenedStats[i++];
    std::vector<size_t> lengths(numStreams);
    while (i < flattenedStats.size()) {
      for(auto& length : lengths)
        length = flattenedStats[i++];
      auto batchSize = flattenedStats[i++];
      map_[lengths] = batchSize;
    }
    ABORT_IF(i != flattenedStats.size(), "invalid flattenedVector??");
    //dump();
  }

  void dump() { // (for debugging)
    for (const auto& entry : map_) {
      for (auto streamLen : entry.first)
        std::cerr << streamLen << " ";
      std::cerr << ": " << entry.second << std::endl;
    }
  }
};
}  // namespace data
}  // namespace marian