diff options
author | Sam Gross <sgross@fb.com> | 2017-02-08 19:48:06 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2017-02-08 20:52:50 +0300 |
commit | fc2017b01f372f7b29b3a023192f89998f2378fb (patch) | |
tree | 8d81201e74a266df4e24241b30233fb029418482 | |
parent | 68f34f6d6e7085b5b75655d4ac0a0e7bb5ac2f43 (diff) |
Add unsqueeze1d to TH
Unsqueeze inserts a singleton dimension. Unlike view, it doesn't require
the tensor to be contiguous.
-rw-r--r-- | lib/TH/generic/THTensor.c | 27 | ||||
-rw-r--r-- | lib/TH/generic/THTensor.h | 1 |
2 files changed, 28 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index 13de6d9..52e838d 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -510,6 +510,33 @@ void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension) } } +void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension) +{ + int d; + + if(!src) + src = self; + + THArgCheck((dimension >= 0) && (dimension <= src->nDimension), 2, "dimension out of range"); + THArgCheck(src->nDimension > 0, 2, "cannot unsqueeze empty tensor"); + + THTensor_(set)(self, src); + + self->size = (long*)THRealloc(self->size, sizeof(long)*(self->nDimension+1)); + self->stride = (long*)THRealloc(self->stride, sizeof(long)*(self->nDimension+1)); + self->nDimension++; + for (d = self->nDimension-1; d > dimension; d--) { + self->size[d] = self->size[d-1]; + self->stride[d] = self->stride[d-1]; + } + if (dimension+1 < self->nDimension) { + self->stride[dimension] = self->size[dimension+1] * self->stride[dimension+1]; + } else { + self->stride[dimension] = 1; + } + self->size[dimension] = 1; +} + int THTensor_(isContiguous)(const THTensor *self) { long z = 1; diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h index 81e3cb0..bfda60e 100644 --- a/lib/TH/generic/THTensor.h +++ b/lib/TH/generic/THTensor.h @@ -101,6 +101,7 @@ TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, lon TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src); TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_); +TH_API void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension_); TH_API int THTensor_(isContiguous)(const THTensor *self); TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src); |