Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-09-30 22:47:07 +0300
committerGitHub <noreply@github.com>2016-09-30 22:47:07 +0300
commit81c6c24717f80bc6ac9a251cac5bba61cc49d536 (patch)
treeeafdb3f616e45e8a22f01f908ac622cbdb3b81b6
parentb1ce165d049ec54b38c728d227352f8d2f0d526d (diff)
parent9bf9ffa528034609d70fd48f7e939816873357cd (diff)
Merge pull request #782 from apaszke/error_handlers
Allow changing the default error handler for all threads
-rw-r--r--lib/TH/THGeneral.c60
-rw-r--r--lib/TH/THGeneral.h.in10
2 files changed, 49 insertions, 21 deletions
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c
index 4bd4c67..fb9abe2 100644
--- a/lib/TH/THGeneral.c
+++ b/lib/TH/THGeneral.c
@@ -16,14 +16,16 @@
#endif
/* Torch Error Handling */
-static void defaultTorchErrorHandlerFunction(const char *msg, void *data)
+static void defaultErrorHandlerFunction(const char *msg, void *data)
{
printf("$ Error: %s\n", msg);
exit(-1);
}
-static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction;
-static __thread void *torchErrorHandlerData;
+static THErrorHandlerFunction defaultErrorHandler = defaultErrorHandlerFunction;
+static void *defaultErrorHandlerData;
+static __thread THErrorHandlerFunction threadErrorHandler = NULL;
+static __thread void *threadErrorHandlerData;
void _THError(const char *file, const int line, const char *fmt, ...)
{
@@ -40,7 +42,10 @@ void _THError(const char *file, const int line, const char *fmt, ...)
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
}
- (*torchErrorHandlerFunction)(msg, torchErrorHandlerData);
+ if (threadErrorHandler)
+ (*threadErrorHandler)(msg, threadErrorHandlerData);
+ else
+ (*defaultErrorHandler)(msg, defaultErrorHandlerData);
}
void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) {
@@ -52,17 +57,23 @@ void _THAssertionFailed(const char *file, const int line, const char *exp, const
_THError(file, line, "Assertion `%s' failed. %s", exp, msg);
}
-void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data )
+void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data)
+{
+ threadErrorHandler = new_handler;
+ threadErrorHandlerData = data;
+}
+
+void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *data)
{
- if(torchErrorHandlerFunction_)
- torchErrorHandlerFunction = torchErrorHandlerFunction_;
+ if (new_handler)
+ defaultErrorHandler = new_handler;
else
- torchErrorHandlerFunction = defaultTorchErrorHandlerFunction;
- torchErrorHandlerData = data;
+ defaultErrorHandler = defaultErrorHandlerFunction;
+ defaultErrorHandlerData = data;
}
/* Torch Arg Checking Handling */
-static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data)
+static void defaultArgErrorHandlerFunction(int argNumber, const char *msg, void *data)
{
if(msg)
printf("$ Invalid argument %d: %s\n", argNumber, msg);
@@ -71,8 +82,10 @@ static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg,
exit(-1);
}
-static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction;
-static __thread void *torchArgErrorHandlerData;
+static THArgErrorHandlerFunction defaultArgErrorHandler = defaultArgErrorHandlerFunction;
+static void *defaultArgErrorHandlerData;
+static __thread THArgErrorHandlerFunction threadArgErrorHandler = NULL;
+static __thread void *threadArgErrorHandlerData;
void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...)
{
@@ -90,17 +103,26 @@ void _THArgCheck(const char *file, int line, int condition, int argNumber, const
snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
}
- (*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData);
+ if (threadArgErrorHandlerData)
+ (*threadArgErrorHandler)(argNumber, msg, threadArgErrorHandlerData);
+ else
+ (*defaultArgErrorHandler)(argNumber, msg, defaultArgErrorHandlerData);
}
}
-void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data )
+void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data)
{
- if(torchArgErrorHandlerFunction_)
- torchArgErrorHandlerFunction = torchArgErrorHandlerFunction_;
+ threadArgErrorHandler = new_handler;
+ threadArgErrorHandlerData = data;
+}
+
+void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data)
+{
+ if (new_handler)
+ defaultArgErrorHandler = new_handler;
else
- torchArgErrorHandlerFunction = defaultTorchArgErrorHandlerFunction;
- torchArgErrorHandlerData = data;
+ defaultArgErrorHandler = defaultArgErrorHandlerFunction;
+ defaultArgErrorHandlerData = data;
}
static __thread void (*torchGCFunction)(void *data) = NULL;
@@ -232,7 +254,7 @@ void* THRealloc(void *ptr, long size)
{
if(!ptr)
return(THAlloc(size));
-
+
if(size == 0)
{
THFree(ptr);
diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in
index 1b68f5e..e52ba34 100644
--- a/lib/TH/THGeneral.h.in
+++ b/lib/TH/THGeneral.h.in
@@ -45,12 +45,18 @@
#define TH_INDEX_BASE 1
#endif
+typedef void (*THErrorHandlerFunction)(const char *msg, void *data);
+typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data);
+
+
TH_API double THLog1p(const double x);
TH_API void _THError(const char *file, const int line, const char *fmt, ...);
TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...);
-TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data );
+TH_API void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data);
+TH_API void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *data);
TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...);
-TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data );
+TH_API void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data);
+TH_API void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data);
TH_API void* THAlloc(long size);
TH_API void* THRealloc(void *ptr, long size);
TH_API void THFree(void *ptr);