Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-01-24 16:50:11 +0300
committerGitHub <noreply@github.com>2017-01-24 16:50:11 +0300
commit207e130221c6b85be94a2a75a8b211799cdc8f46 (patch)
treedf1e8bc9c983781d7d3a91879fb57aab123da787
parent7742eba2218446057f0414c9df4879c5e14481a8 (diff)
parent59e08237f3aa26df80a46331f742c28b9a93e752 (diff)
Merge pull request #1113 from huihuifan/narrow_edit
changed narrow to standardize negative length and negative offset beh…
-rw-r--r--Narrow.lua28
-rw-r--r--doc/simple.md52
-rw-r--r--test.lua47
3 files changed, 68 insertions, 59 deletions
diff --git a/Narrow.lua b/Narrow.lua
index a6ebaa3..01be934 100644
--- a/Narrow.lua
+++ b/Narrow.lua
@@ -13,14 +13,16 @@ end
function Narrow:updateOutput(input)
local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension
local length = self.length
- if length < 0 then
- length = input:size(dim) - self.index + self.length + 2
- end
local index = self.index
- if self.index < 0 then
- index = 1
- length = input:size(dim) - length
+
+ if index < 0 then
+ index = input:size(dim) + self.index + 1
end
+
+ if length < 0 then
+ length = input:size(dim) - index + 1 - torch.abs(length)
+ end
+
local output=input:narrow(dim, index, length)
self.output = self.output:typeAs(output)
self.output:resizeAs(output):copy(output)
@@ -30,14 +32,16 @@ end
function Narrow:updateGradInput(input, gradOutput)
local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension
local length = self.length
- if length < 0 then
- length = input:size(dim) - self.index + self.length + 2
- end
local index = self.index
- if self.index < 0 then
- index = 1
- length = input:size(dim) - length
+
+ if index < 0 then
+ index = input:size(dim) + self.index + 1
end
+
+ if length < 0 then
+ length = input:size(dim) - index + 1 - torch.abs(length)
+ end
+
self.gradInput = self.gradInput:typeAs(input)
self.gradInput:resizeAs(input):zero()
self.gradInput:narrow(dim,index,length):copy(gradOutput)
diff --git a/doc/simple.md b/doc/simple.md
index 09c60ca..6b0e4d5 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -685,47 +685,45 @@ Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/do
> x = torch.rand(4, 5)
> x
- 0.3695 0.2017 0.4485 0.4638 0.0513
- 0.9222 0.1877 0.3388 0.6265 0.5659
- 0.8785 0.7394 0.8265 0.9212 0.0129
- 0.2290 0.7971 0.2113 0.1097 0.3166
+ 0.2746 0.8704 0.6839 0.9137 0.5994
+ 0.6099 0.6365 0.0923 0.0795 0.4404
+ 0.3270 0.9202 0.6142 0.8548 0.8239
+ 0.7058 0.6300 0.8553 0.7736 0.3567
[torch.DoubleTensor of size 4x5]
> nn.Narrow(1, 2, 3):forward(x)
- 0.9222 0.1877 0.3388 0.6265 0.5659
- 0.8785 0.7394 0.8265 0.9212 0.0129
- 0.2290 0.7971 0.2113 0.1097 0.3166
+ 0.6099 0.6365 0.0923 0.0795 0.4404
+ 0.3270 0.9202 0.6142 0.8548 0.8239
+ 0.7058 0.6300 0.8553 0.7736 0.3567
[torch.DoubleTensor of size 3x5]
> nn.Narrow(1, 2, -1):forward(x)
- 0.9222 0.1877 0.3388 0.6265 0.5659
- 0.8785 0.7394 0.8265 0.9212 0.0129
- 0.2290 0.7971 0.2113 0.1097 0.3166
-[torch.DoubleTensor of size 3x5]
+ 0.6099 0.6365 0.0923 0.0795 0.4404
+ 0.3270 0.9202 0.6142 0.8548 0.8239
+[torch.DoubleTensor of size 2x5]
> nn.Narrow(1, 2, 2):forward(x)
- 0.9222 0.1877 0.3388 0.6265 0.5659
- 0.8785 0.7394 0.8265 0.9212 0.0129
+ 0.6099 0.6365 0.0923 0.0795 0.4404
+ 0.3270 0.9202 0.6142 0.8548 0.8239
[torch.DoubleTensor of size 2x5]
> nn.Narrow(1, 2, -2):forward(x)
- 0.9222 0.1877 0.3388 0.6265 0.5659
- 0.8785 0.7394 0.8265 0.9212 0.0129
-[torch.DoubleTensor of size 2x5]
-
-> nn.Narrow(2, 2, 3):forward(x)
- 0.2017 0.4485 0.4638
- 0.1877 0.3388 0.6265
- 0.7394 0.8265 0.9212
- 0.7971 0.2113 0.1097
+ 0.6099 0.6365 0.0923 0.0795 0.4404
+[torch.DoubleTensor of size 1x5]
+
+> nn.Narrow(2,2,3):forward(x)
+ 0.8704 0.6839 0.9137
+ 0.6365 0.0923 0.0795
+ 0.9202 0.6142 0.8548
+ 0.6300 0.8553 0.7736
[torch.DoubleTensor of size 4x3]
> nn.Narrow(2, 2, -2):forward(x)
- 0.2017 0.4485 0.4638
- 0.1877 0.3388 0.6265
- 0.7394 0.8265 0.9212
- 0.7971 0.2113 0.1097
-[torch.DoubleTensor of size 4x3]
+ 0.8704 0.6839
+ 0.6365 0.0923
+ 0.9202 0.6142
+ 0.6300 0.8553
+[torch.DoubleTensor of size 4x2]
```
<a name="nn.Replicate"></a>
diff --git a/test.lua b/test.lua
index b19d6b3..f985f2d 100644
--- a/test.lua
+++ b/test.lua
@@ -5854,9 +5854,10 @@ function nntest.Narrow()
local module1 = nn.Narrow(1, 3, 5)
local output1 = module1:forward(input)
local gradInput1 = module1:backward(input, gradOutput)
- local module2 = nn.Narrow(1, 3, -3)
+ local module2 = nn.Narrow(1, 3, -2)
local output2 = module2:forward(input)
local gradInput2 = module2:backward(input, gradOutput)
+
mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #1 output err")
mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #1 gradInput err")
mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #1 negative output err")
@@ -5871,7 +5872,7 @@ function nntest.Narrow()
local module1 = nn.Narrow(2, 5, 3)
local output1 = module1:forward(input)
local gradInput1 = module1:backward(input, gradOutput)
- local module2 = nn.Narrow(2, 5, -4)
+ local module2 = nn.Narrow(2, 5, -3)
local output2 = module2:forward(input)
local gradInput2 = module2:backward(input, gradOutput)
mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #2 output err")
@@ -5888,7 +5889,7 @@ function nntest.Narrow()
local module1 = nn.Narrow(3, 1, 1)
local output1 = module1:forward(input)
local gradInput1 = module1:backward(input, gradOutput)
- local module2 = nn.Narrow(3, 1, -7)
+ local module2 = nn.Narrow(3, 1, -6)
local output2 = module2:forward(input)
local gradInput2 = module2:backward(input, gradOutput)
mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #3 output err")
@@ -5896,22 +5897,28 @@ function nntest.Narrow()
mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #3 negative output err")
mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #3 negative gradInput err")
- -- check basic narrow functionality #4
- local input = torch.rand(3, 10, 4)
- local output = input:narrow(2, 5, 3)
- local gradOutput = torch.rand(3, 3, 4)
- local gradInput = torch.zeros(3, 10, 4)
- gradInput:narrow(2, 5, 3):copy(gradOutput)
- local module1 = nn.Narrow(-2, 5, 3)
- local output1 = module1:forward(input)
- local gradInput1 = module1:backward(input, gradOutput)
- local module2 = nn.Narrow(-2, 5, -4)
- local output2 = module2:forward(input)
- local gradInput2 = module2:backward(input, gradOutput)
- mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #4 output err")
- mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err")
- mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err")
- mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err")
+ -- check narrow negative offset and negative length
+ local input = torch.Tensor({1, 2, 3, 4, 5})
+
+ local output = torch.Tensor({2})
+ local modelOutput = nn.Narrow(1, 2, 1):forward(input)
+ mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.1 output err")
+
+ local output = torch.Tensor({4})
+ local modelOutput = nn.Narrow(1, -2, 1):forward(input)
+ mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.2 output err")
+
+ local output = torch.Tensor({5})
+ local modelOutput = nn.Narrow(1, -1, 1):forward(input)
+ mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.3 output err")
+
+ local output = torch.Tensor({4})
+ local modelOutput = nn.Narrow(1, -2, -1):forward(input)
+ mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.4 output err")
+
+ local output = torch.Tensor({2, 3, 4})
+ local modelOutput = nn.Narrow(1, 2, -1):forward(input)
+ mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.5 output err")
-- check narrow negative offset
local input = torch.rand(3, 10, 4)
@@ -5919,7 +5926,7 @@ function nntest.Narrow()
local gradOutput = torch.rand(3, 3, 4)
local gradInput = torch.zeros(3, 10, 4)
gradInput:narrow(2, 1, 3):copy(gradOutput)
- local module1 = nn.Narrow(2, -1, 7)
+ local module1 = nn.Narrow(2, -10, 3)
local output1 = module1:forward(input)
local gradInput1 = module1:backward(input, gradOutput)
local module2 = nn.Narrow(2, 1, 3)