diff options
Diffstat (limited to 'test.lua')
-rwxr-xr-x | test.lua | 16 |
1 files changed, 16 insertions, 0 deletions
@@ -8596,6 +8596,22 @@ function nntest.ZipTableOneToMany() mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21") end +function nntest.Collapse() + local c = nn.Collapse(3) + local input = torch.randn(8,3,4,5) + local output = c:forward(input) + mytester:assertTensorEq(input:view(8,-1), output, 0.000001, "Collapse:forward") + local gradInput = c:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.000001, "Collapse:backward") + mytester:assertTableEq(gradInput:size():totable(), input:size():totable(), 0.000001, "Collapse:backward size") + local input2 = input:transpose(1,4) + local output2 = c:forward(input2) + mytester:assertTensorEq(input2:contiguous():view(5,-1), output2, 0.000001, "Collapse:forward non-contiguous") + local gradInput2 = c:backward(input2, output2) + mytester:assertTensorEq(gradInput2, input2, 0.000001, "Collapse:backward non-contiguous") + mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous") +end + mytester:add(nntest) |