diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-09-30 22:47:07 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-30 22:47:07 +0300 |
commit | 81c6c24717f80bc6ac9a251cac5bba61cc49d536 (patch) | |
tree | eafdb3f616e45e8a22f01f908ac622cbdb3b81b6 | |
parent | b1ce165d049ec54b38c728d227352f8d2f0d526d (diff) | |
parent | 9bf9ffa528034609d70fd48f7e939816873357cd (diff) |
Merge pull request #782 from apaszke/error_handlers
Allow changing the default error handler for all threads
-rw-r--r-- | lib/TH/THGeneral.c | 60 | ||||
-rw-r--r-- | lib/TH/THGeneral.h.in | 10 |
2 files changed, 49 insertions, 21 deletions
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c index 4bd4c67..fb9abe2 100644 --- a/lib/TH/THGeneral.c +++ b/lib/TH/THGeneral.c @@ -16,14 +16,16 @@ #endif /* Torch Error Handling */ -static void defaultTorchErrorHandlerFunction(const char *msg, void *data) +static void defaultErrorHandlerFunction(const char *msg, void *data) { printf("$ Error: %s\n", msg); exit(-1); } -static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction; -static __thread void *torchErrorHandlerData; +static THErrorHandlerFunction defaultErrorHandler = defaultErrorHandlerFunction; +static void *defaultErrorHandlerData; +static __thread THErrorHandlerFunction threadErrorHandler = NULL; +static __thread void *threadErrorHandlerData; void _THError(const char *file, const int line, const char *fmt, ...) { @@ -40,7 +42,10 @@ void _THError(const char *file, const int line, const char *fmt, ...) snprintf(msg + n, 2048 - n, " at %s:%d", file, line); } - (*torchErrorHandlerFunction)(msg, torchErrorHandlerData); + if (threadErrorHandler) + (*threadErrorHandler)(msg, threadErrorHandlerData); + else + (*defaultErrorHandler)(msg, defaultErrorHandlerData); } void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) { @@ -52,17 +57,23 @@ void _THAssertionFailed(const char *file, const int line, const char *exp, const _THError(file, line, "Assertion `%s' failed. %s", exp, msg); } -void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data ) +void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data) +{ + threadErrorHandler = new_handler; + threadErrorHandlerData = data; +} + +void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *data) { - if(torchErrorHandlerFunction_) - torchErrorHandlerFunction = torchErrorHandlerFunction_; + if (new_handler) + defaultErrorHandler = new_handler; else - torchErrorHandlerFunction = defaultTorchErrorHandlerFunction; - torchErrorHandlerData = data; + defaultErrorHandler = defaultErrorHandlerFunction; + defaultErrorHandlerData = data; } /* Torch Arg Checking Handling */ -static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data) +static void defaultArgErrorHandlerFunction(int argNumber, const char *msg, void *data) { if(msg) printf("$ Invalid argument %d: %s\n", argNumber, msg); @@ -71,8 +82,10 @@ static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, exit(-1); } -static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction; -static __thread void *torchArgErrorHandlerData; +static THArgErrorHandlerFunction defaultArgErrorHandler = defaultArgErrorHandlerFunction; +static void *defaultArgErrorHandlerData; +static __thread THArgErrorHandlerFunction threadArgErrorHandler = NULL; +static __thread void *threadArgErrorHandlerData; void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...) { @@ -90,17 +103,26 @@ void _THArgCheck(const char *file, int line, int condition, int argNumber, const snprintf(msg + n, 2048 - n, " at %s:%d", file, line); } - (*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData); + if (threadArgErrorHandlerData) + (*threadArgErrorHandler)(argNumber, msg, threadArgErrorHandlerData); + else + (*defaultArgErrorHandler)(argNumber, msg, defaultArgErrorHandlerData); } } -void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data ) +void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data) { - if(torchArgErrorHandlerFunction_) - torchArgErrorHandlerFunction = torchArgErrorHandlerFunction_; + threadArgErrorHandler = new_handler; + threadArgErrorHandlerData = data; +} + +void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data) +{ + if (new_handler) + defaultArgErrorHandler = new_handler; else - torchArgErrorHandlerFunction = defaultTorchArgErrorHandlerFunction; - torchArgErrorHandlerData = data; + defaultArgErrorHandler = defaultArgErrorHandlerFunction; + defaultArgErrorHandlerData = data; } static __thread void (*torchGCFunction)(void *data) = NULL; @@ -232,7 +254,7 @@ void* THRealloc(void *ptr, long size) { if(!ptr) return(THAlloc(size)); - + if(size == 0) { THFree(ptr); diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index 1b68f5e..e52ba34 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -45,12 +45,18 @@ #define TH_INDEX_BASE 1 #endif +typedef void (*THErrorHandlerFunction)(const char *msg, void *data); +typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data); + + TH_API double THLog1p(const double x); TH_API void _THError(const char *file, const int line, const char *fmt, ...); TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...); -TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data ); +TH_API void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data); +TH_API void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *data); TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...); -TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data ); +TH_API void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data); +TH_API void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data); TH_API void* THAlloc(long size); TH_API void* THRealloc(void *ptr, long size); TH_API void THFree(void *ptr); |