diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 17 |
1 files changed, 17 insertions, 0 deletions
@@ -5674,6 +5674,23 @@ function nntest.Narrow() mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #4 gradInput err") mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #4 negative output err") mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #4 negative gradInput err") + + -- check narrow negative offset + local input = torch.rand(3, 10, 4) + local output = input:narrow(2, 1, 3) + local gradOutput = torch.rand(3, 3, 4) + local gradInput = torch.zeros(3, 10, 4) + gradInput:narrow(2, 1, 3):copy(gradOutput) + local module1 = nn.Narrow(2, -1, 7) + local output1 = module1:forward(input) + local gradInput1 = module1:backward(input, gradOutput) + local module2 = nn.Narrow(2, 1, 3) + local output2 = module2:forward(input) + local gradInput2 = module2:backward(input, gradOutput) + mytester:assertTensorEq(output, output1, 0.0000001, "Narrow #5 output err") + mytester:assertTensorEq(gradInput, gradInput1, 0.00001, "Narrow #5 gradInput err") + mytester:assertTensorEq(output, output2, 0.0000001, "Narrow #5 negative output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "Narrow #5 negative gradInput err") end function nntest.NarrowTable() |