diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-27 18:19:04 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-27 18:19:04 +0300 |
commit | 5efb612f312ea8d3996251ab9f45472ea90485c8 (patch) | |
tree | 23c987c475a1ea08a5a207d66b122ff281193490 | |
parent | 479bf761ec38ae686d0a45c45744ca0b35839178 (diff) |
Revert "changed narrow to standardize negative length and negative offset beh…"revert-1113-narrow_edit
-rw-r--r-- | Narrow.lua | 28 | ||||
-rw-r--r-- | doc/simple.md | 52 | ||||
-rw-r--r-- | test.lua | 47 |
3 files changed, 59 insertions, 68 deletions
@@ -13,16 +13,14 @@ end function Narrow:updateOutput(input) local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension local length = self.length - local index = self.index - - if index < 0 then - index = input:size(dim) + self.index + 1 - end - if length < 0 then - length = input:size(dim) - index + 1 - torch.abs(length) + length = input:size(dim) - self.index + self.length + 2 + end + local index = self.index + if self.index < 0 then + index = 1 + length = input:size(dim) - length end - local output=input:narrow(dim, index, length) self.output = self.output:typeAs(output) self.output:resizeAs(output):copy(output) @@ -32,16 +30,14 @@ end function Narrow:updateGradInput(input, gradOutput) local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension local length = self.length - local index = self.index - - if index < 0 then - index = input:size(dim) + self.index + 1 - end - if length < 0 then - length = input:size(dim) - index + 1 - torch.abs(length) + length = input:size(dim) - self.index + self.length + 2 + end + local index = self.index + if self.index < 0 then + index = 1 + length = input:size(dim) - length end - self.gradInput = self.gradInput:typeAs(input) self.gradInput:resizeAs(input):zero() self.gradInput:narrow(dim,index,length):copy(gradOutput) diff --git a/doc/simple.md b/doc/simple.md index 6b0e4d5..09c60ca 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -685,45 +685,47 @@ Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/do > x = torch.rand(4, 5) > x - 0.2746 0.8704 0.6839 0.9137 0.5994 - 0.6099 0.6365 0.0923 0.0795 0.4404 - 0.3270 0.9202 0.6142 0.8548 0.8239 - 0.7058 0.6300 0.8553 0.7736 0.3567 + 0.3695 0.2017 0.4485 0.4638 0.0513 + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.2290 0.7971 0.2113 0.1097 0.3166 [torch.DoubleTensor of size 4x5] > nn.Narrow(1, 2, 3):forward(x) - 0.6099 0.6365 0.0923 0.0795 0.4404 - 0.3270 0.9202 0.6142 0.8548 0.8239 - 0.7058 0.6300 0.8553 0.7736 0.3567 + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.2290 0.7971 0.2113 0.1097 0.3166 [torch.DoubleTensor of size 3x5] > nn.Narrow(1, 2, -1):forward(x) - 0.6099 0.6365 0.0923 0.0795 0.4404 - 0.3270 0.9202 0.6142 0.8548 0.8239 -[torch.DoubleTensor of size 2x5] + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.2290 0.7971 0.2113 0.1097 0.3166 +[torch.DoubleTensor of size 3x5] > nn.Narrow(1, 2, 2):forward(x) - 0.6099 0.6365 0.0923 0.0795 0.4404 - 0.3270 0.9202 0.6142 0.8548 0.8239 + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 [torch.DoubleTensor of size 2x5] > nn.Narrow(1, 2, -2):forward(x) - 0.6099 0.6365 0.0923 0.0795 0.4404 -[torch.DoubleTensor of size 1x5] - -> nn.Narrow(2,2,3):forward(x) - 0.8704 0.6839 0.9137 - 0.6365 0.0923 0.0795 - 0.9202 0.6142 0.8548 - 0.6300 0.8553 0.7736 + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 +[torch.DoubleTensor of size 2x5] + +> nn.Narrow(2, 2, 3):forward(x) + 0.2017 0.4485 0.4638 + 0.1877 0.3388 0.6265 + 0.7394 0.8265 0.9212 + 0.7971 0.2113 0.1097 [torch.DoubleTensor of size 4x3] > nn.Narrow(2, 2, -2):forward(x) - 0.8704 0.6839 - 0.6365 0.0923 - 0.9202 0.6142 - 0.6300 0.8553 -[torch.DoubleTensor of size 4x2] + 0.2017 0.4485 0.4638 + 0.1877 0.3388 0.6265 + 0.7394 0.8265 0.9212 + 0.7971 0.2113 0.1097 +[torch.DoubleTensor of size 4x3] ``` <a name="nn.Replicate"></a> @@ -5854,10 +5854,9 @@ function nntest.Narrow() local module1 = nn.Narrow(1, 3, 5) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(1, 3, -2) + local module2 = nn.Narrow(1, 3, -3) local output2 = module2:forward(input) local gradInput2 = module2:backward(input, gradOutput) - mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #1 output err") mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #1 gradInput err") mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #1 negative output err") @@ -5872,7 +5871,7 @@ function nntest.Narrow() local module1 = nn.Narrow(2, 5, 3) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(2, 5, -3) + local module2 = nn.Narrow(2, 5, -4) local output2 = module2:forward(input) local gradInput2 = module2:backward(input, gradOutput) mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #2 output err") @@ -5889,7 +5888,7 @@ function nntest.Narrow() local module1 = nn.Narrow(3, 1, 1) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) - local module2 = nn.Narrow(3, 1, -6) + local module2 = nn.Narrow(3, 1, -7) local output2 = module2:forward(input) local gradInput2 = module2:backward(input, gradOutput) mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #3 output err") @@ -5897,28 +5896,22 @@ function nntest.Narrow() mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #3 negative output err") mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #3 negative gradInput err") - -- check narrow negative offset and negative length - local input = torch.Tensor({1, 2, 3, 4, 5}) - - local output = torch.Tensor({2}) - local modelOutput = nn.Narrow(1, 2, 1):forward(input) - mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.1 output err") - - local output = torch.Tensor({4}) - local modelOutput = nn.Narrow(1, -2, 1):forward(input) - mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.2 output err") - - local output = torch.Tensor({5}) - local modelOutput = nn.Narrow(1, -1, 1):forward(input) - mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.3 output err") - - local output = torch.Tensor({4}) - local modelOutput = nn.Narrow(1, -2, -1):forward(input) - mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.4 output err") - - local output = torch.Tensor({2, 3, 4}) - local modelOutput = nn.Narrow(1, 2, -1):forward(input) - mytester:assertTensorEq(output, modelOutput, 0.0000001, "Narrow #4.5 output err") + -- check basic narrow functionality #4 + local input = torch.rand(3, 10, 4) + local output = input:narrow(2, 5, 3) + local gradOutput = torch.rand(3, 3, 4) + local gradInput = torch.zeros(3, 10, 4) + gradInput:narrow(2, 5, 3):copy(gradOutput) + local module1 = nn.Narrow(-2, 5, 3) + local output1 = module1:forward(input) + local gradInput1 = module1:backward(input, gradOutput) + local module2 = nn.Narrow(-2, 5, -4) + local output2 = module2:forward(input) + local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #4 output err") + mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err") + mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err") -- check narrow negative offset local input = torch.rand(3, 10, 4) @@ -5926,7 +5919,7 @@ function nntest.Narrow() local gradOutput = torch.rand(3, 3, 4) local gradInput = torch.zeros(3, 10, 4) gradInput:narrow(2, 1, 3):copy(gradOutput) - local module1 = nn.Narrow(2, -10, 3) + local module1 = nn.Narrow(2, -1, 7) local output1 = module1:forward(input) local gradInput1 = module1:backward(input, gradOutput) local module2 = nn.Narrow(2, 1, 3) |