Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-01-07 00:37:32 +0300
committernicholas-leonard <nick@nikopia.org>2015-01-07 00:37:32 +0300
commitbd7b315d7427f8320d44d8b29077154db87df633 (patch)
treebaf59604cc58335e530b69553de893adc6747dbe /test.lua
parent4e0a96d801060121521ccc46f7294aeb3b247965 (diff)
new Copy constructor arguments
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua14
1 files changed, 14 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 3cf6a58..38a659d 100644
--- a/test.lua
+++ b/test.lua
@@ -2181,6 +2181,20 @@ function nntest.MulConstant()
mytester:assertlt(err, precision, 'bprop error ')
end
+function nntest.Copy()
+ local input = torch.randn(3,4):double()
+ local c = nn.Copy('torch.DoubleTensor', 'torch.FloatTensor')
+ local output = c:forward(input)
+ mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err')
+ mytester:assertTensorEq(output, input:float(), 0.000001, 'copy forward value err')
+ local gradInput = c:backward(input, output)
+ mytester:assert(torch.type(gradInput) == 'torch.DoubleTensor', 'copy backward type err')
+ mytester:assertTensorEq(gradInput, input, 0.000001, 'copy backward value err')
+ c.dontCast = true
+ c:double()
+ mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err')
+end
+
function nntest.JoinTable()
local tensor = torch.rand(3,4,5)
local input = {tensor, tensor}