Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-05-23 22:42:29 +0300
committerSoumith Chintala <soumith@gmail.com>2017-05-25 23:49:22 +0300
commit4150c664c980e03a29eaad685e1d860f92888c00 (patch)
treed42ddb5c6fdd91ac8a94fef4017bca8e37d5ddc7 /lib
parented5aa2dceedf2f75c90cde637befe2e0a60e367d (diff)
Add scatterAdd
Diffstat (limited to 'lib')
-rw-r--r--lib/TH/generic/THTensorMath.c25
-rw-r--r--lib/TH/generic/THTensorMath.h1
2 files changed, 26 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index 1dc1bc7..706b1d0 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -466,6 +466,31 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor
})
}
+void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
+{
+ long elems_per_row, i, idx;
+
+ THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds");
+ THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3,
+ "Index tensor must have same dimensions as output tensor");
+ THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4,
+ "Input tensor must have same dimensions as output tensor");
+
+ elems_per_row = THLongTensor_size(index, dim);
+
+ TH_TENSOR_DIM_APPLY3(real, tensor, real, src, long, index, dim,
+ for (i = 0; i < elems_per_row; ++i)
+ {
+ idx = *(index_data + i*index_stride);
+ if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE)
+ {
+ THFree(TH_TENSOR_DIM_APPLY_counter);
+ THError("Invalid index in scatterAdd");
+ }
+ tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] += *(src_data + i*src_stride);
+ })
+}
+
void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real val)
{
long elems_per_row, i, idx;
diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h
index a3cf410..bacc9df 100644
--- a/lib/TH/generic/THTensorMath.h
+++ b/lib/TH/generic/THTensorMath.h
@@ -18,6 +18,7 @@ TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index,
TH_API void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
+TH_API void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real val);
TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);