Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensors/tensor_operators.h')
-rw-r--r--src/tensors/tensor_operators.h19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 6e587953..dc29bf35 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -41,6 +41,25 @@ DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor);
DISPATCH2(AddCast, marian::Tensor, const marian::Tensor);
DISPATCH4(IsNaN, const Tensor, Ptr<Allocator>, bool&, bool&);
+#ifdef CUDA_FOUND
+namespace gpu {
+bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf);
+}
+#endif
+
+namespace cpu {
+bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf);
+}
+
+static inline bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) {
+#ifdef CUDA_FOUND
+ if(in->getBackend()->getDeviceId().type == DeviceType::gpu)
+ return gpu::SanitizeGradient(in, allocator, pruneNaN, clipInf);
+ else
+#endif
+ return cpu::SanitizeGradient(in, allocator, pruneNaN, clipInf);
+}
+
template <class Functor, class... Tensors>
void Element(Functor functor, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND