diff options
author | soumith <soumith@fb.com> | 2014-12-20 10:30:08 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2014-12-20 10:30:08 +0300 |
commit | 035863d94fb41b48ccd0babf7055c6bb719bbf8f (patch) | |
tree | bff149765b13aa311012e47d4b4e2613b273e3e8 /test | |
parent | a2def9658f9df262a252ff71e9dfd310ab722b13 (diff) |
lint fixes (80 columns)
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/test/test.lua b/test/test.lua index 9931431..03f7c2f 100644 --- a/test/test.lua +++ b/test/test.lua @@ -179,7 +179,7 @@ function cudnntest.VolumetricConvolution_forward_single() local inj = (outj-1)*sj+kj local ink = (outk-1)*sk+kk local input = torch.randn(from,ink,inj,ini):cuda() - local sconv = nn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):float() --:cuda() + local sconv = nn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):float() local groundtruth = sconv:forward(input:float()) cutorch.synchronize() local gconv = cudnn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):cuda() @@ -188,7 +188,8 @@ function cudnntest.VolumetricConvolution_forward_single() local rescuda = gconv:forward(input) cutorch.synchronize() local error = rescuda:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') + mytester:assertlt(error:abs():max(), precision_forward, + 'error on state (forward) ') end function cudnntest.VolumetricConvolution_backward_single() @@ -208,8 +209,8 @@ function cudnntest.VolumetricConvolution_backward_single() local ink = (outk-1)*sk+kk local input = torch.randn(from,ink,inj,ini):cuda() local gradOutput = torch.randn(to,outk,outj,outi):cuda() - local sconv = nn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):float() --:cuda() - local groundtruth = sconv:forward(input:float()) + local sconv = nn.VolumetricConvolution(from,to,kk,ki,kj,sk,si,sj):float() + sconv:forward(input:float()) sconv:zeroGradParameters() local groundgrad = sconv:backward(input:float(), gradOutput:float()) cutorch.synchronize() |