diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-12-23 05:10:27 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-12-23 05:10:27 +0300 |
commit | afedaa89daa5218e25ad2283c5e9649539a426e1 (patch) | |
tree | b955021c1cb931125cae820f57853ac8b3bc84bc | |
parent | f0eaec22bfe897b4f38db13d09839a0b1cb9a944 (diff) |
testIO
-rw-r--r-- | test/test.lua | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 8c97ffa..6f75057 100644 --- a/test/test.lua +++ b/test/test.lua @@ -5,9 +5,11 @@ local cudnntest = {} local precision_forward = 1e-4 local precision_backward = 1e-2 local precision_jac = 1e-3 +local precision_io = 1e-5 local nloop = 1 local times = {} local mytester +local jac = nn.Jacobian function cudnntest.SpatialConvolution_forward_batch() local bs = math.random(1,32) @@ -32,6 +34,11 @@ function cudnntest.SpatialConvolution_forward_batch() cutorch.synchronize() local error = rescuda:float() - groundtruth:float() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end @@ -190,6 +197,11 @@ function cudnntest.VolumetricConvolution_forward_single() local error = rescuda:float() - groundtruth:float() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end function cudnntest.VolumetricConvolution_backward_single() @@ -285,6 +297,11 @@ function cudnntest.VolumetricMaxPooling_batch() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') error = resgrad:float() - groundgrad:float() mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end function cudnntest.VolumetricMaxPooling_single() @@ -363,6 +380,11 @@ function cudnntest.SpatialMaxPooling_batch() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') error = resgrad:float() - groundgrad:float() mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end function cudnntest.SpatialMaxPooling_single() @@ -437,6 +459,11 @@ function cudnntest.SpatialAveragePooling_batch() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') error = resgrad:float() - groundgrad:float() mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end function cudnntest.SpatialAveragePooling_single() @@ -667,6 +694,11 @@ function cudnntest.SoftMax_batch() error = resgrad:float() - groundgrad:float() mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end @@ -728,6 +760,11 @@ function cudnntest.LogSoftMax_batch() error = resgrad:float() - groundgrad:float() mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') end function cudnntest.SpatialLogSoftMax() |