From 9d8f9087695e435e0829f466a69343d1995c5c3b Mon Sep 17 00:00:00 2001 From: Alykhan Tejani Date: Sat, 19 Aug 2017 12:34:51 +0100 Subject: add ones_like and zeros_like --- lib/TH/generic/THTensorMath.c | 12 ++++++++++++ lib/TH/generic/THTensorMath.h | 2 ++ 2 files changed, 14 insertions(+) diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 9d2a7b4..43cbf83 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -1927,6 +1927,18 @@ void THTensor_(zeros)(THTensor *r_, THLongStorage *size) THTensor_(zero)(r_); } +void THTensor_(zerosLike)(THTensor *r_, THTensor *input) +{ + THTensor_(resizeAs)(r_, input); + THTensor_(zero)(r_); +} + +void THTensor_(onesLike)(THTensor *r_, THTensor *input) +{ + THTensor_(resizeAs)(r_, input); + THTensor_(fill)(r_, 1); +} + void THTensor_(ones)(THTensor *r_, THLongStorage *size) { THTensor_(resize)(r_, size, NULL); diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index d0963b1..5f38701 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -90,7 +90,9 @@ TH_API void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value); TH_API void THTensor_(cminValue)(THTensor *r, THTensor *t, real value); TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size); +TH_API void THTensor_(zerosLike)(THTensor *r_, THTensor *input); TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size); +TH_API void THTensor_(onesLike)(THTensor *r_, THTensor *input); TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k); TH_API void THTensor_(eye)(THTensor *r_, long n, long m); TH_API void THTensor_(arange)(THTensor *r_, accreal xmin, accreal xmax, accreal step); -- cgit v1.2.3