diff options
Diffstat (limited to 'src/tensors/gpu/prod.cpp')
-rwxr-xr-x | src/tensors/gpu/prod.cpp | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp index bf0d2395..c72af4db 100755 --- a/src/tensors/gpu/prod.cpp +++ b/src/tensors/gpu/prod.cpp @@ -562,7 +562,11 @@ void ProdBatchedLegacy(marian::Tensor C, ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar); #if COMPILE_FP16 } else if(C->type() == Type::float16) { // not a *.cu file - ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar)); + // we use computeType=float here for fp16 training as this seems more stable and roughly as fast + ProdBatchedTypedLegacy<half, float>(C, allocator, A, B, transA, transB, beta, scalar); + + // original for reference: + // ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar)); #endif } else { ABORT("ProdBatchedLegacy not implemented for element type {}", C->type()); |