diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-11 22:13:15 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-11 22:13:15 +0400 |
commit | c40edcad5cb0bff5aeec53ebff6634887978225f (patch) | |
tree | 98c796396620128bc73ac45b158a2d8fde552e6e /test | |
parent | 2d56aa7e6dbc33c49664118437afef0c0882b6d7 (diff) |
MixtureTable unit tests (expertInput is a Tensor)
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua index 8eab9a1..802e553 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1956,11 +1956,12 @@ end function nntest.MixtureTable() local expertInput = torch.randn(5,3,6) + local gradOutput = torch.randn(5,6) + -- expertInput is a Table: local input = { torch.rand(5,3), {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)} } - local gradOutput = torch.randn(5,6) local module = nn.MixtureTable() local output = module:forward(input) local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2) @@ -1973,6 +1974,18 @@ function nntest.MixtureTable() for i, expertGradInput in ipairs(gradInput[2]) do mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture expert "..i.." gradInput") end + -- expertInput is a Tensor: + local input = {input[1], expertInput} + local module = nn.MixtureTable(2) + local output = module:forward(input) + local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2) + mytester:assertTensorEq(output, output2, 0.000001, "mixture2 output") + local gradInput = module:backward(input, gradOutput) + local gradOutput2 = torch.view(gradOutput, 5, 1, 6):expandAs(expertInput) + local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(3):select(3,1) + mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture2 gater gradInput") + local expertGradInput2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), gradOutput:view(5,1,6):expand(5,3,6)) + mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture2 expert gradInput") end function nntest.View() |