diff options
Diffstat (limited to 'src/tensors/tensor_operators.h')
-rw-r--r-- | src/tensors/tensor_operators.h | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index 6e587953..dc29bf35 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -41,6 +41,25 @@ DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor); DISPATCH2(AddCast, marian::Tensor, const marian::Tensor); DISPATCH4(IsNaN, const Tensor, Ptr<Allocator>, bool&, bool&); +#ifdef CUDA_FOUND +namespace gpu { +bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf); +} +#endif + +namespace cpu { +bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf); +} + +static inline bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) { +#ifdef CUDA_FOUND + if(in->getBackend()->getDeviceId().type == DeviceType::gpu) + return gpu::SanitizeGradient(in, allocator, pruneNaN, clipInf); + else +#endif + return cpu::SanitizeGradient(in, allocator, pruneNaN, clipInf); +} + template <class Functor, class... Tensors> void Element(Functor functor, marian::Tensor out, Tensors... tensors) { #ifdef CUDA_FOUND |