Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/FbgemmI8DepthwiseAvx2.cc')
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc36
1 files changed, 36 insertions, 0 deletions
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index 4c8d71d..7454ef4 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -2538,6 +2538,14 @@ void depthwise_3x3_pad_1(
bool fuse_relu,
int thread_id,
int num_threads) {
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
@@ -2958,6 +2966,16 @@ void depthwise_3x3x3_pad_1(
bool fuse_relu,
int thread_id,
int num_threads) {
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>(
N,
@@ -3158,6 +3176,14 @@ void depthwise_3x3_per_channel_quantization_pad_1(
bool fuse_relu,
int thread_id,
int num_threads) {
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
@@ -3524,6 +3550,16 @@ void depthwise_3x3x3_per_channel_quantization_pad_1(
bool fuse_relu,
int thread_id,
int num_threads) {
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
N,