Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-16 01:49:22 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-16 01:49:22 +0400
commite074440aaf6f02491c9685d4e254ecb159622414 (patch)
tree7cd9ce5eb020094f38610118883de4418162175b /test
parent7f0f53416f019bb980c9e922b72048ab5ea4fd77 (diff)
2D gater 1D experts unit test
Diffstat (limited to 'test')
-rw-r--r--test/test.lua21
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index b3e7ca9..8c17e90 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2062,6 +2062,27 @@ function nntest.MixtureTable()
local gradInput = module:backward(input2, gradOutput:float())
mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture6B gater gradInput")
mytester:assertTensorEq(gradInput[2], expertGradInput2:float(), 0.000001, "mixture6B expert gradInput")
+
+ --[[ 2D gater, 1D expert]]--
+ -- expertInput is a Table:
+ local expertInput = torch.randn(5,3)
+ local gradOutput = torch.randn(5)
+ local input = {
+ torch.rand(5,3),
+ {expertInput:select(2,1), expertInput:select(2,2), expertInput:select(2,3)}
+ }
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1], expertInput):sum(2)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture7 output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 5, 1):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture7 gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1], gradOutput:view(5,1):expand(5,3))
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(2,i), 0.000001, "mixture7 expert "..i.." gradInput")
+ end
end
function nntest.View()