Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2020-03-23 09:27:03 +0300
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>2020-03-23 09:28:30 +0300
commit58c002d1593f32aa420ab56b5c344e60d3fb6d05 (patch)
treeb17a2342f667a0d88778259a506edeeb05f34c72
parent38ee061060934e23970f7c978adb6caeac2a8bc2 (diff)
clamping with 1 comparison trick by treating signed as if it's unsigned (#324)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/324 See more details in https://stackoverflow.com/a/34072155 Reviewed By: jianyuh Differential Revision: D20588023 fbshipit-source-id: 9af9d72c03606aee753076e9f5a5fb06aed9b323
-rw-r--r--src/EmbeddingSpMDM.cc14
-rw-r--r--src/EmbeddingSpMDMNBit.cc14
-rw-r--r--src/RowWiseSparseAdagradFused.cc14
3 files changed, 15 insertions, 27 deletions
diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc
index 4f7826d..e7c89c7 100644
--- a/src/EmbeddingSpMDM.cc
+++ b/src/EmbeddingSpMDM.cc
@@ -445,10 +445,11 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>::
} else {
a->mov(scratchReg1_.r32(), x86::dword_ptr(indices));
}
- a->cmp(scratchReg1_, 0);
- a->jl(error);
+ // A trick to check x >= data_size or x < 0 in one shot by treating
+ // scratchReg1_ as if it has unsigned value
+ // (https://stackoverflow.com/a/34072155).
a->cmp(scratchReg1_, data_size);
- a->jge(error);
+ a->jae(error);
if (ROWWISE_SPARSE) {
a->mov(
@@ -484,13 +485,8 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>::
x86::dword_ptr(indices, pref_dist * sizeof(indxType)));
}
- a->cmp(scratchReg2_, 0);
- a->jl(pref_dist_reset_start);
a->cmp(scratchReg2_, data_size);
- a->jge(pref_dist_reset_start);
-
- // everything is okay, prefetch a few rows ahead
- a->jmp(pref_dist_reset_end);
+ a->jb(pref_dist_reset_end);
a->bind(pref_dist_reset_start);
// things are not okay just get the current row
diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc
index c4f0cc7..8e94404 100644
--- a/src/EmbeddingSpMDMNBit.cc
+++ b/src/EmbeddingSpMDMNBit.cc
@@ -487,10 +487,11 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(
} else {
a->mov(scratchReg1_.r32(), x86::dword_ptr(indices));
}
- a->cmp(scratchReg1_, 0);
- a->jl(error);
+ // A trick to check x >= data_size or x < 0 in one shot by treating
+ // scratchReg1_ as if it has unsigned value
+ // (https://stackoverflow.com/a/34072155).
a->cmp(scratchReg1_, data_size);
- a->jge(error);
+ a->jae(error);
if (ROWWISE_SPARSE) {
a->mov(
@@ -525,13 +526,8 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(
x86::dword_ptr(indices, pref_dist * sizeof(indxType)));
}
- a->cmp(scratchReg2_, 0);
- a->jl(pref_dist_reset_start);
a->cmp(scratchReg2_, data_size);
- a->jge(pref_dist_reset_start);
-
- // everything is okay, prefetch a few rows ahead
- a->jmp(pref_dist_reset_end);
+ a->jb(pref_dist_reset_end);
a->bind(pref_dist_reset_start);
// things are not okay just get the current row
diff --git a/src/RowWiseSparseAdagradFused.cc b/src/RowWiseSparseAdagradFused.cc
index 0ca8cea..6adc7ab 100644
--- a/src/RowWiseSparseAdagradFused.cc
+++ b/src/RowWiseSparseAdagradFused.cc
@@ -343,10 +343,11 @@ GenRowWiseSparseAdagradFused<indxType, instSet>::getOrCreate(
} else {
a->mov(scratchReg1.r32(), x86::dword_ptr(indices));
}
- a->cmp(scratchReg1, 0);
- a->jl(error);
+ // A trick to check x >= data_size or x < 0 in one shot by treating
+ // scratchReg1_ as if it has unsigned value
+ // (https://stackoverflow.com/a/34072155).
a->cmp(scratchReg1, data_size);
- a->jge(error);
+ a->jae(error);
if (prefetch) {
asmjit::Label pref_dist_reset_start = a->newLabel();
@@ -369,13 +370,8 @@ GenRowWiseSparseAdagradFused<indxType, instSet>::getOrCreate(
x86::dword_ptr(indices, prefetch * sizeof(indxType)));
}
- a->cmp(scratchReg2, 0);
- a->jl(pref_dist_reset_start);
a->cmp(scratchReg2, data_size);
- a->jge(pref_dist_reset_start);
-
- // everything is okay, prefetch a few rows ahead
- a->jmp(pref_dist_reset_end);
+ a->jb(pref_dist_reset_end);
a->bind(pref_dist_reset_start);
// things are not okay just get the current row