Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-27 09:35:33 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-27 09:35:33 +0400
commit958e3ba2cff6450b51646bbd46878e1effe170c7 (patch)
tree1b378cf62f95f07a39d2721e2e14330097627c86 /test
parentb08848cb42a08c238e1f04494e6bd5baed4cf690 (diff)
MultiSoftMax unit tests
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua25
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 33af2b7..bb36e14 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -509,6 +509,31 @@ function nnxtest.Balance()
local gradInput = bl:backward(input, gradOutput)
end
+function nnxtest.MultiSoftMax()
+ local inputSize = 7
+ local nSoftmax = 5
+ local batchSize = 3
+ local nBatch = 4
+
+ local input = torch.randn(batchSize, nSoftmax, inputSize)
+ local gradOutput = torch.randn(batchSize, nSoftmax, inputSize)
+ local msm = nn.MultiSoftMax()
+
+ local output = msm:forward(input)
+ local gradInput = msm:backward(input, gradOutput)
+ mytester:assert(output:isSameSizeAs(input))
+ mytester:assert(gradOutput:isSameSizeAs(gradInput))
+
+ local sm = nn.SoftMax()
+ local input2 = input:view(nBatch*nSoftMax, inputSize)
+ local output2 = sm:forward(input2)
+ local gradInput2 = sm:backward(input2, gradOutput:view(nBatch*nSoftMax, inputSize))
+
+ mytester:assertTensorEq(output, output2, 0.000001)
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001)
+end
+
+
function nnx.test(tests)
xlua.require('image',true)
mytester = torch.Tester()