diff options
author | cph <stegben.benjamin@gmail.com> | 2017-06-21 13:34:26 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-22 20:20:25 +0300 |
commit | fd4780703227d2a00d708fdd1bc19b6c92e8e4d7 (patch) | |
tree | 32a6bf455370bb2d60171be4e86a376369995a5d | |
parent | 9cffa0ef9c7775093896606e7f86c206e8099ce8 (diff) |
add asserts to BCECriterion
-rw-r--r-- | lib/THCUNN/BCECriterion.cu | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/lib/THCUNN/BCECriterion.cu b/lib/THCUNN/BCECriterion.cu index 3653fc8..04218dc 100644 --- a/lib/THCUNN/BCECriterion.cu +++ b/lib/THCUNN/BCECriterion.cu @@ -25,9 +25,10 @@ struct bce_functor __host__ __device__ Acctype operator()(Tuple x) { - Dtype o = thrust::get<0>(x); + Dtype input = thrust::get<0>(x); Dtype t = thrust::get<1>(x); - return - (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>())); + assert(input >= 0. && input <= 1.); + return - (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>())); } }; @@ -38,10 +39,11 @@ struct bce_functor_weights __host__ __device__ Acctype operator()(Tuple x) { - Dtype o = thrust::get<0>(x); + Dtype input = thrust::get<0>(x); Dtype t = thrust::get<1>(x); Dtype w = thrust::get<2>(x); - return - w * (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>())); + assert(input >= 0. && input <= 1.); + return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>())); } }; |