diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-14 03:42:14 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-14 03:42:14 +0400 |
commit | 920c2d78bcb5aeb20dfef4b9b20335a5d1bcbc6e (patch) | |
tree | 390c6058b138340aee71702a806858a920e9d7c5 /test | |
parent | ba82d19f47ed2e377594f72fad60498efda4b5c8 (diff) |
MixtureTable works with 1D inputs (unit tested)
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua index 00d78b1..39db431 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2010,9 +2010,37 @@ function nntest.MixtureTable() local output = module:forward(input) mytester:assertTensorEq(output, output2, 0.000001, "mixture4 output") local gradInput = module:backward(input, gradOutput) - print(gradInput[1], gaterGradInput2) mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture4 gater gradInput") mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture4 expert gradInput") + + --[[ 1D ]]-- + -- expertInput is a Table: + local expertInput = torch.randn(3,6) + local gradOutput = torch.randn(6) + local input = { + torch.rand(3), + {expertInput:select(1,1), expertInput:select(1,2), expertInput:select(1,3)} + } + local module = nn.MixtureTable() + local output = module:forward(input) + local output2 = torch.cmul(input[1]:view(3,1):expand(3,6), expertInput):sum(1) + mytester:assertTensorEq(output, output2, 0.000001, "mixture5 output") + local gradInput = module:backward(input, gradOutput) + local gradOutput2 = torch.view(gradOutput, 1, 6):expandAs(expertInput) + local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(2):select(2,1) + mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture5 gater gradInput") + local expertGradInput2 = torch.cmul(input[1]:view(3,1):expand(3,6), gradOutput:view(1,6):expand(3,6)) + for i, expertGradInput in ipairs(gradInput[2]) do + mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i), 0.000001, "mixture5 expert "..i.." gradInput") + end + -- expertInput is a Tensor: + local input = {input[1], expertInput} + local module = nn.MixtureTable(1) + local output = module:forward(input) + mytester:assertTensorEq(output, output2, 0.000001, "mixture6 output") + local gradInput = module:backward(input, gradOutput) + mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture6 gater gradInput") + mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture6 expert gradInput") end function nntest.View() |