diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-07 09:28:51 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-07 09:28:51 +0300 |
commit | 675507d9a1ca9c8b854a45e388499bbffc0cda61 (patch) | |
tree | ade128400a5a753cc7086afc4fd9ad2e35888f87 /test.lua | |
parent | 81d2c4215451b350404364dfc19ef5250fe6155b (diff) | |
parent | 5b198168ebaa330e0530fe67f4e08f0b8c1114ba (diff) |
Merge pull request #135 from nicholas-leonard/parallel
Parallel, Container & cie
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 34 |
1 files changed, 34 insertions, 0 deletions
@@ -2462,6 +2462,40 @@ function nntest.SpatialUpSamplingNearest() end end +function nntest.Parallel() + local input = torch.randn(3, 4, 5) + local m = nn.Parallel(1,3) + m:add(nn.View(4,5,1)) + m:add(nn.View(4,5,1)) + m:add(nn.View(4,5,1)) + + local output = m:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err') + + local gradInput = m:backward(input, output2) + mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err') +end + +function nntest.ParallelTable() + local input = torch.randn(3, 4, 5) + local p = nn.ParallelTable() + p:add(nn.View(4,5,1)) + p:add(nn.View(4,5,1)) + p:add(nn.View(4,5,1)) + m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(p) + m:add(nn.JoinTable(3)) + + local output = m:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err') + + local gradInput = m:backward(input, output2) + mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err') +end + function nntest.ConcatTable() -- Test tensor input local input = torch.rand(5, 5, 5) |