diff options
author | Jongsoo Park <jongsoo@fb.com> | 2020-03-23 09:27:03 +0300 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2020-03-23 09:28:30 +0300 |
commit | 58c002d1593f32aa420ab56b5c344e60d3fb6d05 (patch) | |
tree | b17a2342f667a0d88778259a506edeeb05f34c72 | |
parent | 38ee061060934e23970f7c978adb6caeac2a8bc2 (diff) |
clamping with 1 comparison trick by treating signed as if it's unsigned (#324)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/324
See more details in https://stackoverflow.com/a/34072155
Reviewed By: jianyuh
Differential Revision: D20588023
fbshipit-source-id: 9af9d72c03606aee753076e9f5a5fb06aed9b323
-rw-r--r-- | src/EmbeddingSpMDM.cc | 14 | ||||
-rw-r--r-- | src/EmbeddingSpMDMNBit.cc | 14 | ||||
-rw-r--r-- | src/RowWiseSparseAdagradFused.cc | 14 |
3 files changed, 15 insertions, 27 deletions
diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc index 4f7826d..e7c89c7 100644 --- a/src/EmbeddingSpMDM.cc +++ b/src/EmbeddingSpMDM.cc @@ -445,10 +445,11 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>:: } else { a->mov(scratchReg1_.r32(), x86::dword_ptr(indices)); } - a->cmp(scratchReg1_, 0); - a->jl(error); + // A trick to check x >= data_size or x < 0 in one shot by treating + // scratchReg1_ as if it has unsigned value + // (https://stackoverflow.com/a/34072155). a->cmp(scratchReg1_, data_size); - a->jge(error); + a->jae(error); if (ROWWISE_SPARSE) { a->mov( @@ -484,13 +485,8 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>:: x86::dword_ptr(indices, pref_dist * sizeof(indxType))); } - a->cmp(scratchReg2_, 0); - a->jl(pref_dist_reset_start); a->cmp(scratchReg2_, data_size); - a->jge(pref_dist_reset_start); - - // everything is okay, prefetch a few rows ahead - a->jmp(pref_dist_reset_end); + a->jb(pref_dist_reset_end); a->bind(pref_dist_reset_start); // things are not okay just get the current row diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index c4f0cc7..8e94404 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -487,10 +487,11 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate( } else { a->mov(scratchReg1_.r32(), x86::dword_ptr(indices)); } - a->cmp(scratchReg1_, 0); - a->jl(error); + // A trick to check x >= data_size or x < 0 in one shot by treating + // scratchReg1_ as if it has unsigned value + // (https://stackoverflow.com/a/34072155). a->cmp(scratchReg1_, data_size); - a->jge(error); + a->jae(error); if (ROWWISE_SPARSE) { a->mov( @@ -525,13 +526,8 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate( x86::dword_ptr(indices, pref_dist * sizeof(indxType))); } - a->cmp(scratchReg2_, 0); - a->jl(pref_dist_reset_start); a->cmp(scratchReg2_, data_size); - a->jge(pref_dist_reset_start); - - // everything is okay, prefetch a few rows ahead - a->jmp(pref_dist_reset_end); + a->jb(pref_dist_reset_end); a->bind(pref_dist_reset_start); // things are not okay just get the current row diff --git a/src/RowWiseSparseAdagradFused.cc b/src/RowWiseSparseAdagradFused.cc index 0ca8cea..6adc7ab 100644 --- a/src/RowWiseSparseAdagradFused.cc +++ b/src/RowWiseSparseAdagradFused.cc @@ -343,10 +343,11 @@ GenRowWiseSparseAdagradFused<indxType, instSet>::getOrCreate( } else { a->mov(scratchReg1.r32(), x86::dword_ptr(indices)); } - a->cmp(scratchReg1, 0); - a->jl(error); + // A trick to check x >= data_size or x < 0 in one shot by treating + // scratchReg1_ as if it has unsigned value + // (https://stackoverflow.com/a/34072155). a->cmp(scratchReg1, data_size); - a->jge(error); + a->jae(error); if (prefetch) { asmjit::Label pref_dist_reset_start = a->newLabel(); @@ -369,13 +370,8 @@ GenRowWiseSparseAdagradFused<indxType, instSet>::getOrCreate( x86::dword_ptr(indices, prefetch * sizeof(indxType))); } - a->cmp(scratchReg2, 0); - a->jl(pref_dist_reset_start); a->cmp(scratchReg2, data_size); - a->jge(pref_dist_reset_start); - - // everything is okay, prefetch a few rows ahead - a->jmp(pref_dist_reset_end); + a->jb(pref_dist_reset_end); a->bind(pref_dist_reset_start); // things are not okay just get the current row |