Welcome to mirror list, hosted at ThFree Co, Russian Federation.

IndexLinear.cu « generic « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: ae961484552da30b1ee782095e78c24f24e0ceba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/IndexLinear.cu"
#else

static bool THCUNN_checkKeysValues(THCState *state, THCudaLongTensor* keys,
                                   THCTensor* values)
{
    return THCudaLongTensor_size(state, keys, 0) == THCTensor_(nElement)(state, values)
        && THCTensor_(nDimension)(state, values) == 1
        && THCudaLongTensor_nDimension(state, keys) == 1;
}

void THNN_(IndexLinear_updateOutput)(
    THCState *state,
    THCudaLongTensor *keys,
    long keysOffset,
    THCTensor *values,
    THCudaLongTensor *sizes,
    THCudaLongTensor *cumSumSizes,
    THCTensor *output,
    THCTensor *weight,
    THCTensor *bias,
    THCTensor *normalizedValues,
    int   train)
{
    // Make sure these inputs are contiguous to accelerate computations
    THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
               "keys vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, values), 3,
               "values vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
               "sizes vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
               "cumSumSizes vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, output), 6,
               "output vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, weight), 7,
               "weight matrix must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, bias), 8,
               "bias vector must be contiguous");
    THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
               "Keys and values should have the same number of elements");

    long batchSize = sizes->size[0];
    long outDim = bias->size[0];
    long wDim = weight->size[1];
    long weightStride = weight->stride[0];
    int maxNormalize = wDim - outDim;
    long keysSize = keys->size[0];
    long nnzPerRow = divup(keysSize, batchSize);

    THCTensor_(resize2d)(state, output, batchSize, outDim);
    long *keysData        = THCudaLongTensor_data (state, keys);
    real *valuesData      = THCTensor_(data)      (state, values);
    long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
    real *biasData        = THCTensor_(data)      (state, bias);
    real *weightData      = THCTensor_(data)      (state, weight);
    real *outData         = THCTensor_(data)      (state, output);

    cudaStream_t stream = THCState_getCurrentStream(state);
    dim3 threads(THREADS_X, THREADS_Y);
    int blocks_x = divup(outDim, threads.x);
    int blocks_y = batchSize;
    int nnzPerBlock = ((outDim == 1 || batchSize == 1) ? THREADS_X : NNZ_PER_BLOCK_MAX);
    int blocks_z = divup(nnzPerRow, nnzPerBlock);

    dim3 blocks(blocks_x, blocks_y, blocks_z);

    if (blocks_z > 1) {
        THCudaCheck(cudaMemsetAsync(outData, 0, outDim * batchSize * sizeof(real), stream));
    }

    real *normalizedValuesData = NULL;
    if (maxNormalize && train) {
        THCTensor_(resize1d)(state, normalizedValues, keysSize);
        normalizedValuesData = THCTensor_(data)(state, normalizedValues);
        updateOutput<real, true><<<blocks, threads, 0, stream>>>
            (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
             batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
    } else {
        updateOutput<real, false><<<blocks, threads, 0, stream>>>
            (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
             batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
    }
}

void THNN_(IndexLinear_accGradParameters)(
    THCState *state,
    THCudaLongTensor *keys,
    long keysOffset,
    THCTensor *values,
    THCudaLongTensor *sizes,
    THCudaLongTensor *cumSumSizes,
    THCTensor *gradOutput,
    THCTensor *gradWeight,
    THCTensor *gradBias,
    THCTensor *weight,
    THCTensor *bias,
    THCTensor* valuesBuffer,
    accreal weightDecay,
    accreal scale)
{
    long keysSize = keys->size[0];
    long batchSize = sizes->size[0];
    long outDim = bias->size[0];
    long wDim = weight->size[1];
    int maxNormalize = wDim - outDim;

    // Make sure these inputs are contiguous to accelerate computations
    THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
               "keys vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, values), 3,
               "values vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
               "sizes vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
               "cumSumSizes vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6,
               "gradOutput vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 7,
               "gradWeight matrix must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, gradBias), 8,
               "gradBias vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, weight), 9,
               "weight matrix must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, bias), 10,
               "bias vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, valuesBuffer), 11,
               "valuesBuffer vector must be contiguous");
    THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
               "Keys and values should have the same number of elements");

    THCTensor_(resize2d)(state, gradWeight, keysSize, outDim * (maxNormalize > 0 ? 2 : 1));

    real *valuesData      = THCTensor_(data)      (state, values);
    long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
    real *gradOutputData  = THCTensor_(data)      (state, gradOutput);
    real *gradBiasData    = THCTensor_(data)      (state, gradBias);
    real *gradWeightData  = THCTensor_(data)      (state, gradWeight);
    long gradWeightStride = gradWeight->stride[0];

    cudaStream_t stream = THCState_getCurrentStream(state);
    dim3 threads(THREADS_X, THREADS_Y);
    int blocks_x = divup(outDim, threads.x);
    accGradBias<real, false><<<blocks_x, threads, 0, stream>>>
        (gradBiasData, gradOutputData, outDim, batchSize, scale, weightDecay);

    dim3 blocks(blocks_x, batchSize);
    accGradWeight<real><<<blocks, threads, 0, stream>>>
        (gradWeightData, gradOutputData, valuesData, cumSumSizesData, outDim,
         gradWeightStride, scale, weightDecay, maxNormalize);
}

