diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-09-23 04:52:29 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-10-27 17:05:15 +0300 |
commit | 8d7d03ebe72ae507ae716f292778123dc34e04b1 (patch) | |
tree | b6f2516e6855793d7e06484f0a443948320d8551 /test | |
parent | c721165632345794bb3f5faf6a3502d830c207b6 (diff) |
DepthConcat
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/test/test.lua b/test/test.lua index 593e659..b035928 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1,8 +1,6 @@ -require 'torch' - -- you can easily test specific units like this: --- luajit -lnn -e "nn.test{'LookupTable'}" --- luajit -lnn -e "nn.test{'LookupTable', 'Add'}" +-- th -lnn -e "nn.test{'LookupTable'}" +-- th -lnn -e "nn.test{'LookupTable', 'Add'}" local mytester = torch.Tester() local jac @@ -2311,6 +2309,36 @@ function nntest.L1Penalty() -- during BPROP is not included in the FPROP output) end +function nntest.DepthConcat() + local outputSize = torch.IntTensor{5,6,7,8} + local input = torch.randn(2,3,12,12) + local gradOutput = torch.randn(2, outputSize:sum(), 12, 12) + local concat = nn.DepthConcat(2) + concat:add(nn.SpatialConvolutionMM(3, outputSize[1], 1, 1, 1, 1)) --> 2, 5, 12, 12 + concat:add(nn.SpatialConvolutionMM(3, outputSize[2], 3, 3, 1, 1)) --> 2, 6, 10, 10 + concat:add(nn.SpatialConvolutionMM(3, outputSize[3], 4, 4, 1, 1)) --> 2, 7, 9, 9 + concat:add(nn.SpatialConvolutionMM(3, outputSize[4], 5, 5, 1, 1)) --> 2, 8, 8, 8 + concat:zeroGradParameters() + -- forward/backward + local outputConcat = concat:forward(input) + local gradInputConcat = concat:backward(input, gradOutput) + -- the spatial dims are the largest, the nFilters is the sum + local output = torch.Tensor(2, outputSize:sum(), 12, 12):zero() -- zero for padding + local narrows = { {{},{1,5},{},{}}, {{},{6,11},{2,11},{2,11}}, {{},{12,18},{2,10},{2,10}}, {{},{19,26},{3,10},{3,10}} } + local gradInput = input:clone():zero() + local gradWeights = {} + for i=1,4 do + local conv = concat:get(i) + local gradWeight = conv.gradWeight:clone() + conv:zeroGradParameters() + output[narrows[i]]:copy(conv:forward(input)) + gradInput:add(conv:backward(input, gradOutput[narrows[i]])) + mytester:assertTensorEq(gradWeight, conv.gradWeight, 0.000001, "Error in SpatialConcat:accGradParameters for conv "..i) + end + mytester:assertTensorEq(output, outputConcat, 0.000001, "Error in SpatialConcat:updateOutput") + mytester:assertTensorEq(gradInput, gradInputConcat, 0.000001, "Error in SpatialConcat:updateGradInput") +end + mytester:add(nntest) if not nn then |