Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-11 20:46:31 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-11 20:46:31 +0400
commit982c68150363e938a6361780abdbcfb35b816dc2 (patch)
treed8713eb97c1c328193f0ac73c0ec65114bab8909 /test
parentbc259bf37bdf097af38b01eaed934105defe84ee (diff)
MixtureTable unit tests
Diffstat (limited to 'test')
-rw-r--r--test/test.lua21
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 1eb571e..8eab9a1 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1954,6 +1954,27 @@ function nntest.SelectTable()
equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
end
+function nntest.MixtureTable()
+ local expertInput = torch.randn(5,3,6)
+ local input = {
+ torch.rand(5,3),
+ {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)}
+ }
+ local gradOutput = torch.randn(5,6)
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 5, 1, 6):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(3):select(3,1)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), gradOutput:view(5,1,6):expand(5,3,6))
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture expert "..i.." gradInput")
+ end
+end
+
function nntest.View()
local input = torch.rand(10)
local template = torch.rand(5,2)