diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-12 13:50:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-12 13:50:56 +0300 |
commit | 6691772199c4df338902aafaca67c95d8b3d6a2b (patch) | |
tree | 2459c320875bfe070228833563860b29a2895e5d | |
parent | 9b2f1ef7c45204f5c278bbdc55e367fcdd29e70e (diff) | |
parent | ee700e25cc47f0c7433d76059fb8bb72af62fb25 (diff) |
Merge pull request #1070 from torch/narrownegative
Support negative offset in nn.Narrow
-rw-r--r-- | Narrow.lua | 14 | ||||
-rw-r--r-- | doc/simple.md | 5 | ||||
-rw-r--r-- | test.lua | 17 |
3 files changed, 31 insertions, 5 deletions
@@ -16,7 +16,12 @@ function Narrow:updateOutput(input) if length < 0 then length = input:size(dim) - self.index + self.length + 2 end - local output=input:narrow(dim,self.index,length) + local index = self.index + if self.index < 0 then + index = 1 + length = input:size(dim) - length + end + local output=input:narrow(dim, index, length) self.output = self.output:typeAs(output) self.output:resizeAs(output):copy(output) return self.output @@ -28,8 +33,13 @@ function Narrow:updateGradInput(input, gradOutput) if length < 0 then length = input:size(dim) - self.index + self.length + 2 end + local index = self.index + if self.index < 0 then + index = 1 + length = input:size(dim) - length + end self.gradInput = self.gradInput:typeAs(input) self.gradInput:resizeAs(input):zero() - self.gradInput:narrow(dim,self.index,length):copy(gradOutput) + self.gradInput:narrow(dim,index,length):copy(gradOutput) return self.gradInput end diff --git a/doc/simple.md b/doc/simple.md index b7044ae..09c60ca 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -25,7 +25,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [MaskedSelect](#nn.MaskedSelect) : a [masked select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-maskedselect-index) module performs the torch.maskedSelect operation ; * [Index](#nn.Index) : a [index](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-indexdim-index) over a given dimension ; * [Squeeze](#nn.Squeeze) : [squeezes](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-squeezedim) the input; - * [Unsqueeze](#nn.Unsqueeze) : unsqueeze the input, i.e., insert singleton dimension; + * [Unsqueeze](#nn.Unsqueeze) : unsqueeze the input, i.e., insert singleton dimension; * [Transpose](#nn.Transpose) : [transposes](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-transposedim1-dim2) the input ; * Modules that adapt mathematical Tensor methods : * [AddConstant](https://github.com/torch/nn/blob/master/doc/transfer.md#addconstant) : adding a constant ; @@ -679,8 +679,7 @@ The default is false. module = nn.Narrow(dimension, offset, length) ``` -Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. -The module further supports a negative `length` in order to handle inputs with an unknown size. +Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. The module further supports negative `length`, `dim` and `offset` to handle inputs of unknown size. ```lua > x = torch.rand(4, 5) @@ -5674,6 +5674,23 @@ function nntest.Narrow() mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err") mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err") mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err") + + -- check narrow negative offset + local input = torch.rand(3, 10, 4) + local output = input:narrow(2, 1, 3) + local gradOutput = torch.rand(3, 3, 4) + local gradInput = torch.zeros(3, 10, 4) + gradInput:narrow(2, 1, 3):copy(gradOutput) + local module1 = nn.Narrow(2, -1, 7) + local output1 = module1:forward(input) + local gradInput1 = module1:backward(input, gradOutput) + local module2 = nn.Narrow(2, 1, 3) + local output2 = module2:forward(input) + local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #5 output err") + mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #5 gradInput err") + mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #5 negative output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #5 negative gradInput err") end function nntest.NarrowTable() |