Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-07-15 00:56:26 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-07-15 01:27:12 +0400
commitf19810729f293a0ecbd094572e87972a4adf9502 (patch)
treeecbf3569f8f17e8149509d539dd304631274f93a /test
parent437c4940455ff8723a7318b2cb39a0970edde108 (diff)
added FlattenTable module, a unit test and a element in the documentation.
fixed bad formatting in table.md
Diffstat (limited to 'test')
-rw-r--r--test/test.lua79
1 files changed, 79 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index a5550b7..b1b4f90 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2045,6 +2045,85 @@ function nntest.ConcatTable()
mytester:assertlt(err, precision, ' getParameters error ')
end
+function nntest.FlattenTable()
+ -- Create a nested table. Obviously we can't even stochastically test
+ -- the space of all possible nested tables (it's infinite), but here is a
+ -- hand-coded one that covers all the cases we need:
+ local input = {
+ torch.rand(1),
+ {
+ torch.rand(2),
+ {
+ torch.rand(3)
+ },
+ },
+ torch.rand(4)
+ }
+ local gradOutput = {
+ torch.rand(1),
+ torch.rand(2),
+ torch.rand(3),
+ torch.rand(4)
+ }
+
+ -- Check the FPROP
+ local m = nn.FlattenTable()
+ local output = m:forward(input)
+ mytester:assert(#output == 4, torch.typename(m)..' - fprop err ')
+ -- This is ugly, but check that the mapping from input to output is correct
+ mytester:assert(output[1] == input[1])
+ mytester:assert(output[2] == input[2][1])
+ mytester:assert(output[3] == input[2][2][1])
+ mytester:assert(output[4] == input[3])
+
+ -- Check the BPROP
+ local gradInput = m:backward(input, gradOutput)
+ -- Again, check that the mapping is correct
+ mytester:assert(gradOutput[1] == gradInput[1])
+ mytester:assert(gradOutput[2] == gradInput[2][1])
+ mytester:assert(gradOutput[3] == gradInput[2][2][1])
+ mytester:assert(gradOutput[4] == gradInput[3])
+
+ -- More uglyness: FlattenTable doesn't rebuild the table every updateOutput
+ -- call, so we need to make sure that modifications to the input are
+ -- detected correctly (and that the table is correctly rebuilt.
+ -- CASE 1: Nothing changes so the output table shouldn't be redefined
+ local old_input_map = m.input_map
+ local old_output = m.output
+ output = m:forward(input)
+ mytester:assert(old_input_map == m.input_map and old_output == m.output)
+
+ -- CASE 2: An element is added to the input table
+ old_input_map = m.input_map
+ old_output = m.output
+ input[2][#(input[2])+1] = torch.rand(5)
+ m:forward(input)
+ mytester:assert(old_input_map ~= m.input_map and old_output ~= m.output)
+
+ -- CASE 3: An element is removed from the input table
+ old_input_map = m.input_map
+ old_output = m.output
+ input[#input] = nil
+ m:forward(input)
+ mytester:assert(old_input_map ~= m.input_map and old_output ~= m.output)
+
+ -- At this point further testing is not necessary I think, but just to be
+ -- consistent: perform a jacobian test by using SplitTable and JointTable
+ -- elements
+ m = nn.Sequential()
+ local par = nn.ParallelTable()
+ par:add(nn.SplitTable(1))
+ par:add(nn.SplitTable(1))
+ m:add(nn.SplitTable(1))
+ m:add(par) -- this will create a nested table
+ m:add(nn.FlattenTable()) -- This will flatten the nested table
+ m:add(nn.JoinTable(1)) -- Finally, this will create a 1D tensor
+
+ input = torch.Tensor(2,2,2)
+ local err = jac.testJacobian(m, input)
+ mytester:assertlt(err, precision, 'error on bprop ')
+end
+
mytester:add(nntest)
if not nn then