diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-04-09 20:52:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-04-09 20:52:56 +0300 |
commit | 536f41ad8044ec61afaab9045ab8c84a4137514b (patch) | |
tree | 9b7a664c7266a85faa88bbc428c342b8b4e61459 | |
parent | 14b181bbcf21e6cba67357eaa36a7bd611c00324 (diff) | |
parent | 97940f0a81b689657234ee456ac60e35fe72d043 (diff) |
Merge pull request #455 from twitter-forks/indexlinear
Adding Indexlinear
-rw-r--r-- | lib/THCUNN/IndexLinear.cu | 490 | ||||
-rw-r--r-- | lib/THCUNN/SparseLinear.cu | 1 | ||||
-rw-r--r-- | lib/THCUNN/generic/IndexLinear.cu | 273 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 54 | ||||
-rw-r--r-- | test.lua | 217 |
5 files changed, 1034 insertions, 1 deletions
diff --git a/lib/THCUNN/IndexLinear.cu b/lib/THCUNN/IndexLinear.cu new file mode 100644 index 0000000..fb2dc93 --- /dev/null +++ b/lib/THCUNN/IndexLinear.cu @@ -0,0 +1,490 @@ +#include "THCUNN.h" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +#include "THCAtomics.cuh" + +#define divup(a, b) ((a) + (b) - 1) / (b) +const int THREADS_PER_BLOCK = 256; +const int THREADS_X = 32; +const int THREADS_Y = THREADS_PER_BLOCK / THREADS_X; +const int REPEAT = 32; +const long NNZ_PER_BLOCK_MAX = 1024; + +/* sign MACRO */ +#ifndef clamp +#define clamp(a, low, high) max(min((a), (high)), (low)) +#endif + +#ifndef ATOMIC_REAL_MINMAX +#define ATOMIC_REAL_MINMAX(func) \ + __device__ void atomic_##func(double *address, double val) { \ + unsigned long long int* address_as_ull = (unsigned long long int*)address; \ + unsigned long long int old = *address_as_ull; \ + unsigned long long int assumed; \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_ull, assumed, \ + __double_as_longlong(func(val, __longlong_as_double(assumed)))); \ + } while (assumed != old); \ + } \ + __device__ void atomic_##func(float *address, float val) { \ + int* address_as_int = (int*)address; \ + int old = *address_as_int; \ + int assumed; \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_int, assumed, \ + __float_as_int(func(val, __int_as_float(assumed)))); \ + } while (assumed != old); \ + } \ + +ATOMIC_REAL_MINMAX(max) +ATOMIC_REAL_MINMAX(min) +#endif + +template<typename Ty, bool train> +__global__ static +void updateOutput( + Ty *output, + Ty *normalizedValues, + const Ty *values, + const long *cumSumSizes, + const long *keys, + const long batchSize, + const long outDim, + Ty *weight, + const Ty *bias, + const long weightStride, + const long keysOffset, + const int maxNormalize, + const int nnzPerBlock) +{ + /******************************************************* + * Adapted from the following file in arrayfire + * https://github.com/arrayfire/arrayfire/blob/v3.4.1/src/backend/opencl/kernel/csrmm.cl + * + ******************************************************* + * Original copyright notice can be seen below: + * + * Copyright (c) 2016, ArrayFire + * All rights reserved. + * + * This file is distributed under 3-clause BSD license. + * The complete license agreement can be obtained at: + * http://arrayfire.com/licenses/BSD-3-Clause + ********************************************************/ + + const long tidx = threadIdx.x; + const long tidy = threadIdx.y; + const long tid = tidy * blockDim.x + tidx; + const long gidx = blockIdx.x * blockDim.x + tidx; + + + Ty *nWeight = weight; + // Offset the number of elements specified by maxNormalize + weight += gidx + maxNormalize; + output += gidx; + + bool within_N = (gidx < outDim); + + __shared__ Ty s_values[THREADS_PER_BLOCK]; + __shared__ long s_keys[THREADS_PER_BLOCK]; + + const long rowId = blockIdx.y; + // if (rowId >= batchSize) return; + + // Load the nonzero column offsets for current row + const long batchStart = (rowId == 0 ? 0 : cumSumSizes[rowId - 1]) + blockIdx.z * nnzPerBlock; + const long batchEnd = min(batchStart + nnzPerBlock, cumSumSizes[rowId]); + const long batchStride = blockDim.x * blockDim.y; + + Ty outVal = 0; + // Since the number of nonzero elements might be greater than local memory available, + // Load only part of the row into local memory, perform partial dot, repeat until done. + for (long id = batchStart; id < batchEnd; id += batchStride) { + // Load the current chunk of the row into local memory + long lim = min(batchEnd - id, (long)batchStride); + + long key = tid < lim ? keys[id + tid] + keysOffset : -1; + Ty val = tid < lim ? values[id + tid] : 0; + long nWeightOffset = key * weightStride; + + if (tid < lim && maxNormalize) { + Ty *nWeightCurr = nWeight + nWeightOffset; + if (train) { + Ty absVal = fabs(val); + Ty maxVal = nWeight[key * weightStride + 0]; + if (absVal > maxVal) { + // Updating maxVal and invMaxVal. Go hogwild! + atomic_max(nWeightCurr + 0, absVal); + atomic_min(nWeightCurr + 1, 1.0/absVal); + } + val = val * nWeightCurr[1] + nWeightCurr[3]; + normalizedValues[id + tid] = val; + } else { + val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3]; + } + } + + s_keys[tid] = key; + s_values[tid] = val; + __syncthreads(); + + // Perform a single "dot" operation for each thread + for (long idy = tidy; within_N && idy < lim; idy += blockDim.y) { + outVal += s_values[idy] * weight[weightStride * s_keys[idy]]; + } + __syncthreads(); + } + + // s_values is no longer used at this point. Reuse it for reducing outVal. + // A reduction along the y dimension now gives a single output value along x. + s_values[tid] = outVal; + for (long y = blockDim.y / 2; y >= 1; y /= 2) { + __syncthreads(); + if (tidy < y) s_values[tid] = s_values[tid] + s_values[tid + y * blockDim.x]; + } + + if (within_N && tidy == 0) { + Ty val = s_values[tid] + (blockIdx.z == 0 ? bias[gidx] : 0); + if (gridDim.z == 1) { + output[rowId * outDim] = val; + } else { + atomicAdd(output + rowId * outDim, val); + } + } +} + +// This kernel takes in the following inputs: +// values of size [keysSize x 1] and gradOutput of size [batchSize x outDim], +// to generate gradWeight of size [keysSize x outDim] +// nth block along y dimension computes on the non zero elements from the nth batch. +template<typename Ty> +__global__ static +void accGradWeight( + Ty *gradWeight, + const Ty *gradOutput, + const Ty *values, + const long *cumSumSizes, + const long outDim, + const long gradWeightStride, + const Ty scale, + const Ty weightDecay, + const int maxNormalize) +{ + const long bidy = blockIdx.y; + const long tidx = threadIdx.x; + const long tidy = threadIdx.y; + const long tid = tidy * blockDim.x + tidx; + const long ntid = blockDim.x * blockDim.y; + const long gidx = blockIdx.x * blockDim.x + tidx; + + // All the y threads in the block will use the same gradOutput value + gradOutput += bidy * outDim; + Ty gradOutVal = scale * (gidx < outDim ? gradOutput[gidx] : 0); + + // Calculate the amount of work for the current block / batch. + const long batchStart = bidy == 0 ? 0 : cumSumSizes[bidy - 1]; + const long batchEnd = cumSumSizes[bidy]; + const long batchLimit = batchEnd - batchStart; + + // Number of iterations required to finish the work for the current batch. + const long iters = divup(batchLimit, ntid); + + // Offset the values to the current batch. + values += batchStart; + + // When maxNormalize is enabled, gradWeight will be twice the size. + // The first half will contain the gradients required for maxNormalization. + // The second half will contain the gradients required for updating weights. + // if maxNormalize is false, both will evaluate to the same pointer. + Ty *gradWeight0 = gradWeight + batchStart * gradWeightStride + gidx; + Ty *gradWeight1 = gradWeight0 + (maxNormalize ? outDim : 0); + + __shared__ Ty s_values[THREADS_PER_BLOCK]; + + // Using iters to avoid divergence + synchtreads + for (long n = 0; n < iters; n++) { + long off = n * ntid; + long id = off + tid; + long lim = min(ntid, batchLimit - off); + + // Read the values required for the current iteration. + s_values[tid] = id < batchLimit ? values[id] : 0; + __syncthreads(); + + if (gidx < outDim) { + if (maxNormalize) { + for (long idy = tidy; idy < lim; idy += blockDim.y) { + // gradOutVal is already scaled + gradWeight0[(off + idy) * gradWeightStride] = gradOutVal; + } + } + + for (long idy = tidy; idy < lim; idy += blockDim.y) { + gradWeight1[(off + idy) * gradWeightStride] = s_values[idy] * gradOutVal; + } + } + __syncthreads(); + } +} + +// The gradBias is just a reduction of gradOutput along the batches. +// There is only one block along y dimension performing the reduction. +template<typename Ty, bool update> +__global__ static +void accGradBias( + Ty *buffer, + const Ty *gradOutput, + const long outDim, + const long batchSize, + const Ty scale, + const Ty weightDecay) +{ + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int tid = tidy * blockDim.x + tidx; + const long idx = blockIdx.x * blockDim.x + tidx; + + + Ty gradBiasVal = 0; + gradOutput += idx; + __shared__ Ty s_gradBiasVals[THREADS_PER_BLOCK]; + + // Each thread along y calculates the partial sum. + if (idx < outDim) { + for (long idy = tidy; idy < batchSize; idy += blockDim.y) { + gradBiasVal += gradOutput[idy * outDim]; + } + } + s_gradBiasVals[tid] = gradBiasVal * scale; + __syncthreads(); + + // Perform reduction is performed along y. + for (int y = blockDim.y / 2; y >= 1; y /= 2) { + if (tidy < y) { + s_gradBiasVals[tid] += s_gradBiasVals[tid + y * blockDim.x]; + } + __syncthreads(); + } + + // Write the output only from the first lane. + if (tidy == 0 && idx < outDim) { + if (update) { + // If performing inplace update, subtract from bias. + Ty *bias = buffer; + bias[idx] = (bias[idx] - s_gradBiasVals[tid]); + } else { + // If just accumulating gradients, write to gradBias. + Ty *gradBias = buffer; + gradBias[idx] = s_gradBiasVals[tid]; + } + } +} + +// Use gradWeight from accGradWeight to update the weight. +// This kernel is launched batchSize number of times. +// At each step in the iteration, the weights are updated in a sparse manner. +template<typename Ty> +__global__ static +void updateWeight( + Ty *weight, + const Ty *gradWeight, + const long *keys, + const long *cumSumSizes, + const long outDim, + const long gradWeightStride, + const long weightStride, + const long keysOffset, + const Ty learningRate, + const Ty weightDecay, + const int maxNormalize, + const long batchId) +{ + long gidx = blockIdx.x * blockDim.x + threadIdx.x; + long gidy = blockIdx.y * blockDim.y + threadIdx.y; + + // Find the limits of the work to be done + const long batchStart = batchId == 0 ? 0 : cumSumSizes[batchId - 1]; + const long batchEnd = cumSumSizes[batchId]; + + // When maxNormalize is turned on, the weight tensor will contain + // an extra "maxNormalize" number of terms per output at the beginning. + // When maxNormalize is false, both will evaluate to same pointer. + // when maxNormalize is true, + // - nWeight[2] will contain the individual scaling factor. + // - nWeight[3] will contain the individual bias for the normalized input. + Ty *nWeight = weight; + weight += maxNormalize + gidx; + + // When maxNormalize is enabled, gradWeight will be twice the size. + // The first half will contain the gradients required for maxNormalization. + // The second half will contain the gradients required for updating weights. + // if maxNormalize is false, both will evaluate to the same pointer. + const Ty *gradWeight0 = gradWeight + gidx; + const Ty *gradWeight1 = gradWeight0 + (maxNormalize ? outDim : 0); + + if (gidx >= outDim) return; + for (long id = batchStart + gidy; id < batchEnd; id += blockDim.y * gridDim.y) { + Ty lr = learningRate; + Ty wd = weightDecay; + long weightOffset = (keys[id] + keysOffset) * weightStride; + Ty weightVal = weight[weightOffset]; + + if (maxNormalize) { + Ty scale = nWeight[weightOffset + 2]; + lr *= scale; + wd *= scale; + // nWeight[3] needs to be updated in the following manner for a given input. + // nWeight[3] = nWeight[3] - sum(gradWeight0[gidx] * weight[gidx]); + // Since problem is parallelized along gidx, use atomicAdd for the update. + Ty gradNormBias = lr * weightVal * gradWeight0[id * gradWeightStride]; + atomicAdd(nWeight + weightOffset + 3, -gradNormBias); + } + + // Perform the regular update + Ty gradWeightVal = lr * gradWeight1[id * gradWeightStride]; + if (weightDecay == 0) { + weight[weightOffset] = weightVal - gradWeightVal; + } else { + weight[weightOffset] = weightVal * (1 - wd) - gradWeightVal; + } + } +} + +// This kernel is launched batchSize number of times. +// At each step in the iteration, the weights are updated in place in a sparse manner. +template<typename Ty> +__global__ static +void accUpdateWeight( + Ty *weight, + const long weightStride, + const Ty *gradOutput, + const long outDim, + const Ty *values, + const long *cumSumSizes, + const long *keys, + const long keysOffset, + const Ty scale, + const Ty weightDecay, + const int maxNormalize, + const long batchId) +{ + // Parallel along outDim. + long gidx = blockIdx.x * blockDim.x + threadIdx.x; + // Parallel along the sparse input size for current batch. + long gidy = blockIdx.y * blockDim.y + threadIdx.y; + + if (gidx >= outDim) return; + + // Find the limits of the work to be done. + const long batchStart = batchId == 0 ? 0 : cumSumSizes[batchId - 1]; + const long batchEnd = cumSumSizes[batchId]; + + gradOutput += batchId * outDim; + Ty gradOutVal = scale * (gidx < outDim ? gradOutput[gidx] : 0); + + // When maxNormalize is turned on, the weight tensor will contain + // an extra "maxNormalize" number of terms per output at the beginning. + // When maxNormalize is false, both will evaluate to same pointer. + // when maxNormalize is true, + // - nWeight[2] will contain the individual scaling factor. + // - nWeight[3] will contain the individual bias for the normalized input. + Ty *nWeight = weight; + weight += maxNormalize + gidx; + + for (long id = batchStart + gidy; id < batchEnd; id += blockDim.y * gridDim.y) { + Ty wd = weightDecay; + long weightOffset = (keys[id] + keysOffset) * weightStride; + Ty gradWeightVal = gradOutVal * values[id]; + Ty weightVal = weight[weightOffset]; + + if (maxNormalize) { + Ty nScale = nWeight[weightOffset + 2]; + gradWeightVal *= nScale; + wd *= nScale; + // nWeight[3] needs to be updated in the following manner for a given input. + // nWeight[3] = nWeight[3] - sum(gradOut[gidx] * weight[gidx]); + // Since problem is parallelized along gidx, use atomicAdd for the update. + Ty gradNormBias = nScale * weightVal * gradOutVal; + atomicAdd(nWeight + weightOffset + 3, -gradNormBias); + } + + // Perform the regular update + if (weightDecay == 0) { + weight[weightOffset] = weightVal - gradWeightVal; + } else { + weight[weightOffset] = weightVal * (1 - wd) - gradWeightVal; + } + } +} + + +#ifdef CUDA_HALF_TENSOR +void THNN_CudaHalfIndexLinear_updateOutput( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCudaHalfTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCudaHalfTensor *output, + THCudaHalfTensor *weight, + THCudaHalfTensor *bias, + THCudaHalfTensor *normalizedValues, + int train) { + THError("THCudaHalfTensor not supported with IndexLinear"); +} + +void THNN_CudaHalfIndexLinear_accGradParameters( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCudaHalfTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCudaHalfTensor *gradOutput, + THCudaHalfTensor *gradWeight, + THCudaHalfTensor *gradBias, + THCudaHalfTensor *weight, + THCudaHalfTensor *bias, + THCudaHalfTensor* valuesBuffer, + float weightDecay, + float scale) { + THError("THCudaHalfTensor not supported with IndexLinear"); +} + +void THNN_CudaHalfIndexLinear_accUpdateGradParameters( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCudaHalfTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCudaHalfTensor *gradOutput, + THCudaHalfTensor *weight, + THCudaHalfTensor *bias, + float weightDecay, + float scale) { + THError("THCudaHalfTensor not supported with IndexLinear"); +} + +void THNN_CudaHalfIndexLinear_updateParameters( + THCState *state, + THCudaHalfTensor *gradWeight, + THCudaHalfTensor *gradBias, + THCudaHalfTensor *weight, + THCudaHalfTensor *bias, + THCudaLongTensor *runningKeys, + THCudaLongTensor *cumSumSizes, + long keysOffset, + float weightDecay, + float learningRate) { + THError("THCudaHalfTensor not supported with IndexLinear"); +} +#endif + +#include "generic/IndexLinear.cu" +#include "THCGenerateFloatType.h" +#include "generic/IndexLinear.cu" +#include "THCGenerateDoubleType.h" diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu index f36206f..9110bbc 100644 --- a/lib/THCUNN/SparseLinear.cu +++ b/lib/THCUNN/SparseLinear.cu @@ -3,7 +3,6 @@ #include "THCHalfAutoNumerics.cuh" #include <cusparse.h> -#include <thrust/device_vector.h> static cusparseHandle_t cusparse_handle = 0; diff --git a/lib/THCUNN/generic/IndexLinear.cu b/lib/THCUNN/generic/IndexLinear.cu new file mode 100644 index 0000000..ae96148 --- /dev/null +++ b/lib/THCUNN/generic/IndexLinear.cu @@ -0,0 +1,273 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/IndexLinear.cu" +#else + +static bool THCUNN_checkKeysValues(THCState *state, THCudaLongTensor* keys, + THCTensor* values) +{ + return THCudaLongTensor_size(state, keys, 0) == THCTensor_(nElement)(state, values) + && THCTensor_(nDimension)(state, values) == 1 + && THCudaLongTensor_nDimension(state, keys) == 1; +} + +void THNN_(IndexLinear_updateOutput)( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCTensor *output, + THCTensor *weight, + THCTensor *bias, + THCTensor *normalizedValues, + int train) +{ + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1, + "keys vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, values), 3, + "values vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4, + "sizes vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5, + "cumSumSizes vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, output), 6, + "output vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, weight), 7, + "weight matrix must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, bias), 8, + "bias vector must be contiguous"); + THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1, + "Keys and values should have the same number of elements"); + + long batchSize = sizes->size[0]; + long outDim = bias->size[0]; + long wDim = weight->size[1]; + long weightStride = weight->stride[0]; + int maxNormalize = wDim - outDim; + long keysSize = keys->size[0]; + long nnzPerRow = divup(keysSize, batchSize); + + THCTensor_(resize2d)(state, output, batchSize, outDim); + long *keysData = THCudaLongTensor_data (state, keys); + real *valuesData = THCTensor_(data) (state, values); + long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes); + real *biasData = THCTensor_(data) (state, bias); + real *weightData = THCTensor_(data) (state, weight); + real *outData = THCTensor_(data) (state, output); + + cudaStream_t stream = THCState_getCurrentStream(state); + dim3 threads(THREADS_X, THREADS_Y); + int blocks_x = divup(outDim, threads.x); + int blocks_y = batchSize; + int nnzPerBlock = ((outDim == 1 || batchSize == 1) ? THREADS_X : NNZ_PER_BLOCK_MAX); + int blocks_z = divup(nnzPerRow, nnzPerBlock); + + dim3 blocks(blocks_x, blocks_y, blocks_z); + + if (blocks_z > 1) { + THCudaCheck(cudaMemsetAsync(outData, 0, outDim * batchSize * sizeof(real), stream)); + } + + real *normalizedValuesData = NULL; + if (maxNormalize && train) { + THCTensor_(resize1d)(state, normalizedValues, keysSize); + normalizedValuesData = THCTensor_(data)(state, normalizedValues); + updateOutput<real, true><<<blocks, threads, 0, stream>>> + (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData, + batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock); + } else { + updateOutput<real, false><<<blocks, threads, 0, stream>>> + (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData, + batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock); + } +} + +void THNN_(IndexLinear_accGradParameters)( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCTensor *gradOutput, + THCTensor *gradWeight, + THCTensor *gradBias, + THCTensor *weight, + THCTensor *bias, + THCTensor* valuesBuffer, + accreal weightDecay, + accreal scale) +{ + long keysSize = keys->size[0]; + long batchSize = sizes->size[0]; + long outDim = bias->size[0]; + long wDim = weight->size[1]; + int maxNormalize = wDim - outDim; + + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1, + "keys vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, values), 3, + "values vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4, + "sizes vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5, + "cumSumSizes vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6, + "gradOutput vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 7, + "gradWeight matrix must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, gradBias), 8, + "gradBias vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, weight), 9, + "weight matrix must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, bias), 10, + "bias vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, valuesBuffer), 11, + "valuesBuffer vector must be contiguous"); + THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1, + "Keys and values should have the same number of elements"); + + THCTensor_(resize2d)(state, gradWeight, keysSize, outDim * (maxNormalize > 0 ? 2 : 1)); + + real *valuesData = THCTensor_(data) (state, values); + long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes); + real *gradOutputData = THCTensor_(data) (state, gradOutput); + real *gradBiasData = THCTensor_(data) (state, gradBias); + real *gradWeightData = THCTensor_(data) (state, gradWeight); + long gradWeightStride = gradWeight->stride[0]; + + cudaStream_t stream = THCState_getCurrentStream(state); + dim3 threads(THREADS_X, THREADS_Y); + int blocks_x = divup(outDim, threads.x); + accGradBias<real, false><<<blocks_x, threads, 0, stream>>> + (gradBiasData, gradOutputData, outDim, batchSize, scale, weightDecay); + + dim3 blocks(blocks_x, batchSize); + accGradWeight<real><<<blocks, threads, 0, stream>>> + (gradWeightData, gradOutputData, valuesData, cumSumSizesData, outDim, + gradWeightStride, scale, weightDecay, maxNormalize); +} + +void THNN_(IndexLinear_accUpdateGradParameters)( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCTensor *gradOutput, + THCTensor *weight, + THCTensor *bias, + accreal weightDecay, + accreal scale) +{ + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1, + "keys vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, values), 3, + "values vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4, + "sizes vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5, + "cumSumSizes vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6, + "gradOutput vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, weight), 7, + "weight matrix must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, bias), 8, + "bias vector must be contiguous"); + THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1, + "Keys and values should have the same number of elements"); + + long batchSize = sizes->size[0]; + long outDim = bias->size[0]; + long keysSize = keys->size[0]; + long wDim = weight->size[1]; + int maxNormalize = wDim - outDim; + + real *biasData = THCTensor_(data) (state, bias); + real *weightData = THCTensor_(data) (state, weight); + real *gradOutputData = THCTensor_(data) (state, gradOutput); + real *valuesData = THCTensor_(data) (state, values); + long *keysData = THCudaLongTensor_data (state, keys); + long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes); + long weightStride = weight->stride[0]; + + cudaStream_t stream = THCState_getCurrentStream(state); + dim3 threads(THREADS_X, THREADS_Y); + int blocks_x = divup(outDim, threads.x); + + accGradBias<real, true><<<blocks_x, threads, 0, stream>>> + (biasData, gradOutputData, outDim, batchSize, scale, weightDecay); + + long nnzPerRow = divup(keysSize, batchSize); + int blocks_y = divup(nnzPerRow, REPEAT * threads.y); + dim3 blocks(blocks_x, blocks_y); + + for (long batchId = 0; batchId < batchSize; batchId++) { + accUpdateWeight<real><<<blocks, threads, 0, stream>>> + (weightData, weightStride, gradOutputData, outDim, valuesData, + cumSumSizesData, keysData, keysOffset, scale, weightDecay, maxNormalize, + batchId); + } +} + +void THNN_(IndexLinear_updateParameters)( + THCState *state, + THCTensor *gradWeight, + THCTensor *gradBias, + THCTensor *weight, + THCTensor *bias, + THCudaLongTensor *runningKeys, + THCudaLongTensor *cumSumSizes, + long keysOffset, + accreal weightDecay, + accreal learningRate) +{ + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 1, + "gradWeight matrix must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, gradBias), 2, + "gradBias vector must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, weight), 3, + "weight matrix must be contiguous"); + THArgCheck(THCTensor_(isContiguous)(state, bias), 4, + "bias vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, runningKeys), 5, + "runningKeys vector must be contiguous"); + THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 6, + "cumSumSizes vector must be contiguous"); + + long outDim = bias->size[0]; + long wDim = weight->size[1]; + int maxNormalize = wDim - outDim; + long keysSize = runningKeys->size[0]; + long batchSize = cumSumSizes->size[0]; + + THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias); + long gradWeightStride = gradWeight->stride[0]; + long weightStride = weight->stride[0]; + + long *keysData = THCudaLongTensor_data (state, runningKeys); + long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes); + real *gradWeightData = THCTensor_(data) (state, gradWeight); + real *weightData = THCTensor_(data) (state, weight); + + dim3 threads(THREADS_X, THREADS_Y); + long nnzPerRow = divup(keysSize, batchSize); + int blocks_x = divup(outDim, threads.x); + int blocks_y = divup(nnzPerRow, REPEAT * threads.y); + dim3 blocks(blocks_x, blocks_y); + cudaStream_t stream = THCState_getCurrentStream(state); + + for (long batchId = 0; batchId < batchSize; batchId++) { + updateWeight<real><<<blocks, threads, 0, stream>>> + (weightData, gradWeightData, keysData, cumSumSizesData, outDim, + gradWeightStride, weightStride, keysOffset, learningRate, weightDecay, + maxNormalize, batchId); + } +} +#endif diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index a71209a..3c4c38f 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -379,6 +379,60 @@ TH_API void THNN_(SparseLinear_updateParameters)( THCTensor *lastInput, accreal learningRate); +TH_API void THNN_(IndexLinear_updateOutput)( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCTensor *output, + THCTensor *weight, + THCTensor *bias, + THCTensor *normalizedValues, + int train); + +TH_API void THNN_(IndexLinear_accGradParameters)( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCTensor *gradOutput, + THCTensor *gradWeight, + THCTensor *gradBias, + THCTensor *weight, + THCTensor *bias, + THCTensor* valuesBuffer, + accreal weightDecay, + accreal scale); + +TH_API void THNN_(IndexLinear_accUpdateGradParameters)( + THCState *state, + THCudaLongTensor *keys, + long keysOffset, + THCTensor *values, + THCudaLongTensor *sizes, + THCudaLongTensor *cumSumSizes, + THCTensor *gradOutput, + THCTensor *weight, + THCTensor *bias, + accreal weightDecay, + accreal scale); + +TH_API void THNN_(IndexLinear_updateParameters)( + THCState *state, + THCTensor *gradWeight, + THCTensor *gradBias, + THCTensor *weight, + THCTensor *bias, + THCudaLongTensor *runningKeys, + THCudaLongTensor *cumSumSizes, + long keysOffset, + accreal weightDecay, + accreal learningRate); + TH_API void THNN_(SpatialAdaptiveMaxPooling_updateOutput)( THCState *state, THCTensor *input, @@ -5915,6 +5915,223 @@ function cunntest.ModuleConversionFunctions() end end +function cunntest.IndexLinear() + isize = 500E3 + osize = 250 + weightDecay = 0.01 + nnzMin = 1000 + nnzMax = 1500 + idxMin = 1 + idxMax = isize + batchSize = 128 + lr = 0.01 + ntests = 1 + + local errNorm = function(a, b) + return torch.Tensor(1):fill(torch.cdiv((a - b):abs(), a:abs()):max()) + end + + local ilc = nn.IndexLinear(isize, osize):float() + local ilg = nn.IndexLinear(isize, osize):float():cuda() + + local ilc2 = nn.IndexLinear(isize, osize):float() + local ilg2 = nn.IndexLinear(isize, osize):float():cuda() + + local tot = 0 + local samples = 0 + local inputCPU = {{}, {}} + local inputGPU = {{}, {}} + local flatInputCPU = {torch.LongTensor(), torch.FloatTensor(), torch.LongTensor()} + local flatInputGPU = {torch.CudaLongTensor(), torch.CudaTensor(), torch.CudaLongTensor()} + local sizes = torch.LongTensor(batchSize) + for i=1,batchSize do + local n = torch.random(nnzMin, nnzMax) + local indices = idxMin + torch.LongTensor():randperm(idxMax - idxMin) + inputCPU[1][i] = indices[{{1,n}}] + inputCPU[2][i] = torch.FloatTensor(n):uniform() + inputGPU[1][i] = torch.CudaLongTensor(n):copy(inputCPU[1][i]) + inputGPU[2][i] = torch.CudaTensor(n):copy(inputCPU[2][i]) + sizes[i] = n + tot = tot + n + end + flatInputCPU[1]:cat(inputCPU[1], 1) + flatInputCPU[2]:cat(inputCPU[2], 1) + flatInputCPU[3] = sizes + + flatInputGPU[1]:cat(inputGPU[1], 1) + flatInputGPU[2]:cat(inputGPU[2], 1) + flatInputGPU[3] = sizes:cudaLong() + + local inputSize = #inputCPU[1] + local gradOutsCPU = torch.FloatTensor(inputSize, osize):uniform() + local gradOutsGPU = torch.CudaTensor(inputSize, osize):copy(gradOutsCPU) + + local outputCPU, outputGPU + local flatOutputCPU, flatOutputGPU + + ilc.weightDecay = weightDecay + ilg.weightDecay = weightDecay + ilc2.weightDecay = weightDecay + ilg2.weightDecay = weightDecay + + ilc.weight:uniform() + ilc.bias:fill(1) + ilc2.weight:uniform() + ilc2.bias:fill(1) + + ilg.weight:copy(ilc.weight) + ilg.bias:copy(ilc.bias) + ilg2.weight:copy(ilc2.weight) + ilg2.bias:copy(ilc2.bias) + + ilc:zeroGradParameters() + outputCPU = ilc:forward(inputCPU) + ilc:backward(inputCPU, gradOutsCPU); + ilc:updateParameters(lr) + + ilc2:zeroGradParameters() + flatOutputCPU = ilc2:forward(flatInputCPU) + ilc2:backward(flatInputCPU, gradOutsCPU); + ilc2:updateParameters(lr) + + ilg:zeroGradParameters() + outputGPU = ilg:forward(inputGPU) + ilg:backward(inputGPU, gradOutsGPU); + ilg:updateParameters(lr) + + ilg2:zeroGradParameters() + flatOutputGPU = ilg2:forward(flatInputGPU) + ilg2:backward(flatInputGPU, gradOutsGPU); + ilg2:updateParameters(lr) + + mytester:assertTensorEq(errNorm(outputCPU, outputGPU:float()), + torch.Tensor(1):fill(0), + 1E-5, "cunn.IndexLinear:forward failed for output") + + mytester:assertTensorEq(errNorm(flatOutputCPU, flatOutputGPU:float()), + torch.Tensor(1):fill(0), + 1E-5, "cunn.IndexLinear:forward failed for flatOutput") + + mytester:assertTensorEq(ilc.bias, + ilg.bias:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for bias for tensor array") + + mytester:assertTensorEq(ilc.weight, + ilg.weight:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for weight for tensor array") + + mytester:assertTensorEq(ilc2.bias, + ilg2.bias:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for bias for flat input") + + mytester:assertTensorEq(ilc2.weight, + ilg2.weight:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for weight for flat input") + + ilc.weight:uniform() + ilc.bias:fill(1) + + ilg.weight:copy(ilc.weight) + ilg.bias:copy(ilc.bias) + + ilc2.weight:uniform() + ilc2.bias:fill(1) + + ilg2.weight:copy(ilc2.weight) + ilg2.bias:copy(ilc2.bias) + + outputCPU = ilc:forward(inputCPU) + ilc:backwardUpdate(inputCPU, gradOutsCPU, lr); + + outputGPU = ilg:forward(inputGPU) + ilg:backwardUpdate(inputGPU, gradOutsGPU, lr); + + flatOutputCPU = ilc2:forward(flatInputCPU) + ilc2:backwardUpdate(flatInputCPU, gradOutsCPU, lr); + + flatOutputGPU = ilg2:forward(flatInputGPU) + ilg2:backwardUpdate(flatInputGPU, gradOutsGPU, lr); + + mytester:assertTensorEq(errNorm(outputCPU, outputGPU:float()), + torch.Tensor(1):fill(0), + 1E-5, "cunn.IndexLinear:forward failed for output") + + mytester:assertTensorEq(errNorm(flatOutputCPU, flatOutputGPU:float()), + torch.Tensor(1):fill(0), + 1E-5, "cunn.IndexLinear:forward failed for flatOutput") + + mytester:assertTensorEq(ilc.bias, + ilg.bias:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for bias for tensor array") + + mytester:assertTensorEq(ilc.weight, + ilg.weight:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for weight for tensor array") + + mytester:assertTensorEq(ilc2.bias, + ilg2.bias:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for bias for flat input") + + mytester:assertTensorEq(ilc2.weight, + ilg2.weight:float(), + 1E-5, "cunn.IndexLinear:backward+update failed for weight for flat input") +end + +function cunntest.IndexLinearMaxNorm() + isize = 500E3 + osize = 250 + weightDecay = 0 + nnzMin = 1000 + nnzMax = 1500 + idxMin = 1 + idxMax = isize + batchSize = 128 + lr = 0.01 + ntests = 1 + + local errNorm = function(a, b) + return torch.Tensor(1):fill(torch.cdiv((a - b):abs(), a:abs()):max()) + end + + local ilc = nn.IndexLinear(isize, osize, nil, nil, nil, nil, 1):float() + local ilg = nn.IndexLinear(isize, osize, nil, nil, nil, nil, 1):float():cuda() + + local tot = 0 + local samples = 0 + local inputCPU = {{}, {}} + local inputGPU = {{}, {}} + for i=1,batchSize do + local n = torch.random(nnzMin, nnzMax) + local indices = idxMin + torch.LongTensor():randperm(idxMax - idxMin) + inputCPU[1][i] = indices[{{1,n}}] + inputCPU[2][i] = torch.FloatTensor(n):uniform() + inputGPU[1][i] = torch.CudaLongTensor(n):copy(inputCPU[1][i]) + inputGPU[2][i] = torch.CudaTensor(n):copy(inputCPU[2][i]) + tot = tot + n + end + + local inputSize = #inputCPU[1] + local gradOutsCPU = torch.FloatTensor(inputSize, osize):uniform() + local gradOutsGPU = torch.CudaTensor(inputSize, osize):copy(gradOutsCPU) + + ilc.weightDecay = weightDecay + ilg.weightDecay = weightDecay + + ilc.weight:uniform() + ilc.weight:narrow(2,2,1):fill(1.0):cdiv(ilc.weight:narrow(2,1,1)) + ilc.bias:fill(1) + + ilg.weight:copy(ilc.weight) + ilg.bias:copy(ilc.bias) + + outputCPU = ilc:forward(inputCPU) + outputGPU = ilg:forward(inputGPU) + + mytester:assertTensorEq(errNorm(outputCPU, outputGPU:float()), + torch.Tensor(1):fill(0), + 1E-5, "cunn.IndexLinear:forward failed for output") +end + function cunntest.GPU() local ndevice = cutorch.getDeviceCount() if ndevice < 2 then |