diff options
author | Nicholas Leonard <nick@nikopia.org> | 2015-07-15 19:28:48 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2015-07-15 21:24:04 +0300 |
commit | 59e3f8cb4b6571fd46b0ba66e57626f725fbe81f (patch) | |
tree | e0478aa9067338664b40a734f08c13bc23c49f7c /test.lua | |
parent | 514a093f5e9a76a04f3252d58259673ef4ff71bb (diff) |
NarrowTable
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 39 |
1 files changed, 39 insertions, 0 deletions
@@ -3116,6 +3116,45 @@ function nntest.MixtureTable() end end +function nntest.NarrowTable() + local input = torch.randn(3,10,4) + local gradOutput = torch.randn(3,3,4) + local nt = nn.NarrowTable(5,3) + local seq = nn.Sequential() + seq:add(nn.SplitTable(1,2)) + seq:add(nt) + seq:add(nn.JoinTable(1,1)) + seq:add(nn.Reshape(3,3,4)) + local seq2 = nn.Narrow(2,5,3) + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err") + + -- now try it with a smaller input + local input = input:narrow(2, 1, 8) + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err") + + -- test type-cast + local input = input:float() + local gradOutput = gradOutput:float() + seq:float() + seq2:float() + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err") +end + function nntest.View() local input = torch.rand(10) local template = torch.rand(5,2) |