Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-14 18:55:07 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-14 18:55:07 +0400
commit7f0f53416f019bb980c9e922b72048ab5ea4fd77 (patch)
treec98565dee570bb097410b47102173db07635938a /test
parentf0a5f6c8926af4d3c068f1ffbf7290c3e7b21068 (diff)
MixtureTable type-cast unit tests
Diffstat (limited to 'test')
-rw-r--r--test/test.lua21
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 39db431..b3e7ca9 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2033,6 +2033,19 @@ function nntest.MixtureTable()
for i, expertGradInput in ipairs(gradInput[2]) do
mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i), 0.000001, "mixture5 expert "..i.." gradInput")
end
+ -- test type-cast
+ module:float()
+ local input2 = {
+ input[1]:float(),
+ {input[2][1]:float(), input[2][2]:float(), input[2][3]:float()}
+ }
+ local output = module:forward(input2)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "mixture5B output")
+ local gradInput = module:backward(input2, gradOutput:float())
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture5B gater gradInput")
+ for i, expertGradInput in ipairs(gradInput[2]) do
+ mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i):float(), 0.000001, "mixture5B expert "..i.." gradInput")
+ end
-- expertInput is a Tensor:
local input = {input[1], expertInput}
local module = nn.MixtureTable(1)
@@ -2041,6 +2054,14 @@ function nntest.MixtureTable()
local gradInput = module:backward(input, gradOutput)
mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture6 gater gradInput")
mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture6 expert gradInput")
+ -- test type-cast:
+ module:float()
+ local input2 = {input[1]:float(), expertInput:float()}
+ local output = module:forward(input2)
+ mytester:assertTensorEq(output, output2:float(), 0.000001, "mixture6B output")
+ local gradInput = module:backward(input2, gradOutput:float())
+ mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture6B gater gradInput")
+ mytester:assertTensorEq(gradInput[2], expertGradInput2:float(), 0.000001, "mixture6B expert gradInput")
end
function nntest.View()