diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-01-07 00:37:32 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-01-07 00:37:32 +0300 |
commit | bd7b315d7427f8320d44d8b29077154db87df633 (patch) | |
tree | baf59604cc58335e530b69553de893adc6747dbe /test.lua | |
parent | 4e0a96d801060121521ccc46f7294aeb3b247965 (diff) |
new Copy constructor arguments
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 14 |
1 files changed, 14 insertions, 0 deletions
@@ -2181,6 +2181,20 @@ function nntest.MulConstant() mytester:assertlt(err, precision, 'bprop error ') end +function nntest.Copy() + local input = torch.randn(3,4):double() + local c = nn.Copy('torch.DoubleTensor', 'torch.FloatTensor') + local output = c:forward(input) + mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err') + mytester:assertTensorEq(output, input:float(), 0.000001, 'copy forward value err') + local gradInput = c:backward(input, output) + mytester:assert(torch.type(gradInput) == 'torch.DoubleTensor', 'copy backward type err') + mytester:assertTensorEq(gradInput, input, 0.000001, 'copy backward value err') + c.dontCast = true + c:double() + mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err') +end + function nntest.JoinTable() local tensor = torch.rand(3,4,5) local input = {tensor, tensor} |