Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-11 22:13:15 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-11 22:13:15 +0400
commitc40edcad5cb0bff5aeec53ebff6634887978225f (patch)
tree98c796396620128bc73ac45b158a2d8fde552e6e /test
parent2d56aa7e6dbc33c49664118437afef0c0882b6d7 (diff)
MixtureTable unit tests (expertInput is a Tensor)
Diffstat (limited to 'test')
-rw-r--r--test/test.lua15
1 files changed, 14 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua
index 8eab9a1..802e553 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1956,11 +1956,12 @@ end
function nntest.MixtureTable()
local expertInput = torch.randn(5,3,6)
+ local gradOutput = torch.randn(5,6)
+ -- expertInput is a Table:
local input = {
torch.rand(5,3),
{expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)}
}
- local gradOutput = torch.randn(5,6)
local module = nn.MixtureTable()
local output = module:forward(input)
local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2)
@@ -1973,6 +1974,18 @@ function nntest.MixtureTable()
for i, expertGradInput in ipairs(gradInput[2]) do
mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture expert "..i.." gradInput")
end
+ -- expertInput is a Tensor:
+ local input = {input[1], expertInput}
+ local module = nn.MixtureTable(2)
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), expertInput):sum(2)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture2 output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 5, 1, 6):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(3):select(3,1)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture2 gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1]:view(5,3,1):expand(5,3,6), gradOutput:view(5,1,6):expand(5,3,6))
+ mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture2 expert gradInput")
end
function nntest.View()