diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-01-06 23:07:53 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-01-06 23:25:47 +0300 |
commit | 3dcecf7c22745b13fd5ba85848423a346ec42ad2 (patch) | |
tree | 26a17b3482d624bc80734056556fbf4ae6fd4958 /test.lua | |
parent | 4e0a96d801060121521ccc46f7294aeb3b247965 (diff) |
Parallel optimization. ParallelTable inherits Container. unit tests
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 34 |
1 files changed, 34 insertions, 0 deletions
@@ -2452,6 +2452,40 @@ function nntest.SpatialUpSamplingNearest() end end +function nntest.Parallel() + local input = torch.randn(3, 4, 5) + local m = nn.Parallel(1,3) + m:add(nn.View(4,5,1)) + m:add(nn.View(4,5,1)) + m:add(nn.View(4,5,1)) + + local output = m:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err') + + local gradInput = m:backward(input, output2) + mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err') +end + +function nntest.ParallelTable() + local input = torch.randn(3, 4, 5) + local p = nn.ParallelTable() + p:add(nn.View(4,5,1)) + p:add(nn.View(4,5,1)) + p:add(nn.View(4,5,1)) + m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(p) + m:add(nn.JoinTable(3)) + + local output = m:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err') + + local gradInput = m:backward(input, output2) + mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err') +end + function nntest.ConcatTable() -- Test tensor input local input = torch.rand(5, 5, 5) |