Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Niklaus <CodeRect@users.noreply.github.com>2016-06-05 00:48:16 +0300
committerSoumith Chintala <soumith@gmail.com>2016-06-05 00:48:16 +0300
commit5a4e825131c2011e961f80ceeaa889609bb32b16 (patch)
treeb12b8896b5f0a4c52f410a45816f019dfe757639
parent009549385527de79a3afabec2b49ac73fe019d77 (diff)
extended documentation of / added a test case for Narrow (#843)
extended documentation and test of Narrow for negative indices
-rw-r--r--doc/simple.md49
-rw-r--r--test.lua53
2 files changed, 101 insertions, 1 deletions
diff --git a/doc/simple.md b/doc/simple.md
index e29813c..50e5c9f 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -607,7 +607,54 @@ When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast the module
module = nn.Narrow(dimension, offset, length)
```
-Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module.
+Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. The module further supports a negative `length` in order to handle inputs with an unknown size.
+
+```lua
+> x = torch.rand(4, 5)
+
+> x
+ 0.3695 0.2017 0.4485 0.4638 0.0513
+ 0.9222 0.1877 0.3388 0.6265 0.5659
+ 0.8785 0.7394 0.8265 0.9212 0.0129
+ 0.2290 0.7971 0.2113 0.1097 0.3166
+[torch.DoubleTensor of size 4x5]
+
+> nn.Narrow(1, 2, 3):forward(x)
+ 0.9222 0.1877 0.3388 0.6265 0.5659
+ 0.8785 0.7394 0.8265 0.9212 0.0129
+ 0.2290 0.7971 0.2113 0.1097 0.3166
+[torch.DoubleTensor of size 3x5]
+
+> nn.Narrow(1, 2, -1):forward(x)
+ 0.9222 0.1877 0.3388 0.6265 0.5659
+ 0.8785 0.7394 0.8265 0.9212 0.0129
+ 0.2290 0.7971 0.2113 0.1097 0.3166
+[torch.DoubleTensor of size 3x5]
+
+> nn.Narrow(1, 2, 2):forward(x)
+ 0.9222 0.1877 0.3388 0.6265 0.5659
+ 0.8785 0.7394 0.8265 0.9212 0.0129
+[torch.DoubleTensor of size 2x5]
+
+> nn.Narrow(1, 2, -2):forward(x)
+ 0.9222 0.1877 0.3388 0.6265 0.5659
+ 0.8785 0.7394 0.8265 0.9212 0.0129
+[torch.DoubleTensor of size 2x5]
+
+> nn.Narrow(2, 2, 3):forward(x)
+ 0.2017 0.4485 0.4638
+ 0.1877 0.3388 0.6265
+ 0.7394 0.8265 0.9212
+ 0.7971 0.2113 0.1097
+[torch.DoubleTensor of size 4x3]
+
+> nn.Narrow(2, 2, -2):forward(x)
+ 0.2017 0.4485 0.4638
+ 0.1877 0.3388 0.6265
+ 0.7394 0.8265 0.9212
+ 0.7971 0.2113 0.1097
+[torch.DoubleTensor of size 4x3]
+```
<a name="nn.Replicate"></a>
## Replicate ##
diff --git a/test.lua b/test.lua
index 8c0fffd..8bf98ec 100644
--- a/test.lua
+++ b/test.lua
@@ -4967,6 +4967,59 @@ function nntest.MixtureTable()
end
end
+function nntest.Narrow()
+ -- check basic narrow functionality #1
+ local input = torch.rand(9, 4, 14)
+ local output = input:narrow(1, 3, 5)
+ local gradOutput = torch.rand(5, 4, 14)
+ local gradInput = torch.zeros(9, 4, 14)
+ gradInput:narrow(1, 3, 5):copy(gradOutput)
+ local module1 = nn.Narrow(1, 3, 5)
+ local output1 = module1:forward(input)
+ local gradInput1 = module1:backward(input, gradOutput)
+ local module2 = nn.Narrow(1, 3, -3)
+ local output2 = module2:forward(input)
+ local gradInput2 = module2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #1 output err")
+ mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #1 gradInput err")
+ mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #1 negative output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #1 negative gradInput err")
+
+ -- check basic narrow functionality #2
+ local input = torch.rand(3, 10, 4)
+ local output = input:narrow(2, 5, 3)
+ local gradOutput = torch.rand(3, 3, 4)
+ local gradInput = torch.zeros(3, 10, 4)
+ gradInput:narrow(2, 5, 3):copy(gradOutput)
+ local module1 = nn.Narrow(2, 5, 3)
+ local output1 = module1:forward(input)
+ local gradInput1 = module1:backward(input, gradOutput)
+ local module2 = nn.Narrow(2, 5, -4)
+ local output2 = module2:forward(input)
+ local gradInput2 = module2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #2 output err")
+ mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #2 gradInput err")
+ mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #2 negative output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #2 negative gradInput err")
+
+ -- check basic narrow functionality #3
+ local input = torch.rand(6, 11, 7)
+ local output = input:narrow(3, 1, 1)
+ local gradOutput = torch.rand(6, 11, 1)
+ local gradInput = torch.zeros(6, 11, 7)
+ gradInput:narrow(3, 1, 1):copy(gradOutput)
+ local module1 = nn.Narrow(3, 1, 1)
+ local output1 = module1:forward(input)
+ local gradInput1 = module1:backward(input, gradOutput)
+ local module2 = nn.Narrow(3, 1, -7)
+ local output2 = module2:forward(input)
+ local gradInput2 = module2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #3 output err")
+ mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #3 gradInput err")
+ mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #3 negative output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #3 negative gradInput err")
+end
+
function nntest.NarrowTable()
local input = torch.randn(3,10,4)
local gradOutput = torch.randn(3,3,4)