Welcome to mirror list, hosted at ThFree Co, Russian Federation.

nth_element.cpp « translator « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: f99b0be424d91f1c3649b41e838082fad04e6f0b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/* All or part of this file was contributed by Intel under license:
 *   Copyright (C) 2017-2018 Intel Corporation
 *   SPDX-License-Identifier: MIT
 */

#include "translator/nth_element.h"
#include <algorithm>
#include <iterator>
#include <limits>
#include <numeric>

namespace marian {

class NthElementCPU {
  std::vector<int> h_res_idx;
  std::vector<float> h_res;
  //size_t lastN_;

public:
  NthElementCPU() {}
  NthElementCPU(const NthElementCPU& copy) = delete;

private:
    void selectNBest(float* scores,
                     const std::vector<int>& batchFirstElementIdxs,
                     const std::vector<int>& cumulativeBeamSizes) {
    /* For each batch, select the max N elements, where N is the beam size for
     * this batch. Locally record these elements (their current value and index
     * in 'scores') before updating each element to a large negative value, such
     * that they won't be a maximum if we're called again on the same input.
     */

    int numProbs = batchFirstElementIdxs.back();
    std::vector<int> idxs(numProbs);
    std::iota(idxs.begin(), idxs.end(), 0);

    size_t numBatches = batchFirstElementIdxs.size() - 1;
    for(size_t batchIdx = 0; batchIdx < numBatches; ++batchIdx) {
      int pos = cumulativeBeamSizes[batchIdx];
      int beamSize = cumulativeBeamSizes[batchIdx + 1] - pos;

      std::vector<int>::iterator begin = idxs.begin() + batchFirstElementIdxs[batchIdx];
      std::vector<int>::iterator middle = begin + beamSize;
      std::vector<int>::iterator end = idxs.begin() + batchFirstElementIdxs[batchIdx + 1];
      std::partial_sort(
          begin, middle, end, [=](int a, int b) { return scores[a] > scores[b]; });

      while(begin != middle) {
        int idx = *begin++;
        h_res_idx[pos] = idx;
        h_res[pos] = scores[idx];
        scores[idx] = std::numeric_limits<float>::lowest();
        ++pos;
      }
    }
  }

public:
  // @BUGBUG: This API mixes input and output beam size.
  void getNBestList(Tensor scores, // [dimBatch, 1, beamSize, dimVocab or dimShortlist]
                    size_t N,
                    std::vector<float>& outPathScores,
                    std::vector<unsigned>& outKeys,
                    const bool isFirst) {
    const auto vocabSize = scores->shape()[-1];
    const auto inputN    = scores->shape()[-2];
    const auto dimBatch  = scores->shape()[-4];
    ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??");

    const std::vector<size_t> beamSizes(dimBatch, N);
    std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
    std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);

    for(int batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) {
      cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + (int)beamSizes[batchIdx];
      ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??");
      batchFirstElementIdxs[batchIdx + 1]
          += (isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) * vocabSize;
      ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
    }

    size_t maxSize = N * dimBatch;
    h_res.resize(maxSize);
    h_res_idx.resize(maxSize);

    selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes);
    getPairs(cumulativeBeamSizes.back(), outKeys, outPathScores);
    ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??");
  }

private:
  void getPairs(size_t number,
                std::vector<unsigned>& outKeys,
                std::vector<float>& outValues) {
    std::copy(h_res_idx.begin(), h_res_idx.begin() + number, std::back_inserter(outKeys));
    std::copy(h_res    .begin(), h_res    .begin() + number, std::back_inserter(outValues));
    //lastN_ = number;
  }

  //void getValueByKey(std::vector<float>& out, float* d_in) {
  //  for(size_t i = 0; i < lastN_; ++i) {
  //    out[i] = d_in[h_res_idx[i]];
  //  }
  //}
};

#ifdef CUDA_FOUND
GetNBestListFn createGetNBestListGPUFn(size_t beamSize, size_t dimBatch, DeviceId deviceId); // in .cu file
#endif

// factory function
// Returns a lambda with the same signature as the getNBestList() function.
GetNBestListFn createGetNBestListFn(size_t beamSize, size_t dimBatch, DeviceId deviceId) {
#ifdef CUDA_FOUND
  if(deviceId.type == DeviceType::gpu)
    return createGetNBestListGPUFn(beamSize, dimBatch, deviceId);
#else
  deviceId; beamSize; dimBatch; // (unused)
#endif
  auto nth = New<NthElementCPU>();
  return [nth](Tensor logProbs, size_t N, std::vector<float>& outCosts, std::vector<unsigned>& outKeys, const bool isFirst) {
    return nth->getNBestList(logProbs, N, outCosts, outKeys, isFirst);
  };
}

}  // namespace marian