diff options
-rw-r--r-- | TemporalMaxPooling.lua | 31 | ||||
-rw-r--r-- | generic/TemporalMaxPooling.c | 127 | ||||
-rw-r--r-- | init.c | 5 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 17 |
5 files changed, 181 insertions, 0 deletions
diff --git a/TemporalMaxPooling.lua b/TemporalMaxPooling.lua new file mode 100644 index 0000000..b8fdd3e --- /dev/null +++ b/TemporalMaxPooling.lua @@ -0,0 +1,31 @@ +local TemporalMaxPooling, parent = torch.class('nn.TemporalMaxPooling', 'nn.Module') + +function TemporalMaxPooling:__init(kW, dW) + parent.__init(self) + + dW = dW or kW + + self.kW = kW + self.dW = dW + + self.indices = torch.Tensor() +end + +function TemporalMaxPooling:updateOutput(input) + input.nn.TemporalMaxPooling_updateOutput(self, input) + return self.output +end + +function TemporalMaxPooling:updateGradInput(input, gradOutput) + input.nn.TemporalMaxPooling_updateGradInput(self, input, gradOutput) + return self.gradInput +end + +function TemporalMaxPooling:empty() + self.gradInput:resize() + self.gradInput:storage():resize(0) + self.output:resize() + self.output:storage():resize(0) + self.indices:resize() + self.indices:storage():resize(0) +end diff --git a/generic/TemporalMaxPooling.c b/generic/TemporalMaxPooling.c new file mode 100644 index 0000000..56d0ef6 --- /dev/null +++ b/generic/TemporalMaxPooling.c @@ -0,0 +1,127 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/TemporalMaxPooling.c" +#else + +static int nn_(TemporalMaxPooling_updateOutput)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + int kW = luaT_getfieldcheckint(L, 1, "kW"); + int dW = luaT_getfieldcheckint(L, 1, "dW"); + THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id)); + THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id)); + + luaL_argcheck(L, input->nDimension == 2, 2, "2D tensor expected"); + luaL_argcheck(L, input->size[0] >= kW, 2, "input sequence smaller than kernel size"); + + // sizes + long niframe = input->size[0]; + long framesize = input->size[1]; + long noframe = (niframe - kW) / dW + 1; + + // get contiguous input + input = THTensor_(newContiguous)(input); + + // resize output + THTensor_(resize2d)(output, noframe, framesize); + + // indices will contain index locations for each output point + THTensor_(resize2d)(indices, noframe, framesize); + + // get raw pointers + real *input_data = THTensor_(data)(input); + real *output_data = THTensor_(data)(output); + real *indices_data = THTensor_(data)(indices); + + long t, x, y; + for(t = 0; t < noframe; t++) + { + real *ip = input_data + t*framesize*dW; + real *op = output_data + t*framesize; + real *xp = indices_data + t*framesize; +#pragma omp parallel for private(y) + for(y = 0; y < framesize; y++) + { + // compute local max: + long maxindex = -1; + real maxval = -THInf; + for(x = 0; x < kW; x++) + { + real val = ip[x*framesize+y]; + if (val > maxval) + { + maxval = val; + maxindex = x; + } + } + + // set output to local max + op[y] = maxval; + xp[y] = (real)maxindex; + } + } + + // cleanup + THTensor_(free)(input); + + return 1; +} + +static int nn_(TemporalMaxPooling_updateGradInput)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id)); + int dW = luaT_getfieldcheckint(L, 1, "dW"); + THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id)); + THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id)); + + // get contiguous gradOutput + gradOutput = THTensor_(newContiguous)(gradOutput); + + // resize and zero + THTensor_(resizeAs)(gradInput, input); + THTensor_(zero)(gradInput); + + // sizes + int noframe = gradOutput->size[0]; + long framesize = gradOutput->size[1]; + + // get raw pointers + real *gradInput_data = THTensor_(data)(gradInput); + real *gradOutput_data = THTensor_(data)(gradOutput); + real *indices_data = THTensor_(data)(indices); + + long t, y; + for(t = 0; t < noframe; t++) + { + real *gip = gradInput_data + t*framesize*dW; + real *gop = gradOutput_data + t*framesize; + real *xp = indices_data + t*framesize; +#pragma omp parallel for private(y) + for(y = 0; y < framesize; y++) + { + // compute local max: + long maxindex = (long)xp[y]; + gip[maxindex*framesize+y] += gop[y]; + } + } + + // cleanup + THTensor_(free)(gradOutput); + + return 1; +} + +static const struct luaL_Reg nn_(TemporalMaxPooling__) [] = { + {"TemporalMaxPooling_updateOutput", nn_(TemporalMaxPooling_updateOutput)}, + {"TemporalMaxPooling_updateGradInput", nn_(TemporalMaxPooling_updateGradInput)}, + {NULL, NULL} +}; + +static void nn_(TemporalMaxPooling_init)(lua_State *L) +{ + luaT_pushmetaclass(L, torch_(Tensor_id)); + luaT_registeratname(L, nn_(TemporalMaxPooling__), "nn"); + lua_pop(L,1); +} + +#endif @@ -71,6 +71,9 @@ static const void* torch_DoubleTensor_id = NULL; #include "generic/TemporalSubSampling.c" #include "THGenerateFloatTypes.h" +#include "generic/TemporalMaxPooling.c" +#include "THGenerateFloatTypes.h" + #include "generic/SpatialConvolution.c" #include "THGenerateFloatTypes.h" @@ -122,6 +125,7 @@ DLL_EXPORT int luaopen_libnn(lua_State *L) nn_FloatSparseLinear_init(L); nn_FloatTemporalConvolution_init(L); nn_FloatTemporalSubSampling_init(L); + nn_FloatTemporalMaxPooling_init(L); nn_FloatSpatialConvolution_init(L); nn_FloatSpatialConvolutionMap_init(L); nn_FloatSpatialSubSampling_init(L); @@ -151,6 +155,7 @@ DLL_EXPORT int luaopen_libnn(lua_State *L) nn_DoubleSparseLinear_init(L); nn_DoubleTemporalConvolution_init(L); nn_DoubleTemporalSubSampling_init(L); + nn_DoubleTemporalMaxPooling_init(L); nn_DoubleSpatialConvolution_init(L); nn_DoubleSpatialConvolutionMap_init(L); nn_DoubleSpatialSubSampling_init(L); @@ -61,6 +61,7 @@ torch.include('nn', 'SpatialMaxPooling.lua') torch.include('nn', 'SpatialLPPooling.lua') torch.include('nn', 'TemporalConvolution.lua') torch.include('nn', 'TemporalSubSampling.lua') +torch.include('nn', 'TemporalMaxPooling.lua') torch.include('nn', 'SpatialSubtractiveNormalization.lua') torch.include('nn', 'SpatialDivisiveNormalization.lua') torch.include('nn', 'SpatialContrastiveNormalization.lua') diff --git a/test/test.lua b/test/test.lua index b8536ca..4d4383d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1013,6 +1013,23 @@ function nntest.TemporalSubSampling() mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end +function nntestx.TemporalMaxPooling() + local from = math.random(1,10) + local ki = math.random(1,10) + local si = math.random(1,4) + local outi = math.random(10,20) + local ini = (outi-1)*si+ki + local module = nn.TemporalMaxPooling(ki, si) + local input = torch.Tensor(ini, from):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state ') + + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') +end + function nntest.VolumetricConvolution() local from = math.random(2,5) local to = math.random(2,5) |