Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2017-08-12 16:11:22 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-17 00:23:28 +0300
commit32a9708dc2edb94ed73f6f5fa72c3b600ee598a8 (patch)
tree221dc7874f1b32243428b59ecec32ad5eac5767c /lib
parentcec53c34e71dbbbb53574f6371c6d7e4a9f6757b (diff)
Add CUDA version of eye
Diffstat (limited to 'lib')
-rw-r--r--lib/THC/generic/THCTensorMath.cu22
-rw-r--r--lib/THC/generic/THCTensorMath.h1
2 files changed, 23 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu
index d47102c..b9d1412 100644
--- a/lib/THC/generic/THCTensorMath.cu
+++ b/lib/THC/generic/THCTensorMath.cu
@@ -376,6 +376,28 @@ void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, long k
THCudaCheck(cudaGetLastError());
}
+void THCTensor_(eye)(THCState *state, THCTensor *self_, long n, long m)
+{
+ THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
+ THArgCheck(n > 0, 1, "invalid argument");
+
+ if(m <= 0)
+ m = n;
+
+ THCTensor_(resize2d)(state, self_, n, m);
+ THCTensor_(zero)(state, self_);
+
+ long sz = THMin(n, m);
+ long stride = THCTensor_(stride)(state, self_, 0) +
+ THCTensor_(stride)(state, self_, 1);
+
+ THCTensor *diag = THCTensor_(newWithStorage1d)(state, self_->storage,
+ self_->storageOffset, sz, stride);
+
+ THCTensor_(fill)(state, diag, ScalarConvert<int, real>::to(1));
+ THCTensor_(free)(state, diag);
+}
+
accreal THCTensor_(trace)(THCState *state, THCTensor *src_) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, src_));
THArgCheck((src_->nDimension == 2), 1, "expected a matrix");
diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h
index c8fb35b..7b83d02 100644
--- a/lib/THC/generic/THCTensorMath.h
+++ b/lib/THC/generic/THCTensorMath.h
@@ -16,6 +16,7 @@ THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCT
THC_API void THCTensor_(tril)(THCState *state, THCTensor *self, THCTensor *src, long k);
THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, long k);
THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, long k);
+THC_API void THCTensor_(eye)(THCState *state, THCTensor *self, long n, long m);
THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self);
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)