diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 04:41:28 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 04:41:28 +0300 |
commit | 6714cebc861db18ede18e3d9d56e05340669998c (patch) | |
tree | b69b6d90025f719df32948972841c55af437507f /test.lua | |
parent | eb6548a0c30db70465de4779d866bfac781ec0b1 (diff) |
nn.Convert
Diffstat (limited to 'test.lua')
-rwxr-xr-x | test.lua | 37 |
1 files changed, 37 insertions, 0 deletions
@@ -8612,6 +8612,43 @@ function nntest.Collapse() mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous") end +function nntest.Convert() + -- batch mode + local c = nn.Convert('bchw', 'chwb') + local input = torch.randn(8,3,5,5) + local output = c:forward(input) + local output2 = input:transpose(1,4):transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->chwb") + local gradInput = c:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd bchw->chwb") + local c = nn.Convert('bchw', 'bf') + local output = c:forward(input) + local output2 = input:view(8,-1) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->bf") + c:float() + local output = c:forward(input:float()) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type()") + local output = c:forward(input) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float") + -- non-batch mode + local c = nn.Convert('chw', 'hwc') + local input = torch.randn(3,5,5) + local output = c:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->hwc non-batch") + local gradInput = c:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd chw->hwc non-batch") + local c = nn.Convert('chw', 'f') + local output = c:forward(input) + local output2 = input:view(-1) + mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->bf non-batch") + c:float() + local output = c:forward(input:float()) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() non-batch") + local output = c:forward(input) + mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch") +end + mytester:add(nntest) |