Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-25 04:41:28 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-25 04:41:28 +0300
commit6714cebc861db18ede18e3d9d56e05340669998c (patch)
treeb69b6d90025f719df32948972841c55af437507f /test.lua
parenteb6548a0c30db70465de4779d866bfac781ec0b1 (diff)
nn.Convert
Diffstat (limited to 'test.lua')
-rwxr-xr-xtest.lua37
1 files changed, 37 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index e776f26..1b1b2c1 100755
--- a/test.lua
+++ b/test.lua
@@ -8612,6 +8612,43 @@ function nntest.Collapse()
mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous")
end
+function nntest.Convert()
+ -- batch mode
+ local c = nn.Convert('bchw', 'chwb')
+ local input = torch.randn(8,3,5,5)
+ local output = c:forward(input)
+ local output2 = input:transpose(1,4):transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->chwb")
+ local gradInput = c:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd bchw->chwb")
+ local c = nn.Convert('bchw', 'bf')
+ local output = c:forward(input)
+ local output2 = input:view(8,-1)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd bchw->bf")
+ c:float()
+ local output = c:forward(input:float())
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type()")
+ local output = c:forward(input)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float")
+ -- non-batch mode
+ local c = nn.Convert('chw', 'hwc')
+ local input = torch.randn(3,5,5)
+ local output = c:forward(input)
+ local output2 = input:transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->hwc non-batch")
+ local gradInput = c:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.000001, "Convert bwd chw->hwc non-batch")
+ local c = nn.Convert('chw', 'f')
+ local output = c:forward(input)
+ local output2 = input:view(-1)
+ mytester:assertTensorEq(output, output2, 0.000001, "Convert fwd chw->bf non-batch")
+ c:float()
+ local output = c:forward(input:float())
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() non-batch")
+ local output = c:forward(input)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch")
+end
+
mytester:add(nntest)