diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-27 09:35:33 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-27 09:35:33 +0400 |
commit | 958e3ba2cff6450b51646bbd46878e1effe170c7 (patch) | |
tree | 1b378cf62f95f07a39d2721e2e14330097627c86 /test | |
parent | b08848cb42a08c238e1f04494e6bd5baed4cf690 (diff) |
MultiSoftMax unit tests
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 33af2b7..bb36e14 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -509,6 +509,31 @@ function nnxtest.Balance() local gradInput = bl:backward(input, gradOutput) end +function nnxtest.MultiSoftMax() + local inputSize = 7 + local nSoftmax = 5 + local batchSize = 3 + local nBatch = 4 + + local input = torch.randn(batchSize, nSoftmax, inputSize) + local gradOutput = torch.randn(batchSize, nSoftmax, inputSize) + local msm = nn.MultiSoftMax() + + local output = msm:forward(input) + local gradInput = msm:backward(input, gradOutput) + mytester:assert(output:isSameSizeAs(input)) + mytester:assert(gradOutput:isSameSizeAs(gradInput)) + + local sm = nn.SoftMax() + local input2 = input:view(nBatch*nSoftMax, inputSize) + local output2 = sm:forward(input2) + local gradInput2 = sm:backward(input2, gradOutput:view(nBatch*nSoftMax, inputSize)) + + mytester:assertTensorEq(output, output2, 0.000001) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001) +end + + function nnx.test(tests) xlua.require('image',true) mytester = torch.Tester() |