Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensors/cpu/tensor_operators.cpp')
-rwxr-xr-xsrc/tensors/cpu/tensor_operators.cpp13
1 files changed, 12 insertions, 1 deletions
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index 1afb8f64..1e1adc38 100755
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -24,6 +24,10 @@ void IsNaN(const Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool& /*isNaN*/, b
ABORT("Not implemented");
}
+bool SanitizeGradient(marian::Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool /*pruneNaN*/, bool /*clipInf*/) {
+ ABORT("Not implemented");
+}
+
template <bool add, typename To, typename From>
void CopyCastTo(To* out, const From* in, int length) {
for(int i = 0; i < length; ++i)
@@ -735,6 +739,7 @@ void Select(Tensor out,
}
}
+template <bool add>
void Insert(Tensor out,
const Tensor in,
const Tensor indices,
@@ -756,10 +761,16 @@ void Insert(Tensor out,
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex];
int outIndex = outShape.index(dims);
- out->data()[outIndex] += in->data()[index];
+ if(add)
+ out->data()[outIndex] += in->data()[index];
+ else
+ out->data()[outIndex] = in->data()[index];
}
}
+template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
+template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
+
void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) {
int rows = out_->shape().elements() / out_->shape().back();
int cols = out_->shape().back();