Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-09-23 04:52:29 +0400
committerSoumith Chintala <soumith@gmail.com>2014-10-27 17:05:15 +0300
commit8d7d03ebe72ae507ae716f292778123dc34e04b1 (patch)
treeb6f2516e6855793d7e06484f0a443948320d8551 /test
parentc721165632345794bb3f5faf6a3502d830c207b6 (diff)
DepthConcat
Diffstat (limited to 'test')
-rw-r--r--test/test.lua36
1 files changed, 32 insertions, 4 deletions
diff --git a/test/test.lua b/test/test.lua
index 593e659..b035928 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1,8 +1,6 @@
-require 'torch'
-
-- you can easily test specific units like this:
--- luajit -lnn -e "nn.test{'LookupTable'}"
--- luajit -lnn -e "nn.test{'LookupTable', 'Add'}"
+-- th -lnn -e "nn.test{'LookupTable'}"
+-- th -lnn -e "nn.test{'LookupTable', 'Add'}"
local mytester = torch.Tester()
local jac
@@ -2311,6 +2309,36 @@ function nntest.L1Penalty()
-- during BPROP is not included in the FPROP output)
end
+function nntest.DepthConcat()
+ local outputSize = torch.IntTensor{5,6,7,8}
+ local input = torch.randn(2,3,12,12)
+ local gradOutput = torch.randn(2, outputSize:sum(), 12, 12)
+ local concat = nn.DepthConcat(2)
+ concat:add(nn.SpatialConvolutionMM(3, outputSize[1], 1, 1, 1, 1)) --> 2, 5, 12, 12
+ concat:add(nn.SpatialConvolutionMM(3, outputSize[2], 3, 3, 1, 1)) --> 2, 6, 10, 10
+ concat:add(nn.SpatialConvolutionMM(3, outputSize[3], 4, 4, 1, 1)) --> 2, 7, 9, 9
+ concat:add(nn.SpatialConvolutionMM(3, outputSize[4], 5, 5, 1, 1)) --> 2, 8, 8, 8
+ concat:zeroGradParameters()
+ -- forward/backward
+ local outputConcat = concat:forward(input)
+ local gradInputConcat = concat:backward(input, gradOutput)
+ -- the spatial dims are the largest, the nFilters is the sum
+ local output = torch.Tensor(2, outputSize:sum(), 12, 12):zero() -- zero for padding
+ local narrows = { {{},{1,5},{},{}}, {{},{6,11},{2,11},{2,11}}, {{},{12,18},{2,10},{2,10}}, {{},{19,26},{3,10},{3,10}} }
+ local gradInput = input:clone():zero()
+ local gradWeights = {}
+ for i=1,4 do
+ local conv = concat:get(i)
+ local gradWeight = conv.gradWeight:clone()
+ conv:zeroGradParameters()
+ output[narrows[i]]:copy(conv:forward(input))
+ gradInput:add(conv:backward(input, gradOutput[narrows[i]]))
+ mytester:assertTensorEq(gradWeight, conv.gradWeight, 0.000001, "Error in SpatialConcat:accGradParameters for conv "..i)
+ end
+ mytester:assertTensorEq(output, outputConcat, 0.000001, "Error in SpatialConcat:updateOutput")
+ mytester:assertTensorEq(gradInput, gradInputConcat, 0.000001, "Error in SpatialConcat:updateGradInput")
+end
+
mytester:add(nntest)
if not nn then