Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LookupTable.lua1
-rw-r--r--README.md47
-rw-r--r--generic/TemporalMaxPooling.c195
-rw-r--r--test/test.lua24
4 files changed, 207 insertions, 60 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
index ca8477c..be9467a 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -30,6 +30,7 @@ function LookupTable:__init(nIndex, ...)
self.gradWeight = torch.Tensor(self.size):zero()
self.inputs = {}
+ self.nBackward = 0
self:reset()
end
diff --git a/README.md b/README.md
index ec0e0a4..4ee736e 100644
--- a/README.md
+++ b/README.md
@@ -1594,13 +1594,13 @@ Note that depending of the size of your kernel, several (of the last)
frames of the sequence might be lost. It is up to the user to add proper padding frames in the input
sequences.
-If the input sequence is a 2D tensor of dimension `inputFrameSize x nInputFrame`, the output sequence will be
+If the input sequence is a 2D tensor of dimension `nInputFrame x inputFrameSize`, the output sequence will be
`nOutputFrame x outputFrameSize` where
```lua
nOutputFrame = (nInputFrame - kW) / dW + 1
```
-If the input sequence is a 3D tensor of dimension `nBatchFrame x inputFrameSize x nInputFrame`, the output sequence will be
+If the input sequence is a 3D tensor of dimension `nBatchFrame x nInputFrame x inputFrameSize`, the output sequence will be
`nBatchFrame x nOutputFrame x outputFrameSize`.
The parameters of the convolution can be found in `self.weight` (Tensor of
@@ -1610,9 +1610,9 @@ size `outputFrameSize`). The corresponding gradients can be found in
For a 2D input, the output value of the layer can be precisely described as:
```lua
-output[i][t] = bias[i]
- + sum_j sum_{k=1}^kW weight[j][k][i]
- * input[j][dW*(t-1)+k)]
+output[t][i] = bias[i]
+ + sum_j sum_{k=1}^kW weight[i][j][k]
+ * input[dW*(t-1)+k)][j]
```
Here is a simple example:
@@ -1620,7 +1620,7 @@ Here is a simple example:
```lua
inp=5; -- dimensionality of one sequence element
outp=1; -- number of derived features for one sequence element
-kw=1; -- kernel only operates on one sequence element at once
+kw=1; -- kernel only operates on one sequence element per step
dw=1; -- we step once and go on to the next sequence element
mlp=nn.TemporalConvolution(inp,outp,kw,dw)
@@ -1660,6 +1660,23 @@ which gives:
-0.63871422284166
```
+<a name="nn.TemporalMaxPooling"/>
+### TemporalMaxPooling ###
+
+```lua
+module = nn.TemporalMaxPooling(kW, [dW])
+```
+
+Applies 1D max-pooling operation in `kW` regions by step size
+`dW` steps. Input sequence composed of `nInputFrame` frames. The `input` tensor in
+`forward(input)` is expected to be a 2D tensor (`nInputFrame x inputFrameSize`)
+or a 3D tensor (`nBatchFrame x nInputFrame x inputFrameSize`).
+
+If the input sequence is a 2D tensor of dimension `nInputFrame x inputFrameSize`, the output sequence will be
+`nOutputFrame x inputFrameSize` where
+```lua
+nOutputFrame = (nInputFrame - kW) / dW + 1
+```
<a name="nn.TemporalSubSampling"/>
### TemporalSubSampling ###
@@ -1715,7 +1732,7 @@ at `1` and can go up to `nIndex`. For each index, it outputs a corresponding `Te
specified by `sizes` (a `LongStorage`) or `size1 x size2 x...`.
Given a 1D input, the output tensors are concatenated,
-generating a `size1 x size2 x ... x sizeN x n` tensor, where `n`
+generating a `n x size1 x size2 x ... x sizeN` tensor, where `n`
is the size of a 1D `input` tensor.
Again with a 1D input, when only `size1` is provided, the `forward(input)` is equivalent to
@@ -1731,21 +1748,21 @@ where `M` is a 2D matrix `size1 x nIndex` containing the parameters of the looku
-- a lookup table containing 10 tensors of size 3
module = nn.LookupTable(10, 3)
- input = torch.Tensor(4)
- input[1] = 1; input[2] = 2; input[3] = 1; input[4] = 10;
+ input = torch.Tensor{1,2,1,10}
print(module:forward(input))
```
Outputs something like:
```lua
--0.1784 2.2045 -0.1784 -0.2475
--1.0120 0.0537 -1.0120 -0.2148
--1.2840 0.8685 -1.2840 -0.2792
-[torch.Tensor of dimension 3x4]
+-1.4415 -0.1001 -0.1708
+-0.6945 -0.4350 0.7977
+-1.4415 -0.1001 -0.1708
+-0.0745 1.9275 1.0915
+[torch.DoubleTensor of dimension 4x3]
```
-Note that the first column vector is the same than the 3rd one!
+Note that the first row vector is the same as the 3rd one!
-Given a 2D input tensor of size `m x n`, the output is a `m x size1 x size2 x ... x sizeN x n`
+Given a 2D input tensor of size `m x n`, the output is a `m x n x size1 x size2 x ... x sizeN`
tensor, where `m` is the number of samples in
the batch and `n` is the number of indices per sample.
diff --git a/generic/TemporalMaxPooling.c b/generic/TemporalMaxPooling.c
index 3c0384d..f55eea3 100644
--- a/generic/TemporalMaxPooling.c
+++ b/generic/TemporalMaxPooling.c
@@ -19,54 +19,119 @@ static int nn_(TemporalMaxPooling_updateOutput)(lua_State *L)
real *indices_data;
long t, y;
+
+ int dimS = 0; // sequence dimension
+ int dimF = 1; // feature dimension
- luaL_argcheck(L, input->nDimension == 2, 2, "2D tensor expected");
- luaL_argcheck(L, input->size[0] >= kW, 2, "input sequence smaller than kernel size");
-
+ luaL_argcheck(L, input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected");
+
+ if (input->nDimension == 3)
+ {
+ dimS = 1;
+ dimF = 2;
+ }
+ luaL_argcheck(L, input->size[dimS] >= kW, 2, "input sequence smaller than kernel size");
+
/* sizes */
- niframe = input->size[0];
- framesize = input->size[1];
+ niframe = input->size[dimS];
+ framesize = input->size[dimF];
noframe = (niframe - kW) / dW + 1;
-
+
/* get contiguous input */
input = THTensor_(newContiguous)(input);
- /* resize output */
- THTensor_(resize2d)(output, noframe, framesize);
-
- /* indices will contain index locations for each output point */
- THTensor_(resize2d)(indices, noframe, framesize);
-
- /* get raw pointers */
- input_data = THTensor_(data)(input);
- output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
-
- for(t = 0; t < noframe; t++)
+ if (input->nDimension == 2)
{
- real *ip = input_data + t*framesize*dW;
- real *op = output_data + t*framesize;
- real *xp = indices_data + t*framesize;
+ /* resize output */
+ THTensor_(resize2d)(output, noframe, framesize);
+
+ /* indices will contain index locations for each output point */
+ THTensor_(resize2d)(indices, noframe, framesize);
+
+ /* get raw pointers */
+ input_data = THTensor_(data)(input);
+ output_data = THTensor_(data)(output);
+ indices_data = THTensor_(data)(indices);
+
+ for(t = 0; t < noframe; t++)
+ {
+ real *ip = input_data + t*framesize*dW;
+ real *op = output_data + t*framesize;
+ real *xp = indices_data + t*framesize;
#pragma omp parallel for private(y)
- for(y = 0; y < framesize; y++)
+ for(y = 0; y < framesize; y++)
+ {
+ /* compute local max: */
+ long maxindex = -1;
+ real maxval = -THInf;
+ long x;
+ for(x = 0; x < kW; x++)
+ {
+ real val = ip[x*framesize+y];
+ if (val > maxval)
+ {
+ maxval = val;
+ maxindex = x;
+ }
+ }
+
+ /* set output to local max */
+ op[y] = maxval;
+ xp[y] = (real)maxindex;
+ }
+ }
+ }
+ else
+ {
+ /* number of batch frames */
+ long nbframe = input->size[0];
+ long i;
+
+ /* resize output */
+ THTensor_(resize3d)(output, nbframe, noframe, framesize);
+
+ /* indices will contain index locations for each output point */
+ THTensor_(resize3d)(indices, nbframe, noframe, framesize);
+
+ /* get raw pointers */
+ input_data = THTensor_(data)(input);
+ output_data = THTensor_(data)(output);
+ indices_data = THTensor_(data)(indices);
+
+ for(i = 0; i < nbframe; i++)
{
- /* compute local max: */
- long maxindex = -1;
- real maxval = -THInf;
- long x;
- for(x = 0; x < kW; x++)
+ real *inputSample_data = input_data + i*niframe*framesize;
+ real *outputSample_data = output_data + i*noframe*framesize;
+ real *indicesSample_data = indices_data + i*noframe*framesize;
+
+ for(t = 0; t < noframe; t++)
{
- real val = ip[x*framesize+y];
- if (val > maxval)
+ real *ip = inputSample_data + t*framesize*dW;
+ real *op = outputSample_data + t*framesize;
+ real *xp = indicesSample_data + t*framesize;
+
+#pragma omp parallel for private(y)
+ for(y = 0; y < framesize; y++)
{
- maxval = val;
- maxindex = x;
+ /* compute local max: */
+ long maxindex = -1;
+ real maxval = -THInf;
+ long x;
+ for(x = 0; x < kW; x++)
+ {
+ real val = ip[x*framesize+y];
+ if (val > maxval)
+ {
+ maxval = val;
+ maxindex = x;
+ }
+ }
+
+ /* set output to local max */
+ op[y] = maxval;
+ xp[y] = (real)maxindex;
}
}
-
- /* set output to local max */
- op[y] = maxval;
- xp[y] = (real)maxindex;
}
}
@@ -84,6 +149,7 @@ static int nn_(TemporalMaxPooling_updateGradInput)(lua_State *L)
THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor);
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+ long niframe;
int noframe;
long framesize;
@@ -100,26 +166,65 @@ static int nn_(TemporalMaxPooling_updateGradInput)(lua_State *L)
THTensor_(resizeAs)(gradInput, input);
THTensor_(zero)(gradInput);
+ int dimS = 0; // sequence dimension
+ int dimF = 1; // feature dimension
+
+ if (input->nDimension == 3)
+ {
+ dimS = 1;
+ dimF = 2;
+ }
/* sizes */
- noframe = gradOutput->size[0];
- framesize = gradOutput->size[1];
+ niframe = input->size[dimS];
+ noframe = gradOutput->size[dimS];
+ framesize = gradOutput->size[dimF];
/* get raw pointers */
gradInput_data = THTensor_(data)(gradInput);
gradOutput_data = THTensor_(data)(gradOutput);
indices_data = THTensor_(data)(indices);
- for(t = 0; t < noframe; t++)
+ if (input->nDimension == 2)
{
- real *gip = gradInput_data + t*framesize*dW;
- real *gop = gradOutput_data + t*framesize;
- real *xp = indices_data + t*framesize;
+ for(t = 0; t < noframe; t++)
+ {
+ real *gip = gradInput_data + t*framesize*dW;
+ real *gop = gradOutput_data + t*framesize;
+ real *xp = indices_data + t*framesize;
#pragma omp parallel for private(y)
- for(y = 0; y < framesize; y++)
+ for(y = 0; y < framesize; y++)
+ {
+ /* compute local max: */
+ long maxindex = (long)xp[y];
+ gip[maxindex*framesize+y] += gop[y];
+ }
+ }
+ }
+ else
+ {
+ /* number of batch frames */
+ long nbframe = input->size[0];
+ long i;
+
+ for(i = 0; i < nbframe; i++)
{
- /* compute local max: */
- long maxindex = (long)xp[y];
- gip[maxindex*framesize+y] += gop[y];
+ real *gradInputSample_data = gradInput_data + i*niframe*framesize;
+ real *gradOutputSample_data = gradOutput_data + i*noframe*framesize;
+ real *indicesSample_data = indices_data + i*noframe*framesize;
+
+ for(t = 0; t < noframe; t++)
+ {
+ real *gip = gradInputSample_data + t*framesize*dW;
+ real *gop = gradOutputSample_data + t*framesize;
+ real *xp = indicesSample_data + t*framesize;
+#pragma omp parallel for private(y)
+ for(y = 0; y < framesize; y++)
+ {
+ /* compute local max: */
+ long maxindex = (long)xp[y];
+ gip[maxindex*framesize+y] += gop[y];
+ }
+ }
}
}
diff --git a/test/test.lua b/test/test.lua
index 0e85578..7eb2d44 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1438,12 +1438,36 @@ function nntest.TemporalMaxPooling()
local module = nn.TemporalMaxPooling(ki, si)
local input = torch.Tensor(ini, from):zero()
+ -- 1D
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
+
+ -- 2D
+ local nBatchFrame = 2
+ local input = torch.Tensor(nBatchFrame, ini, from):zero()
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
+
+ -- 2D matches 1D
+ local output = module:forward(input):clone()
+ local outputGrad = torch.randn(output:size())
+ local inputGrad = module:backward(input, outputGrad):clone()
+
+ local input1D = input:select(1, 1)
+ local output1D = module:forward(input1D)
+ local outputGrad1D = outputGrad:select(1, 1)
+ local inputGrad1D = module:backward(input1D, outputGrad1D)
+
+ mytester:assertTensorEq(output:select(1,1), output1D, 0.000001, 'error on 2D vs 1D forward)')
+ mytester:assertTensorEq(inputGrad:select(1,1), inputGrad1D, 0.000001, 'error on 2D vs 1D backward)')
end
function nntest.VolumetricConvolution()