diff options
-rw-r--r-- | lib/THCUNN/BCECriterion.cu | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/lib/THCUNN/BCECriterion.cu b/lib/THCUNN/BCECriterion.cu index 3653fc8..04218dc 100644 --- a/lib/THCUNN/BCECriterion.cu +++ b/lib/THCUNN/BCECriterion.cu @@ -25,9 +25,10 @@ struct bce_functor __host__ __device__ Acctype operator()(Tuple x) { - Dtype o = thrust::get<0>(x); + Dtype input = thrust::get<0>(x); Dtype t = thrust::get<1>(x); - return - (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>())); + assert(input >= 0. && input <= 1.); + return - (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>())); } }; @@ -38,10 +39,11 @@ struct bce_functor_weights __host__ __device__ Acctype operator()(Tuple x) { - Dtype o = thrust::get<0>(x); + Dtype input = thrust::get<0>(x); Dtype t = thrust::get<1>(x); Dtype w = thrust::get<2>(x); - return - w * (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>())); + assert(input >= 0. && input <= 1.); + return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>())); } }; |