diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-11 20:46:31 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-11 20:46:31 +0400 |
commit | 982c68150363e938a6361780abdbcfb35b816dc2 (patch) | |
tree | d8713eb97c1c328193f0ac73c0ec65114bab8909 /test | |
parent | bc259bf37bdf097af38b01eaed934105defe84ee (diff) |
MixtureTable unit tests
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 1eb571e..8eab9a1 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1954,6 +1954,27 @@ function nntest.SelectTable() equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) end +function nntest.MixtureTable() + local expertInput = torch.randn(5,3,6) + local input = { + torch.rand(5,3), + {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)} + } + local gradOutput = torch.randn(5,6) + local module = nn.MixtureTable() + local output = module:forward(input) + local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2) + mytester:assertTensorEq(output, output2, 0.000001, "mixture output") + local gradInput = module:backward(input, gradOutput) + local gradOutput2 = torch.view(gradOutput, 5, 1, 6):expandAs(expertInput) + local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(3):select(3,1) + mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture gater gradInput") + local expertGradInput2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), gradOutput:view(5,1,6):expand(5,3,6)) + for i, expertGradInput in ipairs(gradInput[2]) do + mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture expert "..i.." gradInput") + end +end + function nntest.View() local input = torch.rand(10) local template = torch.rand(5,2) |