void THNN_(IndexLinear_accUpdateGradParameters)(
    THCState *state,
    THCudaLongTensor *keys,
    long keysOffset,
    THCTensor *values,
    THCudaLongTensor *sizes,
    THCudaLongTensor *cumSumSizes,
    THCTensor *gradOutput,
    THCTensor *weight,
    THCTensor *bias,
    accreal weightDecay,
    accreal scale)
{
    // Make sure these inputs are contiguous to accelerate computations
    THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
               "keys vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, values), 3,
               "values vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
               "sizes vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
               "cumSumSizes vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6,
               "gradOutput vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, weight), 7,
               "weight matrix must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, bias), 8,
               "bias vector must be contiguous");
    THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
               "Keys and values should have the same number of elements");

    long batchSize = sizes->size[0];
    long outDim = bias->size[0];
    long keysSize = keys->size[0];
    long wDim = weight->size[1];
    int maxNormalize = wDim - outDim;

    real *biasData         = THCTensor_(data)      (state, bias);
    real *weightData       = THCTensor_(data)      (state, weight);
    real *gradOutputData   = THCTensor_(data)      (state, gradOutput);
    real *valuesData       = THCTensor_(data)      (state, values);
    long *keysData         = THCudaLongTensor_data (state, keys);
    long *cumSumSizesData  = THCudaLongTensor_data (state, cumSumSizes);
    long weightStride = weight->stride[0];

    cudaStream_t stream = THCState_getCurrentStream(state);
    dim3 threads(THREADS_X, THREADS_Y);
    int blocks_x = divup(outDim, threads.x);

    accGradBias<real, true><<<blocks_x, threads, 0, stream>>>
        (biasData, gradOutputData, outDim, batchSize, scale, weightDecay);

    long nnzPerRow = divup(keysSize, batchSize);
    int blocks_y = divup(nnzPerRow, REPEAT * threads.y);
    dim3 blocks(blocks_x, blocks_y);

    for (long batchId = 0; batchId < batchSize; batchId++) {
        accUpdateWeight<real><<<blocks, threads, 0, stream>>>
            (weightData, weightStride, gradOutputData, outDim, valuesData,
             cumSumSizesData, keysData, keysOffset, scale, weightDecay, maxNormalize,
             batchId);
    }
}

void THNN_(IndexLinear_updateParameters)(
    THCState *state,
    THCTensor *gradWeight,
    THCTensor *gradBias,
    THCTensor *weight,
    THCTensor *bias,
    THCudaLongTensor *runningKeys,
    THCudaLongTensor *cumSumSizes,
    long keysOffset,
    accreal weightDecay,
    accreal learningRate)
{
    // Make sure these inputs are contiguous to accelerate computations
    THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 1,
               "gradWeight matrix must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, gradBias), 2,
               "gradBias vector must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, weight), 3,
               "weight matrix must be contiguous");
    THArgCheck(THCTensor_(isContiguous)(state, bias), 4,
               "bias vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, runningKeys), 5,
               "runningKeys vector must be contiguous");
    THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 6,
               "cumSumSizes vector must be contiguous");

    long outDim = bias->size[0];
    long wDim = weight->size[1];
    int maxNormalize = wDim - outDim;
    long keysSize = runningKeys->size[0];
    long batchSize = cumSumSizes->size[0];

    THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias);
    long gradWeightStride = gradWeight->stride[0];
    long weightStride = weight->stride[0];

    long *keysData        = THCudaLongTensor_data (state, runningKeys);
    long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
    real *gradWeightData  = THCTensor_(data)      (state, gradWeight);
    real *weightData      = THCTensor_(data)      (state, weight);

    dim3 threads(THREADS_X, THREADS_Y);
    long nnzPerRow = divup(keysSize, batchSize);
    int blocks_x = divup(outDim, threads.x);
    int blocks_y = divup(nnzPerRow, REPEAT * threads.y);
    dim3 blocks(blocks_x, blocks_y);
    cudaStream_t stream = THCState_getCurrentStream(state);

    for (long batchId = 0; batchId < batchSize; batchId++) {
        updateWeight<real><<<blocks, threads, 0, stream>>>
            (weightData, gradWeightData, keysData, cumSumSizesData, outDim,
             gradWeightStride, weightStride, keysOffset, learningRate, weightDecay,
             maxNormalize, batchId);
    }
}
#endif