diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-16 01:49:22 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-16 01:49:22 +0400 |
commit | e074440aaf6f02491c9685d4e254ecb159622414 (patch) | |
tree | 7cd9ce5eb020094f38610118883de4418162175b /test | |
parent | 7f0f53416f019bb980c9e922b72048ab5ea4fd77 (diff) |
2D gater 1D experts unit test
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index b3e7ca9..8c17e90 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2062,6 +2062,27 @@ function nntest.MixtureTable() local gradInput = module:backward(input2, gradOutput:float()) mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture6B gater gradInput") mytester:assertTensorEq(gradInput[2], expertGradInput2:float(), 0.000001, "mixture6B expert gradInput") + + --[[ 2D gater, 1D expert]]-- + -- expertInput is a Table: + local expertInput = torch.randn(5,3) + local gradOutput = torch.randn(5) + local input = { + torch.rand(5,3), + {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)} + } + local module = nn.MixtureTable() + local output = module:forward(input) + local output2 = torch.cmul(input[1], expertInput):sum(2) + mytester:assertTensorEq(output, output2, 0.000001, "mixture7 output") + local gradInput = module:backward(input, gradOutput) + local gradOutput2 = torch.view(gradOutput, 5, 1):expandAs(expertInput) + local gaterGradInput2 = torch.cmul(gradOutput2, expertInput) + mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture7 gater gradInput") + local expertGradInput2 = torch.cmul(input[1], gradOutput:view(5,1):expand(5,3)) + for i, expertGradInput in ipairs(gradInput[2]) do + mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture7 expert "..i.." gradInput") + end end function nntest.View() |