diff options
Diffstat (limited to 'src/tensors/cpu/tensor_operators.cpp')
-rwxr-xr-x | src/tensors/cpu/tensor_operators.cpp | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 1afb8f64..1e1adc38 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -24,6 +24,10 @@ void IsNaN(const Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool& /*isNaN*/, b ABORT("Not implemented"); } +bool SanitizeGradient(marian::Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool /*pruneNaN*/, bool /*clipInf*/) { + ABORT("Not implemented"); +} + template <bool add, typename To, typename From> void CopyCastTo(To* out, const From* in, int length) { for(int i = 0; i < length; ++i) @@ -735,6 +739,7 @@ void Select(Tensor out, } } +template <bool add> void Insert(Tensor out, const Tensor in, const Tensor indices, @@ -756,10 +761,16 @@ void Insert(Tensor out, int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex]; int outIndex = outShape.index(dims); - out->data()[outIndex] += in->data()[index]; + if(add) + out->data()[outIndex] += in->data()[index]; + else + out->data()[outIndex] = in->data()[index]; } } +template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis); +template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis); + void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) { int rows = out_->shape().elements() / out_->shape().back(); int cols = out_->shape().back(); |