Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rwxr-xr-xtest.lua16
1 files changed, 16 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 16bae09..e776f26 100755
--- a/test.lua
+++ b/test.lua
@@ -8596,6 +8596,22 @@ function nntest.ZipTableOneToMany()
mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21")
end
+function nntest.Collapse()
+ local c = nn.Collapse(3)
+ local input = torch.randn(8,3,4,5)
+ local output = c:forward(input)
+ mytester:assertTensorEq(input:view(8,-1), output, 0.000001, "Collapse:forward")
+ local gradInput = c:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.000001, "Collapse:backward")
+ mytester:assertTableEq(gradInput:size():totable(), input:size():totable(), 0.000001, "Collapse:backward size")
+ local input2 = input:transpose(1,4)
+ local output2 = c:forward(input2)
+ mytester:assertTensorEq(input2:contiguous():view(5,-1), output2, 0.000001, "Collapse:forward non-contiguous")
+ local gradInput2 = c:backward(input2, output2)
+ mytester:assertTensorEq(gradInput2, input2, 0.000001, "Collapse:backward non-contiguous")
+ mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous")
+end
+
mytester:add(nntest)