diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-03-04 19:50:12 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-04 19:50:12 +0300 |
commit | 6f16be18fabd81e563924c9c7e4203026258be8d (patch) | |
tree | ba683ff581f8cfea570eeae00d24bbc96a530270 | |
parent | 06678ddcf82067ae621bc5e16292b82166172658 (diff) | |
parent | d10c2cfd7eb074a05621304cb672b6ca78693c2d (diff) |
Merge pull request #963 from torch/ssefix
Fix critical bug in SSE scalar add + improve tests
-rw-r--r-- | lib/TH/CMakeLists.txt | 10 | ||||
-rw-r--r-- | lib/TH/vector/SSE.c | 2 | ||||
-rw-r--r-- | test/test.lua | 431 |
3 files changed, 256 insertions, 187 deletions
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt index 351594d..d26d4c2 100644 --- a/lib/TH/CMakeLists.txt +++ b/lib/TH/CMakeLists.txt @@ -209,18 +209,20 @@ ENDIF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND) # IF AVX FOUND IF(C_AVX_FOUND) IF(MSVC) - SET_SOURCE_FILES_PROPERTIES(vector/AVX.c generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast ${C_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast ${C_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(vector/AVX.c PROPERTIES COMPILE_FLAGS "/Ox ${C_AVX_FLAGS}") ELSE(MSVC) - SET_SOURCE_FILES_PROPERTIES(vector/AVX.c generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${C_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${C_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(vector/AVX.c PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX_FLAGS}") ENDIF(MSVC) SET(simd ${simd} vector/AVX.c generic/simd/convolve5x5_avx.c) ENDIF(C_AVX_FOUND) IF(C_AVX2_FOUND) IF(MSVC) - SET_SOURCE_FILES_PROPERTIES(vector/AVX2.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast ${C_AVX2_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(vector/AVX2.c PROPERTIES COMPILE_FLAGS "/Ox ${C_AVX2_FLAGS}") ELSE(MSVC) - SET_SOURCE_FILES_PROPERTIES(vector/AVX2.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${C_AVX2_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(vector/AVX2.c PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX2_FLAGS}") ENDIF(MSVC) SET(simd ${simd} vector/AVX2.c) ENDIF(C_AVX2_FOUND) diff --git a/lib/TH/vector/SSE.c b/lib/TH/vector/SSE.c index 01ac789..d026935 100644 --- a/lib/TH/vector/SSE.c +++ b/lib/TH/vector/SSE.c @@ -172,7 +172,7 @@ static void THFloatVector_adds_SSE(float *y, const float *x, const float c, cons ptrdiff_t i; __m128 XMM7 = _mm_set1_ps(c); __m128 XMM0, XMM2; - for (i=8; i<=((n)-8); i+=8) { + for (i=0; i<=((n)-8); i+=8) { XMM0 = _mm_loadu_ps((x)+i); XMM2 = _mm_loadu_ps((x)+i+4); XMM0 = _mm_add_ps(XMM0, XMM7); diff --git a/test/test.lua b/test/test.lua index 47e0a0f..6221854 100644 --- a/test/test.lua +++ b/test/test.lua @@ -630,63 +630,75 @@ function torchtest.fill() end function torchtest.add() - -- [res] torch.add([res,] tensor1, tensor2) - local m1 = torch.randn(100,100) - local v1 = torch.randn(100) + local types = { + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + } - local res1 = torch.add(m1[{ 4,{} }],v1) + for k,t in ipairs(types) do + -- [res] torch.add([res,] tensor1, tensor2) + local m1 = torch.randn(100,100):type(t) + local v1 = torch.randn(100):type(t) - local res2 = res1:clone():zero() - for i = 1,m1:size(2) do - res2[i] = m1[4][i] + v1[i] - end + local res1 = torch.add(m1[{ 4,{} }],v1) - local err = (res1-res2):abs():max() + local res2 = res1:clone():zero() + for i = 1,m1:size(2) do + res2[i] = m1[4][i] + v1[i] + end - mytester:assertlt(err, precision, 'error in torch.add - contiguous') + local err = (res1-res2):double():abs():max() - local m1 = torch.randn(100,100) - local v1 = torch.randn(100) + mytester:assertlt(err, precision, 'error in torch.add - contiguous' .. ' ' .. t) - local res1 = torch.add(m1[{ {},4 }],v1) + local m1 = torch.randn(100,100):type(t) + local v1 = torch.randn(100):type(t) - local res2 = res1:clone():zero() - for i = 1,m1:size(1) do - res2[i] = m1[i][4] + v1[i] - end + local res1 = torch.add(m1[{ {},4 }],v1) - local err = (res1-res2):abs():max() + local res2 = res1:clone():zero() + for i = 1,m1:size(1) do + res2[i] = m1[i][4] + v1[i] + end - mytester:assertlt(err, precision, 'error in torch.add - non contiguous') + local err = (res1-res2):double():abs():max() - -- [res] torch.add([res,] tensor, value) - local m1 = torch.randn(10,10) - local res1 = m1:clone() - res1[{ 3,{} }]:add(2) + mytester:assertlt(err, precision, 'error in torch.add - non contiguous' .. ' ' .. t) - local res2 = m1:clone() - for i = 1,m1:size(1) do - res2[{ 3,i }] = res2[{ 3,i }] + 2 - end + -- [res] torch.add([res,] tensor, value) + local m1 = torch.randn(10,10):type(t) + local res1 = m1:clone() + res1[{ 3,{} }]:add(2) - local err = (res1-res2):abs():max() + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ 3,i }] = res2[{ 3,i }] + 2 + end - mytester:assertlt(err, precision, 'error in torch.add - scalar, contiguous') + local err = (res1-res2):double():abs():max() - local m1 = torch.randn(10,10) - local res1 = m1:clone() - res1[{ {},3 }]:add(2) + mytester:assertlt(err, precision, 'error in torch.add - scalar, contiguous' .. ' ' .. t) - local res2 = m1:clone() - for i = 1,m1:size(1) do - res2[{ i,3 }] = res2[{ i,3 }] + 2 - end + local m1 = torch.randn(10,10) + local res1 = m1:clone() + res1[{ {},3 }]:add(2) - local err = (res1-res2):abs():max() + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ i,3 }] = res2[{ i,3 }] + 2 + end + + local err = (res1-res2):abs():max() - mytester:assertlt(err, precision, 'error in torch.add - scalar, non contiguous') + mytester:assertlt(err, precision, 'error in torch.add - scalar, non contiguous' .. ' ' .. t) - -- [res] torch.add([res,] tensor1, value, tensor2) + -- [res] torch.add([res,] tensor1, value, tensor2) + end end function torchtest.csub() @@ -754,35 +766,60 @@ function torchtest.cinv() end function torchtest.mul() - local m1 = torch.randn(10,10) - local res1 = m1:clone() + local types = { + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + } - res1[{ {},3 }]:mul(2) + for k,t in ipairs(types) do + local m1 = torch.randn(10,10):type(t) + local res1 = m1:clone() - local res2 = m1:clone() - for i = 1,m1:size(1) do - res2[{ i,3 }] = res2[{ i,3 }] * 2 - end + res1[{ {},3 }]:mul(2) - local err = (res1-res2):abs():max() + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ i,3 }] = res2[{ i,3 }] * 2 + end + + local err = (res1-res2):double():abs():max() - mytester:assertlt(err, precision, 'error in torch.mul - scalar, non contiguous') + mytester:assertlt(err, precision, 'error in torch.mul - scalar, non contiguous' .. ' ' .. t) + end end function torchtest.div() - local m1 = torch.randn(10,10) - local res1 = m1:clone() - - res1[{ {},3 }]:div(2) - - local res2 = m1:clone() - for i = 1,m1:size(1) do - res2[{ i,3 }] = res2[{ i,3 }] / 2 - end + local types = { + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + } + + for k,t in ipairs(types) do + + local m1 = torch.randn(10,10):type(t) + local res1 = m1:clone() + + res1[{ {},3 }]:div(2) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ i,3 }] = res2[{ i,3 }] / 2 + end - local err = (res1-res2):abs():max() + local err = (res1-res2):double():abs():max() - mytester:assertlt(err, precision, 'error in torch.div - scalar, non contiguous') + mytester:assertlt(err, precision, 'error in torch.div - scalar, non contiguous' .. ' ' .. t) + end end function torchtest.lshift() @@ -1221,68 +1258,84 @@ function torchtest.pow() -- [res] torch.pow([res,] x) mytester:assertlt(maxerr, precision, 'error in torch.pow - non-contiguous') end -function torchtest.cdiv() -- [res] torch.cdiv([res,] tensor1, tensor2) - -- contiguous - local m1 = torch.randn(10, 10, 10) - local m2 = torch.randn(10, 10 * 10) - local sm1 = m1[{4, {}, {}}] - local sm2 = m2[{4, {}}] - local res1 = torch.cdiv(sm1, sm2) - local res2 = res1:clone():zero() - for i = 1,sm1:size(1) do - for j = 1, sm1:size(2) do - local idx1d = (((i-1)*sm1:size(1)))+j - res2[i][j] = sm1[i][j] / sm2[idx1d] - end - end - local err = res1:clone():zero() - -- find absolute error - for i = 1, res1:size(1) do - for j = 1, res1:size(2) do - err[i][j] = math.abs(res1[i][j] - res2[i][j]) - end - end - -- find maximum element of error - local maxerr = 0 - for i = 1, err:size(1) do - for j = 1, err:size(2) do - if err[i][j] > maxerr then - maxerr = err[i][j] - end - end - end - mytester:assertlt(maxerr, precision, 'error in torch.cdiv - contiguous') - - -- non-contiguous - local m1 = torch.randn(10, 10, 10) - local m2 = torch.randn(10 * 10, 10 * 10) - local sm1 = m1[{{}, 4, {}}] - local sm2 = m2[{{}, 4}] - local res1 = torch.cdiv(sm1, sm2) - local res2 = res1:clone():zero() - for i = 1,sm1:size(1) do - for j = 1, sm1:size(2) do - local idx1d = (((i-1)*sm1:size(1)))+j - res2[i][j] = sm1[i][j] / sm2[idx1d] - end - end - local err = res1:clone():zero() - -- find absolute error - for i = 1, res1:size(1) do - for j = 1, res1:size(2) do - err[i][j] = math.abs(res1[i][j] - res2[i][j]) - end - end - -- find maximum element of error - local maxerr = 0 - for i = 1, err:size(1) do - for j = 1, err:size(2) do - if err[i][j] > maxerr then - maxerr = err[i][j] - end - end +function torchtest.cdiv() + local types = { + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + } + + for k,t in ipairs(types) do + + -- [res] torch.cdiv([res,] tensor1, tensor2) + -- contiguous + local m1 = torch.randn(10, 10, 10):type(t) + local m2 = torch.randn(10, 10 * 10):type(t) + m2[m2:eq(0)] = 2 + local sm1 = m1[{4, {}, {}}] + local sm2 = m2[{4, {}}] + local res1 = torch.cdiv(sm1, sm2) + local res2 = res1:clone():zero() + for i = 1,sm1:size(1) do + for j = 1, sm1:size(2) do + local idx1d = (((i-1)*sm1:size(1)))+j + res2[i][j] = sm1[i][j] / sm2[idx1d] + end + end + local err = res1:clone():zero() + -- find absolute error + for i = 1, res1:size(1) do + for j = 1, res1:size(2) do + err[i][j] = math.abs(res1[i][j] - res2[i][j]) + end + end + -- find maximum element of error + local maxerr = 0 + for i = 1, err:size(1) do + for j = 1, err:size(2) do + if err[i][j] > maxerr then + maxerr = err[i][j] + end + end + end + mytester:assertlt(maxerr, precision, 'error in torch.cdiv - contiguous' .. ' ' .. t) + + -- non-contiguous + local m1 = torch.randn(10, 10, 10):type(t) + local m2 = torch.randn(10 * 10, 10 * 10):type(t) + m2[m2:eq(0)] = 2 + local sm1 = m1[{{}, 4, {}}] + local sm2 = m2[{{}, 4}] + local res1 = torch.cdiv(sm1, sm2) + local res2 = res1:clone():zero() + for i = 1,sm1:size(1) do + for j = 1, sm1:size(2) do + local idx1d = (((i-1)*sm1:size(1)))+j + res2[i][j] = sm1[i][j] / sm2[idx1d] + end + end + local err = res1:clone():zero() + -- find absolute error + for i = 1, res1:size(1) do + for j = 1, res1:size(2) do + err[i][j] = math.abs(res1[i][j] - res2[i][j]) + end + end + -- find maximum element of error + local maxerr = 0 + for i = 1, err:size(1) do + for j = 1, err:size(2) do + if err[i][j] > maxerr then + maxerr = err[i][j] + end + end + end + mytester:assertlt(maxerr, precision, 'error in torch.cdiv - non-contiguous' .. ' ' .. t) end - mytester:assertlt(maxerr, precision, 'error in torch.cdiv - non-contiguous') end function torchtest.cfmod() @@ -1413,68 +1466,82 @@ function torchtest.cremainder() mytester:assertlt(maxerr, precision, 'error in torch.cremainder - non-contiguous') end -function torchtest.cmul() -- [res] torch.cmul([res,] tensor1, tensor2) - -- contiguous - local m1 = torch.randn(10, 10, 10) - local m2 = torch.randn(10, 10 * 10) - local sm1 = m1[{4, {}, {}}] - local sm2 = m2[{4, {}}] - local res1 = torch.cmul(sm1, sm2) - local res2 = res1:clone():zero() - for i = 1,sm1:size(1) do - for j = 1, sm1:size(2) do - local idx1d = (((i-1)*sm1:size(1)))+j - res2[i][j] = sm1[i][j] * sm2[idx1d] - end - end - local err = res1:clone():zero() - -- find absolute error - for i = 1, res1:size(1) do - for j = 1, res1:size(2) do - err[i][j] = math.abs(res1[i][j] - res2[i][j]) - end - end - -- find maximum element of error - local maxerr = 0 - for i = 1, err:size(1) do - for j = 1, err:size(2) do - if err[i][j] > maxerr then - maxerr = err[i][j] - end - end - end - mytester:assertlt(maxerr, precision, 'error in torch.cmul - contiguous') - - -- non-contiguous - local m1 = torch.randn(10, 10, 10) - local m2 = torch.randn(10 * 10, 10 * 10) - local sm1 = m1[{{}, 4, {}}] - local sm2 = m2[{{}, 4}] - local res1 = torch.cmul(sm1, sm2) - local res2 = res1:clone():zero() - for i = 1,sm1:size(1) do - for j = 1, sm1:size(2) do - local idx1d = (((i-1)*sm1:size(1)))+j - res2[i][j] = sm1[i][j] * sm2[idx1d] - end - end - local err = res1:clone():zero() - -- find absolute error - for i = 1, res1:size(1) do - for j = 1, res1:size(2) do - err[i][j] = math.abs(res1[i][j] - res2[i][j]) - end - end - -- find maximum element of error - local maxerr = 0 - for i = 1, err:size(1) do - for j = 1, err:size(2) do - if err[i][j] > maxerr then - maxerr = err[i][j] - end - end - end - mytester:assertlt(maxerr, precision, 'error in torch.cmul - non-contiguous') +function torchtest.cmul() + local types = { + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + } + + for k,t in ipairs(types) do + + -- [res] torch.cmul([res,] tensor1, tensor2) + -- contiguous + local m1 = torch.randn(10, 10, 10):type(t) + local m2 = torch.randn(10, 10 * 10):type(t) + local sm1 = m1[{4, {}, {}}] + local sm2 = m2[{4, {}}] + local res1 = torch.cmul(sm1, sm2) + local res2 = res1:clone():zero() + for i = 1,sm1:size(1) do + for j = 1, sm1:size(2) do + local idx1d = (((i-1)*sm1:size(1)))+j + res2[i][j] = sm1[i][j] * sm2[idx1d] + end + end + local err = res1:clone():zero() + -- find absolute error + for i = 1, res1:size(1) do + for j = 1, res1:size(2) do + err[i][j] = math.abs(res1[i][j] - res2[i][j]) + end + end + -- find maximum element of error + local maxerr = 0 + for i = 1, err:size(1) do + for j = 1, err:size(2) do + if err[i][j] > maxerr then + maxerr = err[i][j] + end + end + end + mytester:assertlt(maxerr, precision, 'error in torch.cmul - contiguous' .. ' ' .. t) + + -- non-contiguous + local m1 = torch.randn(10, 10, 10):type(t) + local m2 = torch.randn(10 * 10, 10 * 10):type(t) + local sm1 = m1[{{}, 4, {}}] + local sm2 = m2[{{}, 4}] + local res1 = torch.cmul(sm1, sm2) + local res2 = res1:clone():zero() + for i = 1,sm1:size(1) do + for j = 1, sm1:size(2) do + local idx1d = (((i-1)*sm1:size(1)))+j + res2[i][j] = sm1[i][j] * sm2[idx1d] + end + end + local err = res1:clone():zero() + -- find absolute error + for i = 1, res1:size(1) do + for j = 1, res1:size(2) do + err[i][j] = math.abs(res1[i][j] - res2[i][j]) + end + end + -- find maximum element of error + local maxerr = 0 + for i = 1, err:size(1) do + for j = 1, err:size(2) do + if err[i][j] > maxerr then + maxerr = err[i][j] + end + end + end + mytester:assertlt(maxerr, precision, 'error in torch.cmul - non-contiguous' .. ' ' .. t) + end end function torchtest.cpow() -- [res] torch.cpow([res,] tensor1, tensor2) |