diff options
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 16 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 2 |
2 files changed, 18 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index b9d1412..628240a 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -44,6 +44,14 @@ THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size) } THC_API void +THCTensor_(zerosLike)(THCState *state, THCTensor *r_, THCTensor *input) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, r_, input)); + THCTensor_(resizeAs)(state, r_, input); + THCTensor_(zero)(state, r_); +} + +THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_)); @@ -52,6 +60,14 @@ THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size) } THC_API void +THCTensor_(onesLike)(THCState *state, THCTensor *r_, THCTensor *input) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, r_, input)); + THCTensor_(resizeAs)(state, r_, input); + THCTensor_(fill)(state, r_, ScalarConvert<int, real>::to(1)); +} + +THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, r_, t)); diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 7b83d02..26a7a49 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -6,7 +6,9 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, real value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size); +THC_API void THCTensor_(zerosLike)(THCState *state, THCTensor *r_, THCTensor* input); THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size); +THC_API void THCTensor_(onesLike)(THCState *state, THCTensor *r_, THCTensor* input); THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size); THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); |