Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcph <stegben.benjamin@gmail.com>2017-06-21 13:34:26 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-22 20:20:25 +0300
commitfd4780703227d2a00d708fdd1bc19b6c92e8e4d7 (patch)
tree32a6bf455370bb2d60171be4e86a376369995a5d
parent9cffa0ef9c7775093896606e7f86c206e8099ce8 (diff)
add asserts to BCECriterion
-rw-r--r--lib/THCUNN/BCECriterion.cu10
1 files changed, 6 insertions, 4 deletions
diff --git a/lib/THCUNN/BCECriterion.cu b/lib/THCUNN/BCECriterion.cu
index 3653fc8..04218dc 100644
--- a/lib/THCUNN/BCECriterion.cu
+++ b/lib/THCUNN/BCECriterion.cu
@@ -25,9 +25,10 @@ struct bce_functor
__host__ __device__
Acctype operator()(Tuple x)
{
- Dtype o = thrust::get<0>(x);
+ Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
- return - (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>()));
+ assert(input >= 0. && input <= 1.);
+ return - (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1)- t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
}
};
@@ -38,10 +39,11 @@ struct bce_functor_weights
__host__ __device__
Acctype operator()(Tuple x)
{
- Dtype o = thrust::get<0>(x);
+ Dtype input = thrust::get<0>(x);
Dtype t = thrust::get<1>(x);
Dtype w = thrust::get<2>(x);
- return - w * (t * THCNumerics<Acctype>::log(o + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - o + eps<Acctype>()));
+ assert(input >= 0. && input <= 1.);
+ return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
}
};