Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2017-02-08 19:48:06 +0300
committerSam Gross <sgross@fb.com>2017-02-08 20:52:50 +0300
commitfc2017b01f372f7b29b3a023192f89998f2378fb (patch)
tree8d81201e74a266df4e24241b30233fb029418482
parent68f34f6d6e7085b5b75655d4ac0a0e7bb5ac2f43 (diff)
Add unsqueeze1d to TH
Unsqueeze inserts a singleton dimension. Unlike view, it doesn't require the tensor to be contiguous.
-rw-r--r--lib/TH/generic/THTensor.c27
-rw-r--r--lib/TH/generic/THTensor.h1
2 files changed, 28 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c
index 13de6d9..52e838d 100644
--- a/lib/TH/generic/THTensor.c
+++ b/lib/TH/generic/THTensor.c
@@ -510,6 +510,33 @@ void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension)
}
}
+void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension)
+{
+ int d;
+
+ if(!src)
+ src = self;
+
+ THArgCheck((dimension >= 0) && (dimension <= src->nDimension), 2, "dimension out of range");
+ THArgCheck(src->nDimension > 0, 2, "cannot unsqueeze empty tensor");
+
+ THTensor_(set)(self, src);
+
+ self->size = (long*)THRealloc(self->size, sizeof(long)*(self->nDimension+1));
+ self->stride = (long*)THRealloc(self->stride, sizeof(long)*(self->nDimension+1));
+ self->nDimension++;
+ for (d = self->nDimension-1; d > dimension; d--) {
+ self->size[d] = self->size[d-1];
+ self->stride[d] = self->stride[d-1];
+ }
+ if (dimension+1 < self->nDimension) {
+ self->stride[dimension] = self->size[dimension+1] * self->stride[dimension+1];
+ } else {
+ self->stride[dimension] = 1;
+ }
+ self->size[dimension] = 1;
+}
+
int THTensor_(isContiguous)(const THTensor *self)
{
long z = 1;
diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h
index 81e3cb0..bfda60e 100644
--- a/lib/TH/generic/THTensor.h
+++ b/lib/TH/generic/THTensor.h
@@ -101,6 +101,7 @@ TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, lon
TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src);
TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_);
+TH_API void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension_);
TH_API int THTensor_(isContiguous)(const THTensor *self);
TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src);