diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-06-20 04:49:17 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-22 19:38:10 +0300 |
commit | 3864f3782a74715e35d98176f9014a74524a868a (patch) | |
tree | dd49dc9ae1950ceab4cb579a26225bc8cc10191a | |
parent | 6ec27dc9d941d90eea421e1bb1d0e953d68279b6 (diff) |
improving TH error messages in Apply macros
-rw-r--r-- | lib/TH/THGeneral.c | 23 | ||||
-rw-r--r-- | lib/TH/THGeneral.h.in | 6 | ||||
-rw-r--r-- | lib/TH/THStorage.c | 20 | ||||
-rw-r--r-- | lib/TH/THStorage.h | 5 | ||||
-rw-r--r-- | lib/TH/THTensorApply.h | 32 | ||||
-rw-r--r-- | lib/TH/THTensorDimApply.h | 52 |
6 files changed, 95 insertions, 43 deletions
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c index d44e762..ac032b9 100644 --- a/lib/TH/THGeneral.c +++ b/lib/TH/THGeneral.c @@ -358,3 +358,26 @@ TH_API void THInferNumThreads(void) omp_set_num_threads(mkl_get_max_threads()); #endif } + +TH_API THDescBuff _THSizeDesc(const long *size, const long ndim) { + const int L = TH_DESC_BUFF_LEN; + THDescBuff buf; + char *str = buf.str; + int n = 0; + n += snprintf(str, L-n, "["); + int i; + for(i = 0; i < ndim; i++) { + if(n >= L) break; + n += snprintf(str+n, L-n, "%ld", size[i]); + if(i < ndim-1) { + n += snprintf(str+n, L-n, " x "); + } + } + if(n < L - 2) { + snprintf(str+n, L-n, "]"); + } else { + snprintf(str+L-5, 5, "...]"); + } + return buf; +} + diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index 0621c7a..88a3934 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -42,8 +42,14 @@ typedef void (*THErrorHandlerFunction)(const char *msg, void *data); typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data); +#define TH_DESC_BUFF_LEN 64 +typedef struct { + char str[TH_DESC_BUFF_LEN]; +} THDescBuff; + TH_API double THLog1p(const double x); +TH_API THDescBuff _THSizeDesc(const long *size, const long ndim); TH_API void _THError(const char *file, const int line, const char *fmt, ...); TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...); TH_API void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data); diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c index f850600..f6b63f4 100644 --- a/lib/TH/THStorage.c +++ b/lib/TH/THStorage.c @@ -15,25 +15,7 @@ THDescBuff THLongStorage_sizeDesc(const THLongStorage *size) { - const int L = TH_DESC_BUFF_LEN; - THDescBuff buf; - char *str = buf.str; - int n = 0; - n += snprintf(str, L-n, "["); - int i; - for(i = 0; i < size->size; i++) { - if(n >= L) break; - n += snprintf(str+n, L-n, "%ld", size->data[i]); - if(i < size->size-1) { - n += snprintf(str+n, L-n, " x "); - } - } - if(n < L - 2) { - snprintf(str+n, L-n, "]"); - } else { - snprintf(str+L-5, 5, "...]"); - } - return buf; + return _THSizeDesc(size->data, size->size); } THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement) diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h index e141e4c..fb7946b 100644 --- a/lib/TH/THStorage.h +++ b/lib/TH/THStorage.h @@ -7,11 +7,6 @@ #define THStorage TH_CONCAT_3(TH,Real,Storage) #define THStorage_(NAME) TH_CONCAT_4(TH,Real,Storage_,NAME) -#define TH_DESC_BUFF_LEN 64 -typedef struct { - char str[TH_DESC_BUFF_LEN]; -} THDescBuff; - /* fast access methods */ #define TH_STORAGE_GET(storage, idx) ((storage)->data[(idx)]) #define TH_STORAGE_SET(storage, idx, value) ((storage)->data[(idx)] = (value)) diff --git a/lib/TH/THTensorApply.h b/lib/TH/THTensorApply.h index 86672b9..7f48da4 100644 --- a/lib/TH/THTensorApply.h +++ b/lib/TH/THTensorApply.h @@ -141,10 +141,24 @@ __TH_TENSOR_APPLYX_PREAMBLE(TYPE1, TENSOR1, DIM, 1) \ __TH_TENSOR_APPLYX_PREAMBLE(TYPE2, TENSOR2, DIM, 1) \ __TH_TENSOR_APPLYX_PREAMBLE(TYPE3, TENSOR3, DIM, 1) \ -\ - if(TENSOR1##_n != TENSOR2##_n || TENSOR1##_n != TENSOR3##_n) /* should we do the check in the function instead? i think so */ \ - THError("inconsistent tensor size"); \ -\ + \ + int elements_equal = 1; \ + if(TENSOR1##_n != TENSOR2##_n) { \ + elements_equal = 0; \ + } \ + else if(TENSOR1##_n != TENSOR3##_n) { \ + elements_equal = 0; \ + } \ + if (elements_equal == 0) { \ + THDescBuff T1buff = _THSizeDesc(TENSOR1->size, TENSOR1->nDimension); \ + THDescBuff T2buff = _THSizeDesc(TENSOR2->size, TENSOR2->nDimension); \ + THDescBuff T3buff = _THSizeDesc(TENSOR3->size, TENSOR3->nDimension); \ + THError("inconsistent tensor size, expected %s %s, %s %s and %s %s to have the same " \ + "number of elements, but got %d, %d and %d elements respectively", \ + #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str, \ + TENSOR1##_n, TENSOR2##_n, TENSOR3##_n); \ + } \ + \ while(!TH_TENSOR_APPLY_hasFinished) \ { \ /* Loop through the inner most region of the Tensor */ \ @@ -174,9 +188,13 @@ __TH_TENSOR_APPLYX_PREAMBLE(TYPE1, TENSOR1, DIM, 1) \ __TH_TENSOR_APPLYX_PREAMBLE(TYPE2, TENSOR2, DIM, 1) \ \ - if(TENSOR1##_n != TENSOR2##_n) /* should we do the check in the function instead? i think so */ \ - THError("inconsistent tensor size"); \ -\ + if(TENSOR1##_n != TENSOR2##_n) { \ + THDescBuff T1buff = _THSizeDesc(TENSOR1->size, TENSOR1->nDimension); \ + THDescBuff T2buff = _THSizeDesc(TENSOR2->size, TENSOR2->nDimension); \ + THError("inconsistent tensor size, expected %s %s and %s %s to have the same " \ + "number of elements, but got %d and %d elements respectively", \ + #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, TENSOR1##_n, TENSOR2##_n); \ + } \ while(!TH_TENSOR_APPLY_hasFinished) \ { \ /* Loop through the inner most region of the Tensor */ \ diff --git a/lib/TH/THTensorDimApply.h b/lib/TH/THTensorDimApply.h index df333fa..6727e1f 100644 --- a/lib/TH/THTensorDimApply.h +++ b/lib/TH/THTensorDimApply.h @@ -14,19 +14,38 @@ int TH_TENSOR_DIM_APPLY_i; \ \ if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \ - THError("invalid dimension"); \ - if( TENSOR1->nDimension != TENSOR2->nDimension ) \ - THError("inconsistent tensor sizes"); \ - if( TENSOR1->nDimension != TENSOR3->nDimension ) \ - THError("inconsistent tensor sizes"); \ + THError("invalid dimension %d (expected to be 0 <= dim < %d)", DIMENSION, TENSOR1->nDimension); \ + int same_dims = 1; \ + if( TENSOR1->nDimension != TENSOR2->nDimension ) { \ + same_dims = 0; \ + } \ + if( TENSOR1->nDimension != TENSOR3->nDimension ) { \ + same_dims = 0; \ + } \ + if (same_dims == 0) { \ + THDescBuff T1buff = _THSizeDesc(TENSOR1->size, TENSOR1->nDimension); \ + THDescBuff T2buff = _THSizeDesc(TENSOR2->size, TENSOR2->nDimension); \ + THDescBuff T3buff = _THSizeDesc(TENSOR3->size, TENSOR3->nDimension); \ + THError("inconsistent tensor size, expected %s %s, %s %s and %s %s to have the same " \ + "number of dimensions", #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str); \ + } \ + int shape_check_flag = 0; \ for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ { \ if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ continue; \ if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \ - THError("inconsistent tensor sizes"); \ + shape_check_flag = 1; \ if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR3->size[TH_TENSOR_DIM_APPLY_i]) \ - THError("inconsistent tensor sizes"); \ + shape_check_flag = 1; \ + } \ + \ + if (shape_check_flag == 1) { \ + THDescBuff T1buff = _THSizeDesc(TENSOR1->size, TENSOR1->nDimension); \ + THDescBuff T2buff = _THSizeDesc(TENSOR2->size, TENSOR2->nDimension); \ + THDescBuff T3buff = _THSizeDesc(TENSOR3->size, TENSOR3->nDimension); \ + THError("Expected %s %s, %s %s and %s %s to have the same size in dimension %d", \ + #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str, DIMENSION); \ } \ \ TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \ @@ -119,15 +138,24 @@ int TH_TENSOR_DIM_APPLY_i; \ \ if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \ - THError("invalid dimension"); \ - if( TENSOR1->nDimension != TENSOR2->nDimension ) \ - THError("inconsistent tensor sizes"); \ + THError("invalid dimension %d (expected to be 0 <= dim < %d)", DIMENSION, TENSOR1->nDimension); \ + if( TENSOR1->nDimension != TENSOR2->nDimension ) { \ + THDescBuff T1buff = _THSizeDesc(TENSOR1->size, TENSOR1->nDimension); \ + THDescBuff T2buff = _THSizeDesc(TENSOR2->size, TENSOR2->nDimension); \ + THError("inconsistent tensor size, expected %s %s and %s %s to have the same " \ + "number of dimensions", #TENSOR1, T1buff.str, #TENSOR2, T2buff.str); \ + } \ + int shape_check_flag = 0; \ for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ { \ if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ continue; \ - if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \ - THError("inconsistent tensor sizes"); \ + if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) { \ + THDescBuff T1buff = _THSizeDesc(TENSOR1->size, TENSOR1->nDimension); \ + THDescBuff T2buff = _THSizeDesc(TENSOR2->size, TENSOR2->nDimension); \ + THError("Expected %s %s and %s %s to have the same size in dimension %d", \ + #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, DIMENSION); \ + } \ } \ \ TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \ |