Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-12-12 13:49:06 +0300
committersoumith <soumith@fb.com>2016-12-12 13:49:06 +0300
commitee700e25cc47f0c7433d76059fb8bb72af62fb25 (patch)
tree2459c320875bfe070228833563860b29a2895e5d
parent9b2f1ef7c45204f5c278bbdc55e367fcdd29e70e (diff)
Support negative offset in nn.Narrownarrownegative
-rw-r--r--Narrow.lua14
-rw-r--r--doc/simple.md5
-rw-r--r--test.lua17
3 files changed, 31 insertions, 5 deletions
diff --git a/Narrow.lua b/Narrow.lua
index 0754d45..a6ebaa3 100644
--- a/Narrow.lua
+++ b/Narrow.lua
@@ -16,7 +16,12 @@ function Narrow:updateOutput(input)
if length < 0 then
length = input:size(dim) - self.index + self.length + 2
end
- local output=input:narrow(dim,self.index,length)
+ local index = self.index
+ if self.index < 0 then
+ index = 1
+ length = input:size(dim) - length
+ end
+ local output=input:narrow(dim, index, length)
self.output = self.output:typeAs(output)
self.output:resizeAs(output):copy(output)
return self.output
@@ -28,8 +33,13 @@ function Narrow:updateGradInput(input, gradOutput)
if length < 0 then
length = input:size(dim) - self.index + self.length + 2
end
+ local index = self.index
+ if self.index < 0 then
+ index = 1
+ length = input:size(dim) - length
+ end
self.gradInput = self.gradInput:typeAs(input)
self.gradInput:resizeAs(input):zero()
- self.gradInput:narrow(dim,self.index,length):copy(gradOutput)
+ self.gradInput:narrow(dim,index,length):copy(gradOutput)
return self.gradInput
end
diff --git a/doc/simple.md b/doc/simple.md
index b7044ae..09c60ca 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -25,7 +25,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [MaskedSelect](#nn.MaskedSelect) : a [masked select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-maskedselect-index) module performs the torch.maskedSelect operation ;
* [Index](#nn.Index) : a [index](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-indexdim-index) over a given dimension ;
* [Squeeze](#nn.Squeeze) : [squeezes](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-squeezedim) the input;
- * [Unsqueeze](#nn.Unsqueeze) : unsqueeze the input, i.e., insert singleton dimension;
+ * [Unsqueeze](#nn.Unsqueeze) : unsqueeze the input, i.e., insert singleton dimension;
* [Transpose](#nn.Transpose) : [transposes](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-transposedim1-dim2) the input ;
* Modules that adapt mathematical Tensor methods :
* [AddConstant](https://github.com/torch/nn/blob/master/doc/transfer.md#addconstant) : adding a constant ;
@@ -679,8 +679,7 @@ The default is false.
module = nn.Narrow(dimension, offset, length)
```
-Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module.
-The module further supports a negative `length` in order to handle inputs with an unknown size.
+Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. The module further supports negative `length`, `dim` and `offset` to handle inputs of unknown size.
```lua
> x = torch.rand(4, 5)
diff --git a/test.lua b/test.lua
index 774fba1..5cb7066 100644
--- a/test.lua
+++ b/test.lua
@@ -5674,6 +5674,23 @@ function nntest.Narrow()
mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err")
mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err")
mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err")
+
+ -- check narrow negative offset
+ local input = torch.rand(3, 10, 4)
+ local output = input:narrow(2, 1, 3)
+ local gradOutput = torch.rand(3, 3, 4)
+ local gradInput = torch.zeros(3, 10, 4)
+ gradInput:narrow(2, 1, 3):copy(gradOutput)
+ local module1 = nn.Narrow(2, -1, 7)
+ local output1 = module1:forward(input)
+ local gradInput1 = module1:backward(input, gradOutput)
+ local module2 = nn.Narrow(2, 1, 3)
+ local output2 = module2:forward(input)
+ local gradInput2 = module2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #5 output err")
+ mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #5 gradInput err")
+ mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #5 negative output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #5 negative gradInput err")
end
function nntest.NarrowTable()