Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-04-09 20:52:56 +0300
committerGitHub <noreply@github.com>2017-04-09 20:52:56 +0300
commit536f41ad8044ec61afaab9045ab8c84a4137514b (patch)
tree9b7a664c7266a85faa88bbc428c342b8b4e61459
parent14b181bbcf21e6cba67357eaa36a7bd611c00324 (diff)
parent97940f0a81b689657234ee456ac60e35fe72d043 (diff)
Merge pull request #455 from twitter-forks/indexlinear
Adding Indexlinear
-rw-r--r--lib/THCUNN/IndexLinear.cu490
-rw-r--r--lib/THCUNN/SparseLinear.cu1
-rw-r--r--lib/THCUNN/generic/IndexLinear.cu273
-rw-r--r--lib/THCUNN/generic/THCUNN.h54
-rw-r--r--test.lua217
5 files changed, 1034 insertions, 1 deletions
diff --git a/lib/THCUNN/IndexLinear.cu b/lib/THCUNN/IndexLinear.cu
new file mode 100644
index 0000000..fb2dc93
--- /dev/null
+++ b/lib/THCUNN/IndexLinear.cu
@@ -0,0 +1,490 @@
+#include "THCUNN.h"
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+#include "THCAtomics.cuh"
+
+#define divup(a, b) ((a) + (b) - 1) / (b)
+const int THREADS_PER_BLOCK = 256;
+const int THREADS_X = 32;
+const int THREADS_Y = THREADS_PER_BLOCK / THREADS_X;
+const int REPEAT = 32;
+const long NNZ_PER_BLOCK_MAX = 1024;
+
+/* sign MACRO */
+#ifndef clamp
+#define clamp(a, low, high) max(min((a), (high)), (low))
+#endif
+
+#ifndef ATOMIC_REAL_MINMAX
+#define ATOMIC_REAL_MINMAX(func) \
+ __device__ void atomic_##func(double *address, double val) { \
+ unsigned long long int* address_as_ull = (unsigned long long int*)address; \
+ unsigned long long int old = *address_as_ull; \
+ unsigned long long int assumed; \
+ do { \
+ assumed = old; \
+ old = atomicCAS(address_as_ull, assumed, \
+ __double_as_longlong(func(val, __longlong_as_double(assumed)))); \
+ } while (assumed != old); \
+ } \
+ __device__ void atomic_##func(float *address, float val) { \
+ int* address_as_int = (int*)address; \
+ int old = *address_as_int; \
+ int assumed; \
+ do { \
+ assumed = old; \
+ old = atomicCAS(address_as_int, assumed, \
+ __float_as_int(func(val, __int_as_float(assumed)))); \
+ } while (assumed != old); \
+ } \
+
+ATOMIC_REAL_MINMAX(max)
+ATOMIC_REAL_MINMAX(min)
+#endif
+
+template<typename Ty, bool train>
+__global__ static
+void updateOutput(
+ Ty *output,
+ Ty *normalizedValues,
+ const Ty *values,
+ const long *cumSumSizes,
+ const long *keys,
+ const long batchSize,
+ const long outDim,
+ Ty *weight,
+ const Ty *bias,
+ const long weightStride,
+ const long keysOffset,
+ const int maxNormalize,
+ const int nnzPerBlock)
+{
+ /*******************************************************
+ * Adapted from the following file in arrayfire
+ * https://github.com/arrayfire/arrayfire/blob/v3.4.1/src/backend/opencl/kernel/csrmm.cl
+ *
+ *******************************************************
+ * Original copyright notice can be seen below:
+ *
+ * Copyright (c) 2016, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+ const long tidx = threadIdx.x;
+ const long tidy = threadIdx.y;
+ const long tid = tidy * blockDim.x + tidx;
+ const long gidx = blockIdx.x * blockDim.x + tidx;
+
+
+ Ty *nWeight = weight;
+ // Offset the number of elements specified by maxNormalize
+ weight += gidx + maxNormalize;
+ output += gidx;
+
+ bool within_N = (gidx < outDim);
+
+ __shared__ Ty s_values[THREADS_PER_BLOCK];
+ __shared__ long s_keys[THREADS_PER_BLOCK];
+
+ const long rowId = blockIdx.y;
+ // if (rowId >= batchSize) return;
+
+ // Load the nonzero column offsets for current row
+ const long batchStart = (rowId == 0 ? 0 : cumSumSizes[rowId - 1]) + blockIdx.z * nnzPerBlock;
+ const long batchEnd = min(batchStart + nnzPerBlock, cumSumSizes[rowId]);
+ const long batchStride = blockDim.x * blockDim.y;
+
+ Ty outVal = 0;
+ // Since the number of nonzero elements might be greater than local memory available,
+ // Load only part of the row into local memory, perform partial dot, repeat until done.
+ for (long id = batchStart; id < batchEnd; id += batchStride) {
+ // Load the current chunk of the row into local memory
+ long lim = min(batchEnd - id, (long)batchStride);
+
+ long key = tid < lim ? keys[id + tid] + keysOffset : -1;
+ Ty val = tid < lim ? values[id + tid] : 0;
+ long nWeightOffset = key * weightStride;
+
+ if (tid < lim && maxNormalize) {
+ Ty *nWeightCurr = nWeight + nWeightOffset;
+ if (train) {
+ Ty absVal = fabs(val);
+ Ty maxVal = nWeight[key * weightStride + 0];
+ if (absVal > maxVal) {
+ // Updating maxVal and invMaxVal. Go hogwild!
+ atomic_max(nWeightCurr + 0, absVal);
+ atomic_min(nWeightCurr + 1, 1.0/absVal);
+ }
+ val = val * nWeightCurr[1] + nWeightCurr[3];
+ normalizedValues[id + tid] = val;
+ } else {
+ val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
+ }
+ }
+
+ s_keys[tid] = key;
+ s_values[tid] = val;
+ __syncthreads();
+
+ // Perform a single "dot" operation for each thread
+ for (long idy = tidy; within_N && idy < lim; idy += blockDim.y) {
+ outVal += s_values[idy] * weight[weightStride * s_keys[idy]];
+ }
+ __syncthreads();
+ }
+
+ // s_values is no longer used at this point. Reuse it for reducing outVal.
+ // A reduction along the y dimension now gives a single output value along x.
+ s_values[tid] = outVal;
+ for (long y = blockDim.y / 2; y >= 1; y /= 2) {
+ __syncthreads();
+ if (tidy < y) s_values[tid] = s_values[tid] + s_values[tid + y * blockDim.x];
+ }
+
+ if (within_N && tidy == 0) {
+ Ty val = s_values[tid] + (blockIdx.z == 0 ? bias[gidx] : 0);
+ if (gridDim.z == 1) {
+ output[rowId * outDim] = val;
+ } else {
+ atomicAdd(output + rowId * outDim, val);
+ }
+ }
+}
+
+// This kernel takes in the following inputs:
+// values of size [keysSize x 1] and gradOutput of size [batchSize x outDim],
+// to generate gradWeight of size [keysSize x outDim]
+// nth block along y dimension computes on the non zero elements from the nth batch.
+template<typename Ty>
+__global__ static
+void accGradWeight(
+ Ty *gradWeight,
+ const Ty *gradOutput,
+ const Ty *values,
+ const long *cumSumSizes,
+ const long outDim,
+ const long gradWeightStride,
+ const Ty scale,
+ const Ty weightDecay,
+ const int maxNormalize)
+{
+ const long bidy = blockIdx.y;
+ const long tidx = threadIdx.x;
+ const long tidy = threadIdx.y;
+ const long tid = tidy * blockDim.x + tidx;
+ const long ntid = blockDim.x * blockDim.y;
+ const long gidx = blockIdx.x * blockDim.x + tidx;
+
+ // All the y threads in the block will use the same gradOutput value
+ gradOutput += bidy * outDim;
+ Ty gradOutVal = scale * (gidx < outDim ? gradOutput[gidx] : 0);
+
+ // Calculate the amount of work for the current block / batch.
+ const long batchStart = bidy == 0 ? 0 : cumSumSizes[bidy - 1];
+ const long batchEnd = cumSumSizes[bidy];
+ const long batchLimit = batchEnd - batchStart;
+
+ // Number of iterations required to finish the work for the current batch.
+ const long iters = divup(batchLimit, ntid);
+
+ // Offset the values to the current batch.
+ values += batchStart;
+
+ // When maxNormalize is enabled, gradWeight will be twice the size.
+ // The first half will contain the gradients required for maxNormalization.
+ // The second half will contain the gradients required for updating weights.
+ // if maxNormalize is false, both will evaluate to the same pointer.
+ Ty *gradWeight0 = gradWeight + batchStart * gradWeightStride + gidx;
+ Ty *gradWeight1 = gradWeight0 + (maxNormalize ? outDim : 0);
+
+ __shared__ Ty s_values[THREADS_PER_BLOCK];
+
+ // Using iters to avoid divergence + synchtreads
+ for (long n = 0; n < iters; n++) {
+ long off = n * ntid;
+ long id = off + tid;
+ long lim = min(ntid, batchLimit - off);
+
+ // Read the values required for the current iteration.
+ s_values[tid] = id < batchLimit ? values[id] : 0;
+ __syncthreads();
+
+ if (gidx < outDim) {
+ if (maxNormalize) {
+ for (long idy = tidy; idy < lim; idy += blockDim.y) {
+ // gradOutVal is already scaled
+ gradWeight0[(off + idy) * gradWeightStride] = gradOutVal;
+ }
+ }
+
+ for (long idy = tidy; idy < lim; idy += blockDim.y) {
+ gradWeight1[(off + idy) * gradWeightStride] = s_values[idy] * gradOutVal;
+ }
+ }
+ __syncthreads();
+ }
+}
+
+// The gradBias is just a reduction of gradOutput along the batches.
+// There is only one block along y dimension performing the reduction.
+template<typename Ty, bool update>
+__global__ static
+void accGradBias(
+ Ty *buffer,
+ const Ty *gradOutput,
+ const long outDim,
+ const long batchSize,
+ const Ty scale,
+ const Ty weightDecay)
+{
+ const int tidx = threadIdx.x;
+ const int tidy = threadIdx.y;
+ const int tid = tidy * blockDim.x + tidx;
+ const long idx = blockIdx.x * blockDim.x + tidx;
+
+
+ Ty gradBiasVal = 0;
+ gradOutput += idx;
+ __shared__ Ty s_gradBiasVals[THREADS_PER_BLOCK];
+
+ // Each thread along y calculates the partial sum.
+ if (idx < outDim) {
+ for (long idy = tidy; idy < batchSize; idy += blockDim.y) {
+ gradBiasVal += gradOutput[idy * outDim];
+ }
+ }
+ s_gradBiasVals[tid] = gradBiasVal * scale;
+ __syncthreads();
+
+ // Perform reduction is performed along y.
+ for (int y = blockDim.y / 2; y >= 1; y /= 2) {
+ if (tidy < y) {
+ s_gradBiasVals[tid] += s_gradBiasVals[tid + y * blockDim.x];
+ }
+ __syncthreads();
+ }
+
+ // Write the output only from the first lane.
+ if (tidy == 0 && idx < outDim) {
+ if (update) {
+ // If performing inplace update, subtract from bias.
+ Ty *bias = buffer;
+ bias[idx] = (bias[idx] - s_gradBiasVals[tid]);
+ } else {
+ // If just accumulating gradients, write to gradBias.
+ Ty *gradBias = buffer;
+ gradBias[idx] = s_gradBiasVals[tid];
+ }
+ }
+}
+
+// Use gradWeight from accGradWeight to update the weight.
+// This kernel is launched batchSize number of times.
+// At each step in the iteration, the weights are updated in a sparse manner.
+template<typename Ty>
+__global__ static
+void updateWeight(
+ Ty *weight,
+ const Ty *gradWeight,
+ const long *keys,
+ const long *cumSumSizes,
+ const long outDim,
+ const long gradWeightStride,
+ const long weightStride,
+ const long keysOffset,
+ const Ty learningRate,
+ const Ty weightDecay,
+ const int maxNormalize,
+ const long batchId)
+{
+ long gidx = blockIdx.x * blockDim.x + threadIdx.x;
+ long gidy = blockIdx.y * blockDim.y + threadIdx.y;
+
+ // Find the limits of the work to be done
+ const long batchStart = batchId == 0 ? 0 : cumSumSizes[batchId - 1];
+ const long batchEnd = cumSumSizes[batchId];
+
+ // When maxNormalize is turned on, the weight tensor will contain
+ // an extra "maxNormalize" number of terms per output at the beginning.
+ // When maxNormalize is false, both will evaluate to same pointer.
+ // when maxNormalize is true,
+ // - nWeight[2] will contain the individual scaling factor.
+ // - nWeight[3] will contain the individual bias for the normalized input.
+ Ty *nWeight = weight;
+ weight += maxNormalize + gidx;
+
+ // When maxNormalize is enabled, gradWeight will be twice the size.
+ // The first half will contain the gradients required for maxNormalization.
+ // The second half will contain the gradients required for updating weights.
+ // if maxNormalize is false, both will evaluate to the same pointer.
+ const Ty *gradWeight0 = gradWeight + gidx;
+ const Ty *gradWeight1 = gradWeight0 + (maxNormalize ? outDim : 0);
+
+ if (gidx >= outDim) return;
+ for (long id = batchStart + gidy; id < batchEnd; id += blockDim.y * gridDim.y) {
+ Ty lr = learningRate;
+ Ty wd = weightDecay;
+ long weightOffset = (keys[id] + keysOffset) * weightStride;
+ Ty weightVal = weight[weightOffset];
+
+ if (maxNormalize) {
+ Ty scale = nWeight[weightOffset + 2];
+ lr *= scale;
+ wd *= scale;
+ // nWeight[3] needs to be updated in the following manner for a given input.
+ // nWeight[3] = nWeight[3] - sum(gradWeight0[gidx] * weight[gidx]);
+ // Since problem is parallelized along gidx, use atomicAdd for the update.
+ Ty gradNormBias = lr * weightVal * gradWeight0[id * gradWeightStride];
+ atomicAdd(nWeight + weightOffset + 3, -gradNormBias);
+ }
+
+ // Perform the regular update
+ Ty gradWeightVal = lr * gradWeight1[id * gradWeightStride];
+ if (weightDecay == 0) {
+ weight[weightOffset] = weightVal - gradWeightVal;
+ } else {
+ weight[weightOffset] = weightVal * (1 - wd) - gradWeightVal;
+ }
+ }
+}
+
+// This kernel is launched batchSize number of times.
+// At each step in the iteration, the weights are updated in place in a sparse manner.
+template<typename Ty>
+__global__ static
+void accUpdateWeight(
+ Ty *weight,
+ const long weightStride,
+ const Ty *gradOutput,
+ const long outDim,
+ const Ty *values,
+ const long *cumSumSizes,
+ const long *keys,
+ const long keysOffset,
+ const Ty scale,
+ const Ty weightDecay,
+ const int maxNormalize,
+ const long batchId)
+{
+ // Parallel along outDim.
+ long gidx = blockIdx.x * blockDim.x + threadIdx.x;
+ // Parallel along the sparse input size for current batch.
+ long gidy = blockIdx.y * blockDim.y + threadIdx.y;
+
+ if (gidx >= outDim) return;
+
+ // Find the limits of the work to be done.
+ const long batchStart = batchId == 0 ? 0 : cumSumSizes[batchId - 1];
+ const long batchEnd = cumSumSizes[batchId];
+
+ gradOutput += batchId * outDim;
+ Ty gradOutVal = scale * (gidx < outDim ? gradOutput[gidx] : 0);
+
+ // When maxNormalize is turned on, the weight tensor will contain
+ // an extra "maxNormalize" number of terms per output at the beginning.
+ // When maxNormalize is false, both will evaluate to same pointer.
+ // when maxNormalize is true,
+ // - nWeight[2] will contain the individual scaling factor.
+ // - nWeight[3] will contain the individual bias for the normalized input.
+ Ty *nWeight = weight;
+ weight += maxNormalize + gidx;
+
+ for (long id = batchStart + gidy; id < batchEnd; id += blockDim.y * gridDim.y) {
+ Ty wd = weightDecay;
+ long weightOffset = (keys[id] + keysOffset) * weightStride;
+ Ty gradWeightVal = gradOutVal * values[id];
+ Ty weightVal = weight[weightOffset];
+
+ if (maxNormalize) {
+ Ty nScale = nWeight[weightOffset + 2];
+ gradWeightVal *= nScale;
+ wd *= nScale;
+ // nWeight[3] needs to be updated in the following manner for a given input.
+ // nWeight[3] = nWeight[3] - sum(gradOut[gidx] * weight[gidx]);
+ // Since problem is parallelized along gidx, use atomicAdd for the update.
+ Ty gradNormBias = nScale * weightVal * gradOutVal;
+ atomicAdd(nWeight + weightOffset + 3, -gradNormBias);
+ }
+
+ // Perform the regular update
+ if (weightDecay == 0) {
+ weight[weightOffset] = weightVal - gradWeightVal;
+ } else {
+ weight[weightOffset] = weightVal * (1 - wd) - gradWeightVal;
+ }
+ }
+}
+
+
+#ifdef CUDA_HALF_TENSOR
+void THNN_CudaHalfIndexLinear_updateOutput(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCudaHalfTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCudaHalfTensor *output,
+ THCudaHalfTensor *weight,
+ THCudaHalfTensor *bias,
+ THCudaHalfTensor *normalizedValues,
+ int train) {
+ THError("THCudaHalfTensor not supported with IndexLinear");
+}
+
+void THNN_CudaHalfIndexLinear_accGradParameters(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCudaHalfTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCudaHalfTensor *gradOutput,
+ THCudaHalfTensor *gradWeight,
+ THCudaHalfTensor *gradBias,
+ THCudaHalfTensor *weight,
+ THCudaHalfTensor *bias,
+ THCudaHalfTensor* valuesBuffer,
+ float weightDecay,
+ float scale) {
+ THError("THCudaHalfTensor not supported with IndexLinear");
+}
+
+void THNN_CudaHalfIndexLinear_accUpdateGradParameters(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCudaHalfTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCudaHalfTensor *gradOutput,
+ THCudaHalfTensor *weight,
+ THCudaHalfTensor *bias,
+ float weightDecay,
+ float scale) {
+ THError("THCudaHalfTensor not supported with IndexLinear");
+}
+
+void THNN_CudaHalfIndexLinear_updateParameters(
+ THCState *state,
+ THCudaHalfTensor *gradWeight,
+ THCudaHalfTensor *gradBias,
+ THCudaHalfTensor *weight,
+ THCudaHalfTensor *bias,
+ THCudaLongTensor *runningKeys,
+ THCudaLongTensor *cumSumSizes,
+ long keysOffset,
+ float weightDecay,
+ float learningRate) {
+ THError("THCudaHalfTensor not supported with IndexLinear");
+}
+#endif
+
+#include "generic/IndexLinear.cu"
+#include "THCGenerateFloatType.h"
+#include "generic/IndexLinear.cu"
+#include "THCGenerateDoubleType.h"
diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu
index f36206f..9110bbc 100644
--- a/lib/THCUNN/SparseLinear.cu
+++ b/lib/THCUNN/SparseLinear.cu
@@ -3,7 +3,6 @@
#include "THCHalfAutoNumerics.cuh"
#include <cusparse.h>
-#include <thrust/device_vector.h>
static cusparseHandle_t cusparse_handle = 0;
diff --git a/lib/THCUNN/generic/IndexLinear.cu b/lib/THCUNN/generic/IndexLinear.cu
new file mode 100644
index 0000000..ae96148
--- /dev/null
+++ b/lib/THCUNN/generic/IndexLinear.cu
@@ -0,0 +1,273 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/IndexLinear.cu"
+#else
+
+static bool THCUNN_checkKeysValues(THCState *state, THCudaLongTensor* keys,
+ THCTensor* values)
+{
+ return THCudaLongTensor_size(state, keys, 0) == THCTensor_(nElement)(state, values)
+ && THCTensor_(nDimension)(state, values) == 1
+ && THCudaLongTensor_nDimension(state, keys) == 1;
+}
+
+void THNN_(IndexLinear_updateOutput)(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCTensor *normalizedValues,
+ int train)
+{
+ // Make sure these inputs are contiguous to accelerate computations
+ THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
+ "keys vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, values), 3,
+ "values vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
+ "sizes vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
+ "cumSumSizes vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, output), 6,
+ "output vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 7,
+ "weight matrix must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, bias), 8,
+ "bias vector must be contiguous");
+ THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
+ "Keys and values should have the same number of elements");
+
+ long batchSize = sizes->size[0];
+ long outDim = bias->size[0];
+ long wDim = weight->size[1];
+ long weightStride = weight->stride[0];
+ int maxNormalize = wDim - outDim;
+ long keysSize = keys->size[0];
+ long nnzPerRow = divup(keysSize, batchSize);
+
+ THCTensor_(resize2d)(state, output, batchSize, outDim);
+ long *keysData = THCudaLongTensor_data (state, keys);
+ real *valuesData = THCTensor_(data) (state, values);
+ long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
+ real *biasData = THCTensor_(data) (state, bias);
+ real *weightData = THCTensor_(data) (state, weight);
+ real *outData = THCTensor_(data) (state, output);
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ dim3 threads(THREADS_X, THREADS_Y);
+ int blocks_x = divup(outDim, threads.x);
+ int blocks_y = batchSize;
+ int nnzPerBlock = ((outDim == 1 || batchSize == 1) ? THREADS_X : NNZ_PER_BLOCK_MAX);
+ int blocks_z = divup(nnzPerRow, nnzPerBlock);
+
+ dim3 blocks(blocks_x, blocks_y, blocks_z);
+
+ if (blocks_z > 1) {
+ THCudaCheck(cudaMemsetAsync(outData, 0, outDim * batchSize * sizeof(real), stream));
+ }
+
+ real *normalizedValuesData = NULL;
+ if (maxNormalize && train) {
+ THCTensor_(resize1d)(state, normalizedValues, keysSize);
+ normalizedValuesData = THCTensor_(data)(state, normalizedValues);
+ updateOutput<real, true><<<blocks, threads, 0, stream>>>
+ (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
+ batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
+ } else {
+ updateOutput<real, false><<<blocks, threads, 0, stream>>>
+ (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
+ batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
+ }
+}
+
+void THNN_(IndexLinear_accGradParameters)(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCTensor* valuesBuffer,
+ accreal weightDecay,
+ accreal scale)
+{
+ long keysSize = keys->size[0];
+ long batchSize = sizes->size[0];
+ long outDim = bias->size[0];
+ long wDim = weight->size[1];
+ int maxNormalize = wDim - outDim;
+
+ // Make sure these inputs are contiguous to accelerate computations
+ THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
+ "keys vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, values), 3,
+ "values vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
+ "sizes vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
+ "cumSumSizes vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6,
+ "gradOutput vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 7,
+ "gradWeight matrix must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 8,
+ "gradBias vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 9,
+ "weight matrix must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, bias), 10,
+ "bias vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, valuesBuffer), 11,
+ "valuesBuffer vector must be contiguous");
+ THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
+ "Keys and values should have the same number of elements");
+
+ THCTensor_(resize2d)(state, gradWeight, keysSize, outDim * (maxNormalize > 0 ? 2 : 1));
+
+ real *valuesData = THCTensor_(data) (state, values);
+ long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
+ real *gradOutputData = THCTensor_(data) (state, gradOutput);
+ real *gradBiasData = THCTensor_(data) (state, gradBias);
+ real *gradWeightData = THCTensor_(data) (state, gradWeight);
+ long gradWeightStride = gradWeight->stride[0];
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ dim3 threads(THREADS_X, THREADS_Y);
+ int blocks_x = divup(outDim, threads.x);
+ accGradBias<real, false><<<blocks_x, threads, 0, stream>>>
+ (gradBiasData, gradOutputData, outDim, batchSize, scale, weightDecay);
+
+ dim3 blocks(blocks_x, batchSize);
+ accGradWeight<real><<<blocks, threads, 0, stream>>>
+ (gradWeightData, gradOutputData, valuesData, cumSumSizesData, outDim,
+ gradWeightStride, scale, weightDecay, maxNormalize);
+}
+
+void THNN_(IndexLinear_accUpdateGradParameters)(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCTensor *gradOutput,
+ THCTensor *weight,
+ THCTensor *bias,
+ accreal weightDecay,
+ accreal scale)
+{
+ // Make sure these inputs are contiguous to accelerate computations
+ THArgCheck(THCudaLongTensor_isContiguous(state, keys), 1,
+ "keys vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, values), 3,
+ "values vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, sizes), 4,
+ "sizes vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 5,
+ "cumSumSizes vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 6,
+ "gradOutput vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 7,
+ "weight matrix must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, bias), 8,
+ "bias vector must be contiguous");
+ THArgCheck(THCUNN_checkKeysValues(state, keys, values), 1,
+ "Keys and values should have the same number of elements");
+
+ long batchSize = sizes->size[0];
+ long outDim = bias->size[0];
+ long keysSize = keys->size[0];
+ long wDim = weight->size[1];
+ int maxNormalize = wDim - outDim;
+
+ real *biasData = THCTensor_(data) (state, bias);
+ real *weightData = THCTensor_(data) (state, weight);
+ real *gradOutputData = THCTensor_(data) (state, gradOutput);
+ real *valuesData = THCTensor_(data) (state, values);
+ long *keysData = THCudaLongTensor_data (state, keys);
+ long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
+ long weightStride = weight->stride[0];
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ dim3 threads(THREADS_X, THREADS_Y);
+ int blocks_x = divup(outDim, threads.x);
+
+ accGradBias<real, true><<<blocks_x, threads, 0, stream>>>
+ (biasData, gradOutputData, outDim, batchSize, scale, weightDecay);
+
+ long nnzPerRow = divup(keysSize, batchSize);
+ int blocks_y = divup(nnzPerRow, REPEAT * threads.y);
+ dim3 blocks(blocks_x, blocks_y);
+
+ for (long batchId = 0; batchId < batchSize; batchId++) {
+ accUpdateWeight<real><<<blocks, threads, 0, stream>>>
+ (weightData, weightStride, gradOutputData, outDim, valuesData,
+ cumSumSizesData, keysData, keysOffset, scale, weightDecay, maxNormalize,
+ batchId);
+ }
+}
+
+void THNN_(IndexLinear_updateParameters)(
+ THCState *state,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCudaLongTensor *runningKeys,
+ THCudaLongTensor *cumSumSizes,
+ long keysOffset,
+ accreal weightDecay,
+ accreal learningRate)
+{
+ // Make sure these inputs are contiguous to accelerate computations
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 1,
+ "gradWeight matrix must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 2,
+ "gradBias vector must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 3,
+ "weight matrix must be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, bias), 4,
+ "bias vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, runningKeys), 5,
+ "runningKeys vector must be contiguous");
+ THArgCheck(THCudaLongTensor_isContiguous(state, cumSumSizes), 6,
+ "cumSumSizes vector must be contiguous");
+
+ long outDim = bias->size[0];
+ long wDim = weight->size[1];
+ int maxNormalize = wDim - outDim;
+ long keysSize = runningKeys->size[0];
+ long batchSize = cumSumSizes->size[0];
+
+ THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias);
+ long gradWeightStride = gradWeight->stride[0];
+ long weightStride = weight->stride[0];
+
+ long *keysData = THCudaLongTensor_data (state, runningKeys);
+ long *cumSumSizesData = THCudaLongTensor_data (state, cumSumSizes);
+ real *gradWeightData = THCTensor_(data) (state, gradWeight);
+ real *weightData = THCTensor_(data) (state, weight);
+
+ dim3 threads(THREADS_X, THREADS_Y);
+ long nnzPerRow = divup(keysSize, batchSize);
+ int blocks_x = divup(outDim, threads.x);
+ int blocks_y = divup(nnzPerRow, REPEAT * threads.y);
+ dim3 blocks(blocks_x, blocks_y);
+ cudaStream_t stream = THCState_getCurrentStream(state);
+
+ for (long batchId = 0; batchId < batchSize; batchId++) {
+ updateWeight<real><<<blocks, threads, 0, stream>>>
+ (weightData, gradWeightData, keysData, cumSumSizesData, outDim,
+ gradWeightStride, weightStride, keysOffset, learningRate, weightDecay,
+ maxNormalize, batchId);
+ }
+}
+#endif
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index a71209a..3c4c38f 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -379,6 +379,60 @@ TH_API void THNN_(SparseLinear_updateParameters)(
THCTensor *lastInput,
accreal learningRate);
+TH_API void THNN_(IndexLinear_updateOutput)(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCTensor *output,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCTensor *normalizedValues,
+ int train);
+
+TH_API void THNN_(IndexLinear_accGradParameters)(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCTensor* valuesBuffer,
+ accreal weightDecay,
+ accreal scale);
+
+TH_API void THNN_(IndexLinear_accUpdateGradParameters)(
+ THCState *state,
+ THCudaLongTensor *keys,
+ long keysOffset,
+ THCTensor *values,
+ THCudaLongTensor *sizes,
+ THCudaLongTensor *cumSumSizes,
+ THCTensor *gradOutput,
+ THCTensor *weight,
+ THCTensor *bias,
+ accreal weightDecay,
+ accreal scale);
+
+TH_API void THNN_(IndexLinear_updateParameters)(
+ THCState *state,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *weight,
+ THCTensor *bias,
+ THCudaLongTensor *runningKeys,
+ THCudaLongTensor *cumSumSizes,
+ long keysOffset,
+ accreal weightDecay,
+ accreal learningRate);
+
TH_API void THNN_(SpatialAdaptiveMaxPooling_updateOutput)(
THCState *state,
THCTensor *input,
diff --git a/test.lua b/test.lua
index f8c88f7..436f8b0 100644
--- a/test.lua
+++ b/test.lua
@@ -5915,6 +5915,223 @@ function cunntest.ModuleConversionFunctions()
end
end
+function cunntest.IndexLinear()
+ isize = 500E3
+ osize = 250
+ weightDecay = 0.01
+ nnzMin = 1000
+ nnzMax = 1500
+ idxMin = 1
+ idxMax = isize
+ batchSize = 128
+ lr = 0.01
+ ntests = 1
+
+ local errNorm = function(a, b)
+ return torch.Tensor(1):fill(torch.cdiv((a - b):abs(), a:abs()):max())
+ end
+
+ local ilc = nn.IndexLinear(isize, osize):float()
+ local ilg = nn.IndexLinear(isize, osize):float():cuda()
+
+ local ilc2 = nn.IndexLinear(isize, osize):float()
+ local ilg2 = nn.IndexLinear(isize, osize):float():cuda()
+
+ local tot = 0
+ local samples = 0
+ local inputCPU = {{}, {}}
+ local inputGPU = {{}, {}}
+ local flatInputCPU = {torch.LongTensor(), torch.FloatTensor(), torch.LongTensor()}
+ local flatInputGPU = {torch.CudaLongTensor(), torch.CudaTensor(), torch.CudaLongTensor()}
+ local sizes = torch.LongTensor(batchSize)
+ for i=1,batchSize do
+ local n = torch.random(nnzMin, nnzMax)
+ local indices = idxMin + torch.LongTensor():randperm(idxMax - idxMin)
+ inputCPU[1][i] = indices[{{1,n}}]
+ inputCPU[2][i] = torch.FloatTensor(n):uniform()
+ inputGPU[1][i] = torch.CudaLongTensor(n):copy(inputCPU[1][i])
+ inputGPU[2][i] = torch.CudaTensor(n):copy(inputCPU[2][i])
+ sizes[i] = n
+ tot = tot + n
+ end
+ flatInputCPU[1]:cat(inputCPU[1], 1)
+ flatInputCPU[2]:cat(inputCPU[2], 1)
+ flatInputCPU[3] = sizes
+
+ flatInputGPU[1]:cat(inputGPU[1], 1)
+ flatInputGPU[2]:cat(inputGPU[2], 1)
+ flatInputGPU[3] = sizes:cudaLong()
+
+ local inputSize = #inputCPU[1]
+ local gradOutsCPU = torch.FloatTensor(inputSize, osize):uniform()
+ local gradOutsGPU = torch.CudaTensor(inputSize, osize):copy(gradOutsCPU)
+
+ local outputCPU, outputGPU
+ local flatOutputCPU, flatOutputGPU
+
+ ilc.weightDecay = weightDecay
+ ilg.weightDecay = weightDecay
+ ilc2.weightDecay = weightDecay
+ ilg2.weightDecay = weightDecay
+
+ ilc.weight:uniform()
+ ilc.bias:fill(1)
+ ilc2.weight:uniform()
+ ilc2.bias:fill(1)
+
+ ilg.weight:copy(ilc.weight)
+ ilg.bias:copy(ilc.bias)
+ ilg2.weight:copy(ilc2.weight)
+ ilg2.bias:copy(ilc2.bias)
+
+ ilc:zeroGradParameters()
+ outputCPU = ilc:forward(inputCPU)
+ ilc:backward(inputCPU, gradOutsCPU);
+ ilc:updateParameters(lr)
+
+ ilc2:zeroGradParameters()
+ flatOutputCPU = ilc2:forward(flatInputCPU)
+ ilc2:backward(flatInputCPU, gradOutsCPU);
+ ilc2:updateParameters(lr)
+
+ ilg:zeroGradParameters()
+ outputGPU = ilg:forward(inputGPU)
+ ilg:backward(inputGPU, gradOutsGPU);
+ ilg:updateParameters(lr)
+
+ ilg2:zeroGradParameters()
+ flatOutputGPU = ilg2:forward(flatInputGPU)
+ ilg2:backward(flatInputGPU, gradOutsGPU);
+ ilg2:updateParameters(lr)
+
+ mytester:assertTensorEq(errNorm(outputCPU, outputGPU:float()),
+ torch.Tensor(1):fill(0),
+ 1E-5, "cunn.IndexLinear:forward failed for output")
+
+ mytester:assertTensorEq(errNorm(flatOutputCPU, flatOutputGPU:float()),
+ torch.Tensor(1):fill(0),
+ 1E-5, "cunn.IndexLinear:forward failed for flatOutput")
+
+ mytester:assertTensorEq(ilc.bias,
+ ilg.bias:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for bias for tensor array")
+
+ mytester:assertTensorEq(ilc.weight,
+ ilg.weight:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for weight for tensor array")
+
+ mytester:assertTensorEq(ilc2.bias,
+ ilg2.bias:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for bias for flat input")
+
+ mytester:assertTensorEq(ilc2.weight,
+ ilg2.weight:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for weight for flat input")
+
+ ilc.weight:uniform()
+ ilc.bias:fill(1)
+
+ ilg.weight:copy(ilc.weight)
+ ilg.bias:copy(ilc.bias)
+
+ ilc2.weight:uniform()
+ ilc2.bias:fill(1)
+
+ ilg2.weight:copy(ilc2.weight)
+ ilg2.bias:copy(ilc2.bias)
+
+ outputCPU = ilc:forward(inputCPU)
+ ilc:backwardUpdate(inputCPU, gradOutsCPU, lr);
+
+ outputGPU = ilg:forward(inputGPU)
+ ilg:backwardUpdate(inputGPU, gradOutsGPU, lr);
+
+ flatOutputCPU = ilc2:forward(flatInputCPU)
+ ilc2:backwardUpdate(flatInputCPU, gradOutsCPU, lr);
+
+ flatOutputGPU = ilg2:forward(flatInputGPU)
+ ilg2:backwardUpdate(flatInputGPU, gradOutsGPU, lr);
+
+ mytester:assertTensorEq(errNorm(outputCPU, outputGPU:float()),
+ torch.Tensor(1):fill(0),
+ 1E-5, "cunn.IndexLinear:forward failed for output")
+
+ mytester:assertTensorEq(errNorm(flatOutputCPU, flatOutputGPU:float()),
+ torch.Tensor(1):fill(0),
+ 1E-5, "cunn.IndexLinear:forward failed for flatOutput")
+
+ mytester:assertTensorEq(ilc.bias,
+ ilg.bias:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for bias for tensor array")
+
+ mytester:assertTensorEq(ilc.weight,
+ ilg.weight:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for weight for tensor array")
+
+ mytester:assertTensorEq(ilc2.bias,
+ ilg2.bias:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for bias for flat input")
+
+ mytester:assertTensorEq(ilc2.weight,
+ ilg2.weight:float(),
+ 1E-5, "cunn.IndexLinear:backward+update failed for weight for flat input")
+end
+
+function cunntest.IndexLinearMaxNorm()
+ isize = 500E3
+ osize = 250
+ weightDecay = 0
+ nnzMin = 1000
+ nnzMax = 1500
+ idxMin = 1
+ idxMax = isize
+ batchSize = 128
+ lr = 0.01
+ ntests = 1
+
+ local errNorm = function(a, b)
+ return torch.Tensor(1):fill(torch.cdiv((a - b):abs(), a:abs()):max())
+ end
+
+ local ilc = nn.IndexLinear(isize, osize, nil, nil, nil, nil, 1):float()
+ local ilg = nn.IndexLinear(isize, osize, nil, nil, nil, nil, 1):float():cuda()
+
+ local tot = 0
+ local samples = 0
+ local inputCPU = {{}, {}}
+ local inputGPU = {{}, {}}
+ for i=1,batchSize do
+ local n = torch.random(nnzMin, nnzMax)
+ local indices = idxMin + torch.LongTensor():randperm(idxMax - idxMin)
+ inputCPU[1][i] = indices[{{1,n}}]
+ inputCPU[2][i] = torch.FloatTensor(n):uniform()
+ inputGPU[1][i] = torch.CudaLongTensor(n):copy(inputCPU[1][i])
+ inputGPU[2][i] = torch.CudaTensor(n):copy(inputCPU[2][i])
+ tot = tot + n
+ end
+
+ local inputSize = #inputCPU[1]
+ local gradOutsCPU = torch.FloatTensor(inputSize, osize):uniform()
+ local gradOutsGPU = torch.CudaTensor(inputSize, osize):copy(gradOutsCPU)
+
+ ilc.weightDecay = weightDecay
+ ilg.weightDecay = weightDecay
+
+ ilc.weight:uniform()
+ ilc.weight:narrow(2,2,1):fill(1.0):cdiv(ilc.weight:narrow(2,1,1))
+ ilc.bias:fill(1)
+
+ ilg.weight:copy(ilc.weight)
+ ilg.bias:copy(ilc.bias)
+
+ outputCPU = ilc:forward(inputCPU)
+ outputGPU = ilg:forward(inputGPU)
+
+ mytester:assertTensorEq(errNorm(outputCPU, outputGPU:float()),
+ torch.Tensor(1):fill(0),
+ 1E-5, "cunn.IndexLinear:forward failed for output")
+end
+
function cunntest.GPU()
local ndevice = cutorch.getDeviceCount()
if ndevice < 2 then