Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/THCUNN/BCECriterion.cu10
1 files changed, 6 insertions, 4 deletions
diff --git a/lib/THCUNN/BCECriterion.cu b/lib/THCUNN/BCECriterion.cu
index 3653fc8..04218dc 100644
--- a/lib/THCUNN/BCECriterion.cu
+++ b/lib/THCUNN/BCECriterion.cu
@@ -25,9 +25,10 @@ struct bce_functor
__host__ __device__
Acctype operator()(Tuple x)
{
- Dtype o = thrust::get<0>(x);
+ Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
- return - (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>()));
+ assert(input >= 0. && input <= 1.);
+ return - (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
}
};
@@ -38,10 +39,11 @@ struct bce_functor_weights
__host__ __device__
Acctype operator()(Tuple x)
{
- Dtype o = thrust::get<0>(x);
+ Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
Dtype w = thrust::get<2>(x);
- return - w * (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>()));
+ assert(input >= 0. && input <= 1.);
+ return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
}
};