diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-14 18:55:07 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-14 18:55:07 +0400 |
commit | 7f0f53416f019bb980c9e922b72048ab5ea4fd77 (patch) | |
tree | c98565dee570bb097410b47102173db07635938a /test | |
parent | f0a5f6c8926af4d3c068f1ffbf7290c3e7b21068 (diff) |
MixtureTable type-cast unit tests
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 39db431..b3e7ca9 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2033,6 +2033,19 @@ function nntest.MixtureTable() for i, expertGradInput in ipairs(gradInput[2]) do mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i), 0.000001, "mixture5 expert "..i.." gradInput") end + -- test type-cast + module:float() + local input2 = { + input[1]:float(), + {input[2][1]:float(), input[2][2]:float(), input[2][3]:float()} + } + local output = module:forward(input2) + mytester:assertTensorEq(output, output2:float(), 0.000001, "mixture5B output") + local gradInput = module:backward(input2, gradOutput:float()) + mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture5B gater gradInput") + for i, expertGradInput in ipairs(gradInput[2]) do + mytester:assertTensorEq(expertGradInput, expertGradInput2:select(1,i):float(), 0.000001, "mixture5B expert "..i.." gradInput") + end -- expertInput is a Tensor: local input = {input[1], expertInput} local module = nn.MixtureTable(1) @@ -2041,6 +2054,14 @@ function nntest.MixtureTable() local gradInput = module:backward(input, gradOutput) mytester:assertTensorEq(gradInput[1], gaterGradInput2, 0.000001, "mixture6 gater gradInput") mytester:assertTensorEq(gradInput[2], expertGradInput2, 0.000001, "mixture6 expert gradInput") + -- test type-cast: + module:float() + local input2 = {input[1]:float(), expertInput:float()} + local output = module:forward(input2) + mytester:assertTensorEq(output, output2:float(), 0.000001, "mixture6B output") + local gradInput = module:backward(input2, gradOutput:float()) + mytester:assertTensorEq(gradInput[1], gaterGradInput2:float(), 0.000001, "mixture6B gater gradInput") + mytester:assertTensorEq(gradInput[2], expertGradInput2:float(), 0.000001, "mixture6B expert gradInput") end function nntest.View() |