Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-01-06 23:07:53 +0300
committernicholas-leonard <nick@nikopia.org>2015-01-06 23:25:47 +0300
commit3dcecf7c22745b13fd5ba85848423a346ec42ad2 (patch)
tree26a17b3482d624bc80734056556fbf4ae6fd4958 /test.lua
parent4e0a96d801060121521ccc46f7294aeb3b247965 (diff)
Parallel optimization. ParallelTable inherits Container. unit tests
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua34
1 files changed, 34 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 3cf6a58..eb2f3ea 100644
--- a/test.lua
+++ b/test.lua
@@ -2452,6 +2452,40 @@ function nntest.SpatialUpSamplingNearest()
end
end
+function nntest.Parallel()
+ local input = torch.randn(3, 4, 5)
+ local m = nn.Parallel(1,3)
+ m:add(nn.View(4,5,1))
+ m:add(nn.View(4,5,1))
+ m:add(nn.View(4,5,1))
+
+ local output = m:forward(input)
+ local output2 = input:transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err')
+
+ local gradInput = m:backward(input, output2)
+ mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err')
+end
+
+function nntest.ParallelTable()
+ local input = torch.randn(3, 4, 5)
+ local p = nn.ParallelTable()
+ p:add(nn.View(4,5,1))
+ p:add(nn.View(4,5,1))
+ p:add(nn.View(4,5,1))
+ m = nn.Sequential()
+ m:add(nn.SplitTable(1))
+ m:add(p)
+ m:add(nn.JoinTable(3))
+
+ local output = m:forward(input)
+ local output2 = input:transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err')
+
+ local gradInput = m:backward(input, output2)
+ mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err')
+end
+
function nntest.ConcatTable()
-- Test tensor input
local input = torch.rand(5, 5, 5)