diff options
author | Simon Niklaus <CodeRect@users.noreply.github.com> | 2016-06-05 00:48:16 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-06-05 00:48:16 +0300 |
commit | 5a4e825131c2011e961f80ceeaa889609bb32b16 (patch) | |
tree | b12b8896b5f0a4c52f410a45816f019dfe757639 | |
parent | 009549385527de79a3afabec2b49ac73fe019d77 (diff) |
extended documentation of / added a test case for Narrow (#843)
extended documentation and test of Narrow for negative indices
-rw-r--r-- | doc/simple.md | 49 | ||||
-rw-r--r-- | test.lua | 53 |
2 files changed, 101 insertions, 1 deletions
diff --git a/doc/simple.md b/doc/simple.md index e29813c..50e5c9f 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -607,7 +607,54 @@ When `dontCast` is true, a call to `nn.Copy:type(type)` will not cast the module module = nn.Narrow(dimension, offset, length) ``` -Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. +Narrow is application of [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation in a module. The module further supports a negative `length` in order to handle inputs with an unknown size. + +```lua +> x = torch.rand(4, 5) + +> x + 0.3695 0.2017 0.4485 0.4638 0.0513 + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.2290 0.7971 0.2113 0.1097 0.3166 +[torch.DoubleTensor of size 4x5] + +> nn.Narrow(1, 2, 3):forward(x) + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.2290 0.7971 0.2113 0.1097 0.3166 +[torch.DoubleTensor of size 3x5] + +> nn.Narrow(1, 2, -1):forward(x) + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 + 0.2290 0.7971 0.2113 0.1097 0.3166 +[torch.DoubleTensor of size 3x5] + +> nn.Narrow(1, 2, 2):forward(x) + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 +[torch.DoubleTensor of size 2x5] + +> nn.Narrow(1, 2, -2):forward(x) + 0.9222 0.1877 0.3388 0.6265 0.5659 + 0.8785 0.7394 0.8265 0.9212 0.0129 +[torch.DoubleTensor of size 2x5] + +> nn.Narrow(2, 2, 3):forward(x) + 0.2017 0.4485 0.4638 + 0.1877 0.3388 0.6265 + 0.7394 0.8265 0.9212 + 0.7971 0.2113 0.1097 +[torch.DoubleTensor of size 4x3] + +> nn.Narrow(2, 2, -2):forward(x) + 0.2017 0.4485 0.4638 + 0.1877 0.3388 0.6265 + 0.7394 0.8265 0.9212 + 0.7971 0.2113 0.1097 +[torch.DoubleTensor of size 4x3] +``` <a name="nn.Replicate"></a> ## Replicate ## @@ -4967,6 +4967,59 @@ function nntest.MixtureTable() end end +function nntest.Narrow() + -- check basic narrow functionality #1 + local input = torch.rand(9, 4, 14) + local output = input:narrow(1, 3, 5) + local gradOutput = torch.rand(5, 4, 14) + local gradInput = torch.zeros(9, 4, 14) + gradInput:narrow(1, 3, 5):copy(gradOutput) + local module1 = nn.Narrow(1, 3, 5) + local output1 = module1:forward(input) + local gradInput1 = module1:backward(input, gradOutput) + local module2 = nn.Narrow(1, 3, -3) + local output2 = module2:forward(input) + local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #1 output err") + mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #1 gradInput err") + mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #1 negative output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #1 negative gradInput err") + + -- check basic narrow functionality #2 + local input = torch.rand(3, 10, 4) + local output = input:narrow(2, 5, 3) + local gradOutput = torch.rand(3, 3, 4) + local gradInput = torch.zeros(3, 10, 4) + gradInput:narrow(2, 5, 3):copy(gradOutput) + local module1 = nn.Narrow(2, 5, 3) + local output1 = module1:forward(input) + local gradInput1 = module1:backward(input, gradOutput) + local module2 = nn.Narrow(2, 5, -4) + local output2 = module2:forward(input) + local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #2 output err") + mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #2 gradInput err") + mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #2 negative output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #2 negative gradInput err") + + -- check basic narrow functionality #3 + local input = torch.rand(6, 11, 7) + local output = input:narrow(3, 1, 1) + local gradOutput = torch.rand(6, 11, 1) + local gradInput = torch.zeros(6, 11, 7) + gradInput:narrow(3, 1, 1):copy(gradOutput) + local module1 = nn.Narrow(3, 1, 1) + local output1 = module1:forward(input) + local gradInput1 = module1:backward(input, gradOutput) + local module2 = nn.Narrow(3, 1, -7) + local output2 = module2:forward(input) + local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #3 output err") + mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #3 gradInput err") + mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #3 negative output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #3 negative gradInput err") +end + function nntest.NarrowTable() local input = torch.randn(3,10,4) local gradOutput = torch.randn(3,3,4) |