diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-15 00:56:26 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-15 01:27:12 +0400 |
commit | f19810729f293a0ecbd094572e87972a4adf9502 (patch) | |
tree | ecbf3569f8f17e8149509d539dd304631274f93a /test | |
parent | 437c4940455ff8723a7318b2cb39a0970edde108 (diff) |
added FlattenTable module, a unit test and a element in the documentation.
fixed bad formatting in table.md
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index a5550b7..b1b4f90 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2045,6 +2045,85 @@ function nntest.ConcatTable() mytester:assertlt(err, precision, ' getParameters error ') end +function nntest.FlattenTable() + -- Create a nested table. Obviously we can't even stochastically test + -- the space of all possible nested tables (it's infinite), but here is a + -- hand-coded one that covers all the cases we need: + local input = { + torch.rand(1), + { + torch.rand(2), + { + torch.rand(3) + }, + }, + torch.rand(4) + } + local gradOutput = { + torch.rand(1), + torch.rand(2), + torch.rand(3), + torch.rand(4) + } + + -- Check the FPROP + local m = nn.FlattenTable() + local output = m:forward(input) + mytester:assert(#output == 4, torch.typename(m)..' - fprop err ') + -- This is ugly, but check that the mapping from input to output is correct + mytester:assert(output[1] == input[1]) + mytester:assert(output[2] == input[2][1]) + mytester:assert(output[3] == input[2][2][1]) + mytester:assert(output[4] == input[3]) + + -- Check the BPROP + local gradInput = m:backward(input, gradOutput) + -- Again, check that the mapping is correct + mytester:assert(gradOutput[1] == gradInput[1]) + mytester:assert(gradOutput[2] == gradInput[2][1]) + mytester:assert(gradOutput[3] == gradInput[2][2][1]) + mytester:assert(gradOutput[4] == gradInput[3]) + + -- More uglyness: FlattenTable doesn't rebuild the table every updateOutput + -- call, so we need to make sure that modifications to the input are + -- detected correctly (and that the table is correctly rebuilt. + -- CASE 1: Nothing changes so the output table shouldn't be redefined + local old_input_map = m.input_map + local old_output = m.output + output = m:forward(input) + mytester:assert(old_input_map == m.input_map and old_output == m.output) + + -- CASE 2: An element is added to the input table + old_input_map = m.input_map + old_output = m.output + input[2][#(input[2])+1] = torch.rand(5) + m:forward(input) + mytester:assert(old_input_map ~= m.input_map and old_output ~= m.output) + + -- CASE 3: An element is removed from the input table + old_input_map = m.input_map + old_output = m.output + input[#input] = nil + m:forward(input) + mytester:assert(old_input_map ~= m.input_map and old_output ~= m.output) + + -- At this point further testing is not necessary I think, but just to be + -- consistent: perform a jacobian test by using SplitTable and JointTable + -- elements + m = nn.Sequential() + local par = nn.ParallelTable() + par:add(nn.SplitTable(1)) + par:add(nn.SplitTable(1)) + m:add(nn.SplitTable(1)) + m:add(par) -- this will create a nested table + m:add(nn.FlattenTable()) -- This will flatten the nested table + m:add(nn.JoinTable(1)) -- Finally, this will create a 1D tensor + + input = torch.Tensor(2,2,2) + local err = jac.testJacobian(m, input) + mytester:assertlt(err, precision, 'error on bprop ') +end + mytester:add(nntest) if not nn then |