diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-05-23 22:42:29 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-05-25 23:49:22 +0300 |
commit | 4150c664c980e03a29eaad685e1d860f92888c00 (patch) | |
tree | d42ddb5c6fdd91ac8a94fef4017bca8e37d5ddc7 /lib | |
parent | ed5aa2dceedf2f75c90cde637befe2e0a60e367d (diff) |
Add scatterAdd
Diffstat (limited to 'lib')
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 25 | ||||
-rw-r--r-- | lib/TH/generic/THTensorMath.h | 1 |
2 files changed, 26 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 1dc1bc7..706b1d0 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -466,6 +466,31 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor }) } +void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) +{ + long elems_per_row, i, idx; + + THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds"); + THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3, + "Index tensor must have same dimensions as output tensor"); + THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4, + "Input tensor must have same dimensions as output tensor"); + + elems_per_row = THLongTensor_size(index, dim); + + TH_TENSOR_DIM_APPLY3(real, tensor, real, src, long, index, dim, + for (i = 0; i < elems_per_row; ++i) + { + idx = *(index_data + i*index_stride); + if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) + { + THFree(TH_TENSOR_DIM_APPLY_counter); + THError("Invalid index in scatterAdd"); + } + tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] += *(src_data + i*src_stride); + }) +} + void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real val) { long elems_per_row, i, idx; diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index a3cf410..bacc9df 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -18,6 +18,7 @@ TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, TH_API void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index); TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); +TH_API void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); TH_API void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real val); TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); |