diff options
author | Francisco Massa <fvsmassa@gmail.com> | 2017-08-12 16:11:22 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-17 00:23:28 +0300 |
commit | 32a9708dc2edb94ed73f6f5fa72c3b600ee598a8 (patch) | |
tree | 221dc7874f1b32243428b59ecec32ad5eac5767c /lib | |
parent | cec53c34e71dbbbb53574f6371c6d7e4a9f6757b (diff) |
Add CUDA version of eye
Diffstat (limited to 'lib')
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 22 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 1 |
2 files changed, 23 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index d47102c..b9d1412 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -376,6 +376,28 @@ void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, long k THCudaCheck(cudaGetLastError()); } +void THCTensor_(eye)(THCState *state, THCTensor *self_, long n, long m) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_)); + THArgCheck(n > 0, 1, "invalid argument"); + + if(m <= 0) + m = n; + + THCTensor_(resize2d)(state, self_, n, m); + THCTensor_(zero)(state, self_); + + long sz = THMin(n, m); + long stride = THCTensor_(stride)(state, self_, 0) + + THCTensor_(stride)(state, self_, 1); + + THCTensor *diag = THCTensor_(newWithStorage1d)(state, self_->storage, + self_->storageOffset, sz, stride); + + THCTensor_(fill)(state, diag, ScalarConvert<int, real>::to(1)); + THCTensor_(free)(state, diag); +} + accreal THCTensor_(trace)(THCState *state, THCTensor *src_) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, src_)); THArgCheck((src_->nDimension == 2), 1, "expected a matrix"); diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index c8fb35b..7b83d02 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -16,6 +16,7 @@ THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCT THC_API void THCTensor_(tril)(THCState *state, THCTensor *self, THCTensor *src, long k); THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, long k); THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, long k); +THC_API void THCTensor_(eye)(THCState *state, THCTensor *self, long n, long m); THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self); #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) |