Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-10-13 19:45:22 +0300
committerGregory Chanan <gchanan@fb.com>2016-10-20 00:53:19 +0300
commit915bd8711b224467262e2e7bfb1f5ace3f7b99ad (patch)
tree4015454631538c734faff0c41fd9772411bb25fd
parent09b9966cb3aafc4852806a2a4f5b50dc0711a3ea (diff)
Indices for nn.
-rw-r--r--SpatialAdaptiveMaxPooling.lua4
-rw-r--r--SpatialFractionalMaxPooling.lua7
-rw-r--r--TemporalMaxPooling.lua7
-rw-r--r--VolumetricMaxPooling.lua9
-rw-r--r--lib/THNN/generic/LookupTable.c2
-rw-r--r--lib/THNN/generic/SpatialFractionalMaxPooling.c26
-rw-r--r--lib/THNN/generic/THNN.h22
-rw-r--r--lib/THNN/generic/TemporalMaxPooling.c30
-rw-r--r--lib/THNN/generic/VolumetricDilatedMaxPooling.c32
-rw-r--r--lib/THNN/generic/VolumetricMaxPooling.c4
-rw-r--r--lib/THNN/generic/VolumetricMaxUnpooling.c40
11 files changed, 96 insertions, 87 deletions
diff --git a/SpatialAdaptiveMaxPooling.lua b/SpatialAdaptiveMaxPooling.lua
index a2cf104..b78261c 100644
--- a/SpatialAdaptiveMaxPooling.lua
+++ b/SpatialAdaptiveMaxPooling.lua
@@ -10,9 +10,9 @@ end
function SpatialAdaptiveMaxPooling:updateOutput(input)
self.indices = self.indices or torch.LongTensor()
if torch.typename(input):find('torch%.Cuda.*Tensor') then
- self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
else
- self.indices = self.indices:long()
+ self.indices = self.indices:long()
end
input.THNN.SpatialAdaptiveMaxPooling_updateOutput(
input:cdata(),
diff --git a/SpatialFractionalMaxPooling.lua b/SpatialFractionalMaxPooling.lua
index f5d8076..884751d 100644
--- a/SpatialFractionalMaxPooling.lua
+++ b/SpatialFractionalMaxPooling.lua
@@ -114,7 +114,12 @@ function SpatialFractionalMaxPooling:fixPoolingRegions(val)
end
function SpatialFractionalMaxPooling:updateOutput(input)
- self.indices = self.indices or input.new()
+ self.indices = self.indices or torch.LongTensor()
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ else
+ self.indices = self.indices:long()
+ end
self:initSampleBuffer_(input)
local outW, outH = self:getOutputSizes_(input)
diff --git a/TemporalMaxPooling.lua b/TemporalMaxPooling.lua
index 91723e6..894f4a9 100644
--- a/TemporalMaxPooling.lua
+++ b/TemporalMaxPooling.lua
@@ -10,7 +10,12 @@ function TemporalMaxPooling:__init(kW, dW)
end
function TemporalMaxPooling:updateOutput(input)
- self.indices = self.indices or input.new()
+ self.indices = self.indices or torch.LongTensor()
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ else
+ self.indices = self.indices:long()
+ end
input.THNN.TemporalMaxPooling_updateOutput(
input:cdata(), self.output:cdata(),
self.indices:cdata(), self.kW, self.dW
diff --git a/VolumetricMaxPooling.lua b/VolumetricMaxPooling.lua
index fd65231..20733ed 100644
--- a/VolumetricMaxPooling.lua
+++ b/VolumetricMaxPooling.lua
@@ -22,7 +22,7 @@ function VolumetricMaxPooling:__init(kT, kW, kH, dT, dW, dH, padT, padW, padH)
self.ceil_mode = false
- self.indices = torch.Tensor()
+ self.indices = torch.LongTensor()
end
function VolumetricMaxPooling:ceil()
@@ -41,7 +41,12 @@ function VolumetricMaxPooling:updateOutput(input)
self.iheight = input:size(dims-1)
self.iwidth = input:size(dims)
- self.indices = self.indices or input.new()
+ self.indices = self.indices or torch.LongTensor()
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ else
+ self.indices = self.indices:long()
+ end
input.THNN.VolumetricMaxPooling_updateOutput(
input:cdata(),
self.output:cdata(),
diff --git a/lib/THNN/generic/LookupTable.c b/lib/THNN/generic/LookupTable.c
index a682f5f..b460f38 100644
--- a/lib/THNN/generic/LookupTable.c
+++ b/lib/THNN/generic/LookupTable.c
@@ -29,7 +29,7 @@ void THNN_(LookupTable_accGradParameters)(
THTensor *gradWeight,
THIntegerTensor *count,
THTensor *sorted,
- THTensor *indices,
+ THIndexTensor *indices,
bool scaleGradByFreq,
int paddingValue,
real scale)
diff --git a/lib/THNN/generic/SpatialFractionalMaxPooling.c b/lib/THNN/generic/SpatialFractionalMaxPooling.c
index 0a9db40..a98954c 100644
--- a/lib/THNN/generic/SpatialFractionalMaxPooling.c
+++ b/lib/THNN/generic/SpatialFractionalMaxPooling.c
@@ -23,7 +23,7 @@ static long* THNN_(SpatialFractionalMaxPooling_generateIntervals)(
static void THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
real* input,
real* output,
- real* indices,
+ THIndex_t* indices,
real* randomSamples,
long numPlanes,
long inputW, long inputH,
@@ -48,7 +48,7 @@ static void THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
real* inputForPlane = input + plane * inputW * inputH;
real* outputForPlane = output + plane * outputW * outputH;
- real* indicesForPlane = indices + plane * outputW * outputH;
+ THIndex_t* indicesForPlane = indices + plane * outputW * outputH;
for (h = 0; h < outputH; ++h) {
long inputHStart = sequenceH[h];
@@ -79,7 +79,7 @@ static void THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
outputForPlane[h * outputW + w] = maxVal;
/* +1 to lua index */
- indicesForPlane[h * outputW + w] = (real) maxIndex + TH_INDEX_BASE;
+ indicesForPlane[h * outputW + w] = maxIndex + TH_INDEX_BASE;
}
}
@@ -94,7 +94,7 @@ void THNN_(SpatialFractionalMaxPooling_updateOutput)(
THTensor *output,
int outputW, int outputH,
int poolSizeW, int poolSizeH,
- THTensor *indices,
+ THIndexTensor *indices,
THTensor *randomSamples) {
long numBatch = 1;
@@ -132,18 +132,18 @@ void THNN_(SpatialFractionalMaxPooling_updateOutput)(
/* resize output */
THTensor_(resize3d)(output, numPlanes, outputH, outputW);
/* indices will contain the locations for each output point */
- THTensor_(resize3d)(indices, numPlanes, outputH, outputW);
+ THIndexTensor_(resize3d)(indices, numPlanes, outputH, outputW);
THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
THTensor_(data)(input),
THTensor_(data)(output),
- THTensor_(data)(indices),
+ THIndexTensor_(data)(indices),
THTensor_(data)(randomSamples),
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
} else {
THTensor_(resize4d)(output, numBatch, numPlanes, outputH, outputW);
/* indices will contain the locations for each output point */
- THTensor_(resize4d)(indices, numBatch, numPlanes, outputH, outputW);
+ THIndexTensor_(resize4d)(indices, numBatch, numPlanes, outputH, outputW);
long batch;
#pragma omp parallel for private(batch)
@@ -151,7 +151,7 @@ void THNN_(SpatialFractionalMaxPooling_updateOutput)(
THNN_(SpatialFractionalMaxPooling_updateOutput_frame)(
THTensor_(data)(input) + batch * numPlanes * inputH * inputW,
THTensor_(data)(output) + batch * numPlanes * outputH * outputW,
- THTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
+ THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
THTensor_(data)(randomSamples) + batch * numPlanes * 2,
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
}
@@ -164,7 +164,7 @@ void THNN_(SpatialFractionalMaxPooling_updateOutput)(
static void THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
real* gradInput,
real* gradOutput,
- real* indices,
+ THIndex_t* indices,
long numPlanes,
long inputW, long inputH,
long outputW, long outputH) {
@@ -173,7 +173,7 @@ static void THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
for (plane = 0; plane < numPlanes; plane++) {
real* gradInputForPlane = gradInput + plane * inputW * inputH;
real* gradOutputForPlane = gradOutput + plane * outputW * outputH;
- real* indicesForPlane = indices + plane * outputW * outputH;
+ THIndex_t* indicesForPlane = indices + plane * outputW * outputH;
long h, w;
for (h = 0; h < outputH; ++h) {
@@ -195,7 +195,7 @@ void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
THTensor *gradInput,
int outputW, int outputH,
int poolSizeW, int poolSizeH,
- THTensor *indices) {
+ THIndexTensor *indices) {
long numBatch = 1;
int planeDim = 0;
@@ -232,7 +232,7 @@ void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
THTensor_(data)(gradInput),
THTensor_(data)(gradOutput),
- THTensor_(data)(indices),
+ THIndexTensor_(data)(indices),
numPlanes, inputW, inputH, outputW, outputH);
} else {
long batch;
@@ -241,7 +241,7 @@ void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
THNN_(SpatialFractionalMaxPooling_updateGradInput_frame)(
THTensor_(data)(gradInput) + batch * numPlanes * inputH * inputW,
THTensor_(data)(gradOutput) + batch * numPlanes * outputH * outputW,
- THTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
+ THIndexTensor_(data)(indices) + batch * numPlanes * outputH * outputW,
numPlanes, inputW, inputH, outputW, outputH);
}
}
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index df37f62..9d8dac8 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -186,7 +186,7 @@ TH_API void THNN_(LookupTable_accGradParameters)(
THTensor *gradWeight,
THIntegerTensor *count,
THTensor *sorted, // [OPTIONAL]
- THTensor *indices, // [OPTIONAL]
+ THIndexTensor *indices, // [OPTIONAL]
bool scaleGradByFreq,
int paddingValue,
real scale);
@@ -543,14 +543,14 @@ TH_API void THNN_(TemporalMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int kW, int dW);
TH_API void THNN_(TemporalMaxPooling_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int kW, int dW);
TH_API void THNN_(TemporalSubSampling_updateOutput)(
THNNState *state,
@@ -753,7 +753,7 @@ TH_API void THNN_(SpatialFractionalMaxPooling_updateOutput)(
THTensor *output,
int outputW, int outputH,
int poolSizeW, int poolSizeH,
- THTensor *indices,
+ THIndexTensor *indices,
THTensor *randomSamples);
TH_API void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
THNNState *state,
@@ -762,7 +762,7 @@ TH_API void THNN_(SpatialFractionalMaxPooling_updateGradInput)(
THTensor *gradInput,
int outputW, int outputH,
int poolSizeW, int poolSizeH,
- THTensor *indices);
+ THIndexTensor *indices);
TH_API void THNN_(SpatialFullConvolution_updateOutput)(
THNNState *state,
@@ -1156,7 +1156,7 @@ TH_API void THNN_(VolumetricMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int kT, int kW, int kH,
int dT, int dW, int dH,
int pT, int pW, int pH,
@@ -1166,7 +1166,7 @@ TH_API void THNN_(VolumetricMaxPooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int dT, int dW, int dH,
int pT, int pW, int pH);
@@ -1174,7 +1174,7 @@ TH_API void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int kT, int kW, int kH,
int dT, int dW, int dH,
int pT, int pW, int pH,
@@ -1185,7 +1185,7 @@ TH_API void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int dT, int dW, int dH,
int pT, int pW, int pH,
int dilationT, int dilationW, int dilationH);
@@ -1194,7 +1194,7 @@ TH_API void THNN_(VolumetricMaxUnpooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int oT, int oW, int oH,
int dT, int dW, int dH,
int pT, int pW, int pH);
@@ -1203,7 +1203,7 @@ TH_API void THNN_(VolumetricMaxUnpooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int oT, int oW, int oH,
int dT, int dW, int dH,
int pT, int pW, int pH);
diff --git a/lib/THNN/generic/TemporalMaxPooling.c b/lib/THNN/generic/TemporalMaxPooling.c
index 48cbcab..0a2f004 100644
--- a/lib/THNN/generic/TemporalMaxPooling.c
+++ b/lib/THNN/generic/TemporalMaxPooling.c
@@ -6,7 +6,7 @@ void THNN_(TemporalMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int kW,
int dW)
{
@@ -16,7 +16,7 @@ void THNN_(TemporalMaxPooling_updateOutput)(
real *input_data;
real *output_data;
- real *indices_data;
+ THIndex_t *indices_data;
long t, y;
@@ -46,18 +46,18 @@ void THNN_(TemporalMaxPooling_updateOutput)(
THTensor_(resize2d)(output, noframe, framesize);
/* indices will contain index locations for each output point */
- THTensor_(resize2d)(indices, noframe, framesize);
+ THIndexTensor_(resize2d)(indices, noframe, framesize);
/* get raw pointers */
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
for(t = 0; t < noframe; t++)
{
real *ip = input_data + t*framesize*dW;
real *op = output_data + t*framesize;
- real *xp = indices_data + t*framesize;
+ THIndex_t *xp = indices_data + t*framesize;
#pragma omp parallel for private(y)
for(y = 0; y < framesize; y++)
{
@@ -91,24 +91,24 @@ void THNN_(TemporalMaxPooling_updateOutput)(
THTensor_(resize3d)(output, nbframe, noframe, framesize);
/* indices will contain index locations for each output point */
- THTensor_(resize3d)(indices, nbframe, noframe, framesize);
+ THIndexTensor_(resize3d)(indices, nbframe, noframe, framesize);
/* get raw pointers */
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
for(i = 0; i < nbframe; i++)
{
real *inputSample_data = input_data + i*niframe*framesize;
real *outputSample_data = output_data + i*noframe*framesize;
- real *indicesSample_data = indices_data + i*noframe*framesize;
+ THIndex_t *indicesSample_data = indices_data + i*noframe*framesize;
for(t = 0; t < noframe; t++)
{
real *ip = inputSample_data + t*framesize*dW;
real *op = outputSample_data + t*framesize;
- real *xp = indicesSample_data + t*framesize;
+ THIndex_t *xp = indicesSample_data + t*framesize;
#pragma omp parallel for private(y)
for(y = 0; y < framesize; y++)
@@ -145,7 +145,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int kW,
int dW)
{
@@ -155,7 +155,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
real *gradInput_data;
real *gradOutput_data;
- real *indices_data;
+ THIndex_t *indices_data;
long t, y;
@@ -182,7 +182,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
/* get raw pointers */
gradInput_data = THTensor_(data)(gradInput);
gradOutput_data = THTensor_(data)(gradOutput);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
if (input->nDimension == 2)
{
@@ -190,7 +190,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
{
real *gip = gradInput_data + t*framesize*dW;
real *gop = gradOutput_data + t*framesize;
- real *xp = indices_data + t*framesize;
+ THIndex_t *xp = indices_data + t*framesize;
#pragma omp parallel for private(y)
for(y = 0; y < framesize; y++)
{
@@ -210,13 +210,13 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
{
real *gradInputSample_data = gradInput_data + i*niframe*framesize;
real *gradOutputSample_data = gradOutput_data + i*noframe*framesize;
- real *indicesSample_data = indices_data + i*noframe*framesize;
+ THIndex_t *indicesSample_data = indices_data + i*noframe*framesize;
for(t = 0; t < noframe; t++)
{
real *gip = gradInputSample_data + t*framesize*dW;
real *gop = gradOutputSample_data + t*framesize;
- real *xp = indicesSample_data + t*framesize;
+ THIndex_t *xp = indicesSample_data + t*framesize;
#pragma omp parallel for private(y)
for(y = 0; y < framesize; y++)
{
diff --git a/lib/THNN/generic/VolumetricDilatedMaxPooling.c b/lib/THNN/generic/VolumetricDilatedMaxPooling.c
index 240040a..940c6fe 100644
--- a/lib/THNN/generic/VolumetricDilatedMaxPooling.c
+++ b/lib/THNN/generic/VolumetricDilatedMaxPooling.c
@@ -5,7 +5,7 @@
static void THNN_(VolumetricDilatedMaxPooling_updateOutput_frame)(
real *input_p,
real *output_p,
- real *indz_p,
+ THIndex_t *indz_p,
long nslices,
long itime,
long iwidth,
@@ -43,7 +43,7 @@ static void THNN_(VolumetricDilatedMaxPooling_updateOutput_frame)(
long start_t = ti * dT - pT;
long start_h = i * dH - pH;
long start_w = j * dW - pW;
-
+
long kernel_t = fminf(kT, kT + start_t);
long kernel_h = fminf(kH, kH + start_h);
long kernel_w = fminf(kW, kW + start_w);
@@ -54,12 +54,12 @@ static void THNN_(VolumetricDilatedMaxPooling_updateOutput_frame)(
start_h += dilationH;
while(start_w < 0)
start_w += dilationW;
-
+
real *ip = input_p + k * itime * iwidth * iheight
+ start_t * iwidth * iheight + start_h * iwidth + start_w;
real *op = output_p + k * otime * owidth * oheight
+ ti * owidth * oheight + i * owidth + j;
- real *indzp = indz_p + k * otime * owidth * oheight
+ THIndex_t *indzp = indz_p + k * otime * owidth * oheight
+ ti * owidth * oheight + i * owidth + j;
/* compute local max: */
@@ -107,7 +107,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int kT,
int kW,
int kH,
@@ -131,7 +131,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
long owidth;
real *input_data;
real *output_data;
- real *indices_data;
+ THIndex_t *indices_data;
THNN_ARGCHECK(input->nDimension == 4 || input->nDimension == 5, 2, input,
"4D or 5D (batch mode) tensor expected for input, but got: %s");
@@ -201,11 +201,11 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
/* resize output */
THTensor_(resize4d)(output, nslices, otime, oheight, owidth);
/* indices will contain ti,i,j uchar locations packed into float/double */
- THTensor_(resize4d)(indices, nslices, otime, oheight, owidth);
+ THIndexTensor_(resize4d)(indices, nslices, otime, oheight, owidth);
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
THNN_(VolumetricDilatedMaxPooling_updateOutput_frame)(
input_data, output_data,
@@ -230,11 +230,11 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
/* resize output */
THTensor_(resize5d)(output, nBatch, nslices, otime, oheight, owidth);
/* indices will contain ti,i,j locations for each output point */
- THTensor_(resize5d)(indices, nBatch, nslices, otime, oheight, owidth);
+ THIndexTensor_(resize5d)(indices, nBatch, nslices, otime, oheight, owidth);
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
#pragma omp parallel for private(p)
for (p=0; p < nBatch; p++)
@@ -261,7 +261,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
static void THNN_(VolumetricDilatedMaxPooling_updateGradInput_frame)(
real *gradInput_p,
real *gradOutput_p,
- real *indz_p,
+ THIndex_t *indz_p,
long nslices,
long itime,
long iwidth,
@@ -285,7 +285,7 @@ static void THNN_(VolumetricDilatedMaxPooling_updateGradInput_frame)(
{
real *gradInput_p_k = gradInput_p + k * itime * iwidth * iheight;
real *gradOutput_p_k = gradOutput_p + k * otime * owidth * oheight;
- real *indz_p_k = indz_p + k * otime * owidth * oheight;
+ THIndex_t *indz_p_k = indz_p + k * otime * owidth * oheight;
/* calculate max points */
long ti, i, j;
@@ -296,7 +296,7 @@ static void THNN_(VolumetricDilatedMaxPooling_updateGradInput_frame)(
for (j = 0; j < owidth; j++)
{
/* retrieve position of max */
- real * indzp = &indz_p_k[ti * oheight * owidth + i * owidth + j];
+ THIndex_t * indzp = &indz_p_k[ti * oheight * owidth + i * owidth + j];
long maxti = ((unsigned char*)(indzp))[0] * dilationT + ti * dT - pT;
long maxi = ((unsigned char*)(indzp))[1] * dilationH + i * dH - pH;
long maxj = ((unsigned char*)(indzp))[2] * dilationW + j * dW - pW;
@@ -315,7 +315,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int dT,
int dW,
int dH,
@@ -335,7 +335,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
int owidth;
real *gradInput_data;
real *gradOutput_data;
- real *indices_data;
+ THIndex_t *indices_data;
int dimN = 0;
int dimt = 1;
@@ -370,7 +370,7 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
/* get raw pointers */
gradInput_data = THTensor_(data)(gradInput);
gradOutput_data = THTensor_(data)(gradOutput);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
/* backprop */
if (input->nDimension == 4) /* non-batch mode*/
diff --git a/lib/THNN/generic/VolumetricMaxPooling.c b/lib/THNN/generic/VolumetricMaxPooling.c
index dc376e6..47af4f0 100644
--- a/lib/THNN/generic/VolumetricMaxPooling.c
+++ b/lib/THNN/generic/VolumetricMaxPooling.c
@@ -6,7 +6,7 @@ void THNN_(VolumetricMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int kT,
int kW,
int kH,
@@ -29,7 +29,7 @@ void THNN_(VolumetricMaxPooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int dT,
int dW,
int dH,
diff --git a/lib/THNN/generic/VolumetricMaxUnpooling.c b/lib/THNN/generic/VolumetricMaxUnpooling.c
index 83f1673..2a5dcdc 100644
--- a/lib/THNN/generic/VolumetricMaxUnpooling.c
+++ b/lib/THNN/generic/VolumetricMaxUnpooling.c
@@ -5,7 +5,7 @@
static void THNN_(VolumetricMaxUnpooling_updateOutput_frame)(
real *input_p,
real *output_p,
- real *ind_p,
+ THIndex_t *ind_p,
long nslices,
long iT,
long iW,
@@ -37,7 +37,7 @@ static void THNN_(VolumetricMaxUnpooling_updateOutput_frame)(
//real *output_p_k = output_p + k*oT*oW*oH + ti*oW*oH*dT + i*oW*dH + j*dW;
real *input_p_k = input_p + k*iT*iW*iH + ti*iW*iH + i*iW + j;
- real *ind_p_k = ind_p + k*iT*iW*iH + ti*iW*iH + i*iW + j;
+ THIndex_t *ind_p_k = ind_p + k*iT*iW*iH + ti*iW*iH + i*iW + j;
maxz = ((unsigned char*)(ind_p_k))[0]; /* retrieve position of max */
maxy = ((unsigned char*)(ind_p_k))[1];
@@ -61,7 +61,7 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
- THTensor *indices,
+ THIndexTensor *indices,
int oT,
int oW,
int oH,
@@ -82,15 +82,12 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
int iW;
real *input_data;
real *output_data;
- real *indices_data;
+ THIndex_t *indices_data;
THNN_ARGCHECK(input->nDimension == 4 || input->nDimension == 5, 2, input,
"4D or 5D (batch mode) tensor expected for input, but got: %s");
- if (!THTensor_(isSameSizeAs)(input, indices))
- {
- THError("Invalid input size w.r.t current indices size");
- }
+ THNN_CHECK_SHAPE_INDICES(input, indices);
if (input->nDimension == 5)
{
@@ -108,7 +105,7 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
/* get contiguous input */
input = THTensor_(newContiguous)(input);
- indices = THTensor_(newContiguous)(indices);
+ indices = THIndexTensor_(newContiguous)(indices);
/* resize output */
if (input->nDimension == 4)
@@ -118,7 +115,7 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
THNN_(VolumetricMaxUnpooling_updateOutput_frame)(
input_data, output_data,
@@ -138,7 +135,7 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
#pragma omp parallel for private(p)
for (p = 0; p < nbatch; p++)
@@ -158,13 +155,13 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)(
/* cleanup */
THTensor_(free)(input);
- THTensor_(free)(indices);
+ THIndexTensor_(free)(indices);
}
static void THNN_(VolumetricMaxUnpooling_updateGradInput_frame)(
real *gradInput_p,
real *gradOutput_p,
- real *ind_p,
+ THIndex_t *ind_p,
long nslices,
long iT,
long iW,
@@ -196,7 +193,7 @@ static void THNN_(VolumetricMaxUnpooling_updateGradInput_frame)(
real *gradInput_p_k = gradInput_p + k*iT*iW*iH + ti*iW*iH + i*iW + j;
//real *gradOutput_p_k = gradOutput_p + k*oT*oW*oH + ti*oW*oH*dT + i*oW*dH + j*dW;
- real *ind_p_k = ind_p + k*iT*iW*iH + ti*iW*iH + i*iW + j;
+ THIndex_t *ind_p_k = ind_p + k*iT*iW*iH + ti*iW*iH + i*iW + j;
maxz = ((unsigned char*)(ind_p_k))[0]; /* retrieve position of max */
maxy = ((unsigned char*)(ind_p_k))[1];
@@ -221,7 +218,7 @@ void THNN_(VolumetricMaxUnpooling_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *indices,
+ THIndexTensor *indices,
int oT,
int oW,
int oH,
@@ -242,17 +239,14 @@ void THNN_(VolumetricMaxUnpooling_updateGradInput)(
int iW;
real *gradInput_data;
real *gradOutput_data;
- real *indices_data;
+ THIndex_t *indices_data;
- if (!THTensor_(isSameSizeAs)(input, indices))
- {
- THError("Invalid input size w.r.t current indices size");
- }
+ THNN_CHECK_SHAPE_INDICES(input, indices);
// TODO: check gradOutput shape
/* get contiguous gradOutput */
gradOutput = THTensor_(newContiguous)(gradOutput);
- indices = THTensor_(newContiguous)(indices);
+ indices = THIndexTensor_(newContiguous)(indices);
/* resize */
THTensor_(resizeAs)(gradInput, input);
@@ -283,7 +277,7 @@ void THNN_(VolumetricMaxUnpooling_updateGradInput)(
/* get raw pointers */
gradInput_data = THTensor_(data)(gradInput);
gradOutput_data = THTensor_(data)(gradOutput);
- indices_data = THTensor_(data)(indices);
+ indices_data = THIndexTensor_(data)(indices);
/* backprop */
if (input->nDimension == 4)
@@ -319,7 +313,7 @@ void THNN_(VolumetricMaxUnpooling_updateGradInput)(
/* cleanup */
THTensor_(free)(gradOutput);
- THTensor_(free)(indices);
+ THIndexTensor_(free)(indices);
}
#endif