Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-07-15 19:28:48 +0300
committerNicholas Leonard <nick@nikopia.org>2015-07-15 21:24:04 +0300
commit59e3f8cb4b6571fd46b0ba66e57626f725fbe81f (patch)
treee0478aa9067338664b40a734f08c13bc23c49f7c /test.lua
parent514a093f5e9a76a04f3252d58259673ef4ff71bb (diff)
NarrowTable
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua39
1 files changed, 39 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index bfece0e..92b686f 100644
--- a/test.lua
+++ b/test.lua
@@ -3116,6 +3116,45 @@ function nntest.MixtureTable()
end
end
+function nntest.NarrowTable()
+ local input = torch.randn(3,10,4)
+ local gradOutput = torch.randn(3,3,4)
+ local nt = nn.NarrowTable(5,3)
+ local seq = nn.Sequential()
+ seq:add(nn.SplitTable(1,2))
+ seq:add(nt)
+ seq:add(nn.JoinTable(1,1))
+ seq:add(nn.Reshape(3,3,4))
+ local seq2 = nn.Narrow(2,5,3)
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err")
+
+ -- now try it with a smaller input
+ local input = input:narrow(2, 1, 8)
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err")
+
+ -- test type-cast
+ local input = input:float()
+ local gradOutput = gradOutput:float()
+ seq:float()
+ seq2:float()
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err")
+end
+
function nntest.View()
local input = torch.rand(10)
local template = torch.rand(5,2)