From 59e3f8cb4b6571fd46b0ba66e57626f725fbe81f Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Wed, 15 Jul 2015 12:28:48 -0400 Subject: NarrowTable --- test.lua | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) (limited to 'test.lua') diff --git a/test.lua b/test.lua index bfece0e..92b686f 100644 --- a/test.lua +++ b/test.lua @@ -3116,6 +3116,45 @@ function nntest.MixtureTable() end end +function nntest.NarrowTable() + local input = torch.randn(3,10,4) + local gradOutput = torch.randn(3,3,4) + local nt = nn.NarrowTable(5,3) + local seq = nn.Sequential() + seq:add(nn.SplitTable(1,2)) + seq:add(nt) + seq:add(nn.JoinTable(1,1)) + seq:add(nn.Reshape(3,3,4)) + local seq2 = nn.Narrow(2,5,3) + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err") + + -- now try it with a smaller input + local input = input:narrow(2, 1, 8) + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err") + + -- test type-cast + local input = input:float() + local gradOutput = gradOutput:float() + seq:float() + seq2:float() + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err") +end + function nntest.View() local input = torch.rand(10) local template = torch.rand(5,2) -- cgit v1.2.3