Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensors/gpu/prod.cpp')
-rwxr-xr-xsrc/tensors/gpu/prod.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp
index bf0d2395..c72af4db 100755
--- a/src/tensors/gpu/prod.cpp
+++ b/src/tensors/gpu/prod.cpp
@@ -562,7 +562,11 @@ void ProdBatchedLegacy(marian::Tensor C,
ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
#if COMPILE_FP16
} else if(C->type() == Type::float16) { // not a *.cu file
- ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+ // we use computeType=float here for fp16 training as this seems more stable and roughly as fast
+ ProdBatchedTypedLegacy<half, float>(C, allocator, A, B, transA, transB, beta, scalar);
+
+ // original for reference:
+ // ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
#endif
} else {
ABORT("ProdBatchedLegacy not implemented for element type {}", C->type());