Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-14 03:42:14 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-14 03:42:14 +0400
commit920c2d78bcb5aeb20dfef4b9b20335a5d1bcbc6e (patch)
tree390c6058b138340aee71702a806858a920e9d7c5 /test
parentba82d19f47ed2e377594f72fad60498efda4b5c8 (diff)
MixtureTable works with 1D inputs (unit tested)
Diffstat (limited to 'test')
-rw-r--r--test/test.lua30
1 files changed, 29 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua
index 00d78b1..39db431 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2010,9 +2010,37 @@ function nntest.MixtureTable()
local output = module:forward(input)
mytester:assertTensorEq(output, output2, 0.000001, "mixture4 output")
local gradInput = module:backward(input, gradOutput)
- print(gradInput[1], gaterGradInput2)
mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture4 gater gradInput")
mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture4 expert gradInput")
+
+ --[[ 1D ]]--
+ -- expertInput is a Table:
+ local expertInput = torch.randn(3,6)
+ local gradOutput = torch.randn(6)
+ local input = {
+ torch.rand(3),
+ {expertInput:select(1,1), expertInput:select(1,2), expertInput:select(1,3)}
+ }
+ local module = nn.MixtureTable()
+ local output = module:forward(input)
+ local output2 = torch.cmul(input[1]:view(3,1):expand(3,6), expertInput):sum(1)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture5 output")
+ local gradInput = module:backward(input, gradOutput)
+ local gradOutput2 = torch.view(gradOutput, 1, 6):expandAs(expertInput)
+ local gaterGradInput2 = torch.cmul(gradOutput2, expertInput):sum(2):select(2,1)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture5 gater gradInput")
+ local expertGradInput2 = torch.cmul(input[1]:view(3,1):expand(3,6), gradOutput:view(1,6):expand(3,6))
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i), 0.000001, "mixture5 expert "..i.." gradInput")
+ end
+ -- expertInput is a Tensor:
+ local input = {input[1], expertInput}
+ local module = nn.MixtureTable(1)
+ local output = module:forward(input)
+ mytester:assertTensorEq(output, output2, 0.000001, "mixture6 output")
+ local gradInput = module:backward(input, gradOutput)
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture6 gater gradInput")
+ mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture6 expert gradInput")
end
function nntest.View()