Welcome to mirror list, hosted at ThFree Co, Russian Federation.

LookupTableBag.cu « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 1042214d7420fa93fd5e52ac25709f2b884f14a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "THCUNN.h"
#include "common.h"

#include "THCThrustAllocator.cuh"
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/transform_reduce.h>
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif
#include <thrust/unique.h>
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCTensorSort.cuh"

const int WARP_SIZE = 32;
const int MODE_SUM = 0;
const int MODE_MEAN = 1;

template <typename Dtype, typename Acctype>
__global__ void cunn_LookupTableBag_updateOutputKernel(
  long *input, long *offsets, Dtype *weight, Dtype *output,
  long *offset2bag, long numIndices, long numBags, long stride, int mode,
  long *bag_size) {

  // the strategy here is that each bag x feature is handled by a single thread

  long chunksPerBag = THCCeilDiv(stride, (long) blockDim.x);
  long numChunks = numBags * chunksPerBag;
  long chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
  long chunkStride = gridDim.x * blockDim.y;

  for (long chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
    long featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
    if (featureDim < stride) {
      long bag = chunk / chunksPerBag;
      Dtype*  weightFeat = weight + featureDim;
      long begin = offsets[bag] - TH_INDEX_BASE;
      long end = (bag < numBags - 1) ? (offsets[bag + 1] - TH_INDEX_BASE) : numIndices;
      assert(end >= begin);
      Acctype weightFeatSum = ScalarConvert<float, Acctype>::to(0);
      long bag_size_ = 0;
      for (long emb = begin; emb < end; emb++) {
        const int weightRow = ((int) input[emb] - TH_INDEX_BASE) * stride;
        weightFeatSum += ScalarConvert<Dtype, Acctype>::to(weightFeat[weightRow]);
	bag_size_ ++;
        if (featureDim == 0) {
          offset2bag[emb] = bag + TH_INDEX_BASE;
        }
      }
      if (mode == MODE_MEAN) {
	weightFeatSum = weightFeatSum / ScalarConvert<long, Acctype>::to(bag_size_);
	bag_size[bag] = bag_size_;
      }
      output[bag * stride + featureDim] = ScalarConvert<Acctype, Dtype>::to(weightFeatSum);
    }
  }
}

// FIXME: removed the accGradParametersKernelByFeature case present in
// LookupTable. That kernel is faster at small sizes (<768 indices), which
// does not need LookupTableBag (LookupTable + Sum works fine), but would
// still be nice to not be slow in that case.

template <typename Dtype, typename Acctype>
__global__ void cunn_LookupTableBag_accGradParametersKernel(
  long *input, long *indices, Dtype *gradOutput, Dtype *gradWeight, long *offset2bag,
  long *count, Dtype defaultScale, ptrdiff_t numel, long stride,
  int mode, long *bag_size) {

  int idx = blockIdx.x * 4 + threadIdx.y;

  // Each warp is responsible for an input into the LookupTable.
  // If the preceding input has the same as this input, then the warp
  // exits immediately. The warp also processes subsequent inputs with the
  // same value.
  //
  // Input Warp
  // 1     <warp 1>
  // 1     <warp 1> (<warp 2> exits without doing any work)
  // 5     <warp 3>
  // 8     <warp 4>

  // Number of values proceessed by each thread (grain size)
  const int SZ = 4;

  if (idx < numel
      && (idx == 0 || input[idx] != input[idx - 1])) {
    do {
      const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
      const int weightRow = ((int) input[idx] - TH_INDEX_BASE) * stride;

      // Note: only this line changes from LookupTable_accgradParametersKernel
      const int origRow = ((int) indices[idx] - TH_INDEX_BASE);
      const int seq_number = offset2bag[origRow] - TH_INDEX_BASE;
      const int gradOutputRow = ((int) seq_number) * stride;

      const Acctype scale = count ? ScalarConvert<Dtype, Acctype>::to(defaultScale) / count[idx] : ScalarConvert<Dtype, Acctype>::to(defaultScale);

      Acctype gradient[SZ];
      Acctype weight[SZ];

      #pragma unroll
      for (int ii = 0; ii < SZ; ii++)
      {
        int featureDim = startFeature + ii * WARP_SIZE;
        if (featureDim < stride)
        {
          gradient[ii] = ScalarConvert<Dtype, Acctype>::to(gradOutput[gradOutputRow + featureDim]);
	  if (mode == MODE_MEAN) {
	    gradient[ii] /= bag_size[seq_number];
	  }
          weight[ii] = ScalarConvert<Dtype, Acctype>::to(gradWeight[weightRow + featureDim]);
        }
      }

      #pragma unroll
      for (int ii = 0; ii < SZ; ii++)
      {
        weight[ii] += gradient[ii] * scale;
      }

      #pragma unroll
      for (int ii = 0; ii < SZ; ii++)
      {
        int featureDim = startFeature + ii * WARP_SIZE;
        if (featureDim < stride)
        {
          gradWeight[weightRow + featureDim] = ScalarConvert<Acctype, Dtype>::to(weight[ii]);
        }
      }

      idx++;
    } while (idx < numel && input[idx] == input[idx - 1]);
  }
}


#include "generic/LookupTableBag.cu"
#include "THCGenerateFloatTypes.h"