Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua17
1 files changed, 17 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 774fba1..5cb7066 100644
--- a/test.lua
+++ b/test.lua
@@ -5674,6 +5674,23 @@ function nntest.Narrow()
mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err")
mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err")
mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err")
+
+ -- check narrow negative offset
+ local input = torch.rand(3, 10, 4)
+ local output = input:narrow(2, 1, 3)
+ local gradOutput = torch.rand(3, 3, 4)
+ local gradInput = torch.zeros(3, 10, 4)
+ gradInput:narrow(2, 1, 3):copy(gradOutput)
+ local module1 = nn.Narrow(2, -1, 7)
+ local output1 = module1:forward(input)
+ local gradInput1 = module1:backward(input, gradOutput)
+ local module2 = nn.Narrow(2, 1, 3)
+ local output2 = module2:forward(input)
+ local gradInput2 = module2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #5 output err")
+ mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #5 gradInput err")
+ mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #5 negative output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #5 negative gradInput err")
end
function nntest.NarrowTable()