diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-17 07:46:21 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-17 07:46:21 +0300 |
commit | a8e63f2da3d3d84a7e1eed917572901a9ffba5d9 (patch) | |
tree | 721c68a6fd4cfd50895228992782ba1694580ae4 /test.lua | |
parent | 7c17669434bda5741b108e023fe37fa7d7854fef (diff) | |
parent | fcdf644d0d2986932ce38149b05778225b2c9b5d (diff) |
Merge pull request #1006 from torch/errorsimprovements
more improvments on error messages and shape checks
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 13 |
1 files changed, 2 insertions, 11 deletions
@@ -2482,10 +2482,6 @@ function nntest.SpatialConvolution() module = nn.SpatialConvolution(from, to, ki, kj, si, sj) input = torch.Tensor(batch,from,inj,ini):zero() - -- print(from, to, ki, kj, si, sj, batch, ini, inj) - -- print(module.weight:size()) - -- print(module.gradWeight:size()) - local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') @@ -2709,11 +2705,7 @@ function nntest.SpatialConvolutionLocal() ini = (outi-1)*si+ki inj = (outj-1)*sj+kj module = nn.SpatialConvolutionLocal(from, to, ini, inj, ki, kj, si, sj) - input = torch.Tensor(batch,from,inj,ini):zero() - --- print(from, to, ki, kj, si, sj, batch, ini, inj) --- print(module.weight:size()) --- print(module.gradWeight:size()) + input = torch.Tensor(batch, from, inj, ini):zero() local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'batch error on state ') @@ -2755,14 +2747,13 @@ function nntest.SpatialConvolutionLocal() -- check against nn.SpatialConvolution local conv = nn.SpatialConvolution(from, to, ki, kj, si, sj) - torch.repeatTensor(module.bias, conv.bias:view(to, 1, 1), 1, outi, outj) + torch.repeatTensor(module.bias, conv.bias:view(to, 1, 1), 1, outj, outi) torch.repeatTensor(module.weight, conv.weight:view(1, 1, from, to, ki, kj), outi, outj, 1, 1, 1, 1) local input = torch.rand(batch, from, inj, ini) local output = module:forward(input) local outputConv = conv:forward(input) local err = torch.dist(output, outputConv) mytester:assertlt(err, precision, 'error checking against nn.SpatialConvolution') - end function nntest.SpatialFullConvolution() |