blob: 846d61097992ef5a83063d57d76a9e31fc0d5b7e (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
#pragma once
#include <deque>
#include <queue>
#include "data/corpus.h"
#include "data/vocab.h"
namespace marian {
namespace data {
class BatchStats {
private:
std::map<std::vector<size_t>, size_t> map_; // [(src len, tgt len)] -> batch size
public:
BatchStats() { }
size_t getBatchSize(const std::vector<size_t>& lengths) {
// find the first item where all item.first[i] >= lengths[i], i.e. that can fit sentence tuples of lengths[]
auto it = map_.lower_bound(lengths); // typ. 20 items, ~4..5 steps
for(size_t i = 0; i < lengths.size(); ++i)
while(it != map_.end() && it->first[i] < lengths[i])
it++;
ABORT_IF(it == map_.end(), "Missing batch statistics");
return it->second;
}
void add(Ptr<data::CorpusBatch> batch, size_t multiplier = 1) {
std::vector<size_t> lengths;
for(size_t i = 0; i < batch->sets(); ++i)
lengths.push_back((*batch)[i]->batchWidth());
size_t batchSize = batch->size() * multiplier;
if(map_[lengths] < batchSize)
map_[lengths] = batchSize;
}
// helpers for multi-node --note: presently unused, but keeping them around for later use
// serialize into a flat vector, for MPI data exchange
std::vector<size_t> flatten() const {
std::vector<size_t> res;
if(map_.empty())
return res;
auto numStreams = map_.begin()->first.size();
// format:
// - num streams
// - tuples ((stream sizes), )
res.push_back(numStreams);
for (const auto& entry : map_) {
ABORT_IF(entry.first.size() != numStreams, "inconsistent number of streams??");
for (auto streamLen : entry.first)
res.push_back(streamLen);
res.push_back(entry.second);
}
return res;
}
// deserialize a flattened batchStats
// used as part of MPI data exchange
BatchStats(const std::vector<size_t>& flattenedStats) {
if (flattenedStats.empty())
return;
size_t i = 0;
auto numStreams = flattenedStats[i++];
std::vector<size_t> lengths(numStreams);
while (i < flattenedStats.size()) {
for(auto& length : lengths)
length = flattenedStats[i++];
auto batchSize = flattenedStats[i++];
map_[lengths] = batchSize;
}
ABORT_IF(i != flattenedStats.size(), "invalid flattenedVector??");
//dump();
}
void dump() { // (for debugging)
for (const auto& entry : map_) {
for (auto streamLen : entry.first)
std::cerr << streamLen << " ";
std::cerr << ": " << entry.second << std::endl;
}
}
};
} // namespace data
} // namespace marian
|