diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-24 16:50:11 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-24 16:50:11 +0300 |
commit | 207e130221c6b85be94a2a75a8b211799cdc8f46 (patch) | |
tree | df1e8bc9c983781d7d3a91879fb57aab123da787 | |
parent | 7742eba2218446057f0414c9df4879c5e14481a8 (diff) | |
parent | 59e08237f3aa26df80a46331f742c28b9a93e752 (diff) |
Merge pull request #1113 from huihuifan/narrow_edit
changed narrow to standardize negative length and negative offset beh…
-rw-r--r-- | Narrow.lua | 28 | ||||
-rw-r--r-- | doc/simple.md | 52 | ||||
-rw-r--r-- | test.lua | 47 |
3 files changed, 68 insertions, 59 deletions
@@ -13,14 +13,16 @@ end function Narrow:updateOutput(input) local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension local length = self.length - if length < 0 then - length = input:size(dim) - self.index + self.length + 2 - end local index = self.index - if self.index < 0 then - index = 1 - length = input:size(dim) - length + + if index < 0 then + index = input:size(dim) + self.index + 1 end + + if length < 0 then + length = input:size(dim) - index + 1 - torch.abs(length) + end + local output=input:narrow(dim, index, length) self.output = self.output:typeAs(output) self.output:resizeAs(output):copy(output) @@ -30,14 +32,16 @@ end function Narrow:updateGradInput(input, gradOutput) local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension local length = self.length - if length < 0 then - length = input:size(dim) - self.index + self.length + 2 - end local index = self.index - if self.index < 0 then - index = 1 - length = input:size(dim) - length + + if index < 0 then + index = input:size(dim) + self.index + 1 end + + if length < 0 then + length = input:size(dim) - index + 1 - torch.abs(length) + end + self.gradInput = self.gradInput:typeAs(input) self.gradInput:resizeAs(input):zero() self.gradInput:narrow(dim,index,length):copy(gradOutput) diff --git a/doc/simple.md b/doc/simple.md index 09c60ca..6b0e4d5 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -685,47 +685,45 @@ Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/do > x = torch.rand(4, 5) > x - 0.3695 0.2017 0.4485 0.4638 0.0513 - 0.9222 0.1877 0.3388 0.6265 0.5659 - 0.8785 0.7394 0.8265 0.9212 0.0129 - 0.2290 0.7971 0.2113 0.1097 0.3166 + 0.2746 0.8704 0.6839 0.9137 0.5994 + 0.6099 0.6365 0.0923 0.0795 0.4404 + 0.3270 0.9202 0.6142 0.8548 0.8239 + 0.7058 0.6300 0.8553 0.7736 0.3567 [torch.DoubleTensor of size 4x5] > nn.Narrow(1, 2, 3):forward(x) - 0.9222 0.1877 0.3388 0.6265 0.5659 - 0.8785 0.7394 0.8265 0.9212 0.0129 - 0.2290 0.7971 0.2113 0.1097 0.3166 + 0.6099 0.6365 0.0923 0.0795 0.4404 + 0.3270 0.9202 0.6142 0.8548 0.8239 + 0.7058 0.6300 0.8553 0.7736 0.3567 [torch.DoubleTensor of size 3x5] > nn.Narrow(1, 2, -1):forward(x) - 0.9222 0.1877 0.3388 0.6265 0.5659 - 0.8785 0.7394 0.8265 0.9212 0.0129 - 0.2290 0.7971 0.2113 0.1097 0.3166 -[torch.DoubleTensor of size 3x5] + 0.6099 0.6365 0.0923 0.0795 0.4404 + 0.3270 0.9202 0.6142 0.8548 0.8239 +[torch.DoubleTensor of size 2x5] > nn.Narrow(1, 2, 2):forward(x) - 0.9222 0.1877 0.3388 0.6265 0.5659 - 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.6099 0.6365 0.0923 0.0795 0.4404 + 0.3270 0.9202 0.6142 0.8548 0.8239 [torch.DoubleTensor of size 2x5] > nn.Narrow(1, 2, -2):forward(x) - 0.9222 0.1877 0.3388 0.6265 0.5659 - 0.8785 0.7394 0.8265 0.9212 0.0129 -[torch.DoubleTensor of size 2x5] - -> nn.Narrow(2, 2, 3):forward(x) - 0.2017 0.4485 0.4638 - 0.1877 0.3388 0.6265 - 0.7394 0.8265 0.9212 - 0.7971 0.2113 0.1097 + 0.6099 0.6365 0.0923 0.0795 0.4404 +[torch.DoubleTensor of size 1x5] + +> nn.Narrow(2,2,3):forward(x) + 0.8704 0.6839 0.9137 + 0.6365 0.0923 0.0795 + 0.9202 0.6142 0.8548 + 0.6300 0.8553 0.7736 [torch.DoubleTensor of size 4x3] > nn.Narrow(2, 2, -2):forward(x) - 0.2017 0.4485 0.4638 - 0.1877 0.3388 0.6265 - 0.7394 0.8265 0.9212 - 0.7971 0.2113 0.1097 -[torch.DoubleTensor of size 4x3] + 0.8704 0.6839 + 0.6365 0.0923 + 0.9202 0.6142 + 0.6300 0.8553 +[torch.DoubleTensor of size 4x2] ``` <a name="nn.Replicate"></a> @@ -5854,9 +5854,10 @@ function nntest.Narrow() local module1 = nn.Narrow(1, 3, 5) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(1, 3, -3) + local module2 = nn.Narrow(1, 3, -2) local output2 = module2:forward(input) local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #1 output err") mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #1 gradInput err") mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #1 negative output err") @@ -5871,7 +5872,7 @@ function nntest.Narrow() local module1 = nn.Narrow(2, 5, 3) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(2, 5, -4) + local module2 = nn.Narrow(2, 5, -3) local output2 = module2:forward(input) local gradInput2 = module2:backward(input, gradOutput) mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #2 output err") @@ -5888,7 +5889,7 @@ function nntest.Narrow() local module1 = nn.Narrow(3, 1, 1) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(3, 1, -7) + local module2 = nn.Narrow(3, 1, -6) local output2 = module2:forward(input) local gradInput2 = module2:backward(input, gradOutput) mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #3 output err") @@ -5896,22 +5897,28 @@ function nntest.Narrow() mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #3 negative output err") mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #3 negative gradInput err") - -- check basic narrow functionality #4 - local input = torch.rand(3, 10, 4) - local output = input:narrow(2, 5, 3) - local gradOutput = torch.rand(3, 3, 4) - local gradInput = torch.zeros(3, 10, 4) - gradInput:narrow(2, 5, 3):copy(gradOutput) - local module1 = nn.Narrow(-2, 5, 3) - local output1 = module1:forward(input) - local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(-2, 5, -4) - local output2 = module2:forward(input) - local gradInput2 = module2:backward(input, gradOutput) - mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #4 output err") - mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err") - mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err") - mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err") + -- check narrow negative offset and negative length + local input = torch.Tensor({1, 2, 3, 4, 5}) + + local output = torch.Tensor({2}) + local modelOutput = nn.Narrow(1, 2, 1):forward(input) + mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.1 output err") + + local output = torch.Tensor({4}) + local modelOutput = nn.Narrow(1, -2, 1):forward(input) + mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.2 output err") + + local output = torch.Tensor({5}) + local modelOutput = nn.Narrow(1, -1, 1):forward(input) + mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.3 output err") + + local output = torch.Tensor({4}) + local modelOutput = nn.Narrow(1, -2, -1):forward(input) + mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.4 output err") + + local output = torch.Tensor({2, 3, 4}) + local modelOutput = nn.Narrow(1, 2, -1):forward(input) + mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.5 output err") -- check narrow negative offset local input = torch.rand(3, 10, 4) @@ -5919,7 +5926,7 @@ function nntest.Narrow() local gradOutput = torch.rand(3, 3, 4) local gradInput = torch.zeros(3, 10, 4) gradInput:narrow(2, 1, 3):copy(gradOutput) - local module1 = nn.Narrow(2, -1, 7) + local module1 = nn.Narrow(2, -10, 3) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) local module2 = nn.Narrow(2, 1, 3) |