Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlykhan Tejani <alykhan.tejani@gmail.com>2017-08-19 14:34:51 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-25 21:10:42 +0300
commit9d8f9087695e435e0829f466a69343d1995c5c3b (patch)
tree56ffefe6a4864c33a449995dbe929634306d80a0
parenta4cc7a3492621aeb0579a3b4358b41e4b613a973 (diff)
add ones_like and zeros_like
-rw-r--r--lib/TH/generic/THTensorMath.c12
-rw-r--r--lib/TH/generic/THTensorMath.h2
2 files changed, 14 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index 9d2a7b4..43cbf83 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -1927,6 +1927,18 @@ void THTensor_(zeros)(THTensor *r_, THLongStorage *size)
THTensor_(zero)(r_);
}
+void THTensor_(zerosLike)(THTensor *r_, THTensor *input)
+{
+ THTensor_(resizeAs)(r_, input);
+ THTensor_(zero)(r_);
+}
+
+void THTensor_(onesLike)(THTensor *r_, THTensor *input)
+{
+ THTensor_(resizeAs)(r_, input);
+ THTensor_(fill)(r_, 1);
+}
+
void THTensor_(ones)(THTensor *r_, THLongStorage *size)
{
THTensor_(resize)(r_, size, NULL);
diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h
index d0963b1..5f38701 100644
--- a/lib/TH/generic/THTensorMath.h
+++ b/lib/TH/generic/THTensorMath.h
@@ -90,7 +90,9 @@ TH_API void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value);
TH_API void THTensor_(cminValue)(THTensor *r, THTensor *t, real value);
TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size);
+TH_API void THTensor_(zerosLike)(THTensor *r_, THTensor *input);
TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size);
+TH_API void THTensor_(onesLike)(THTensor *r_, THTensor *input);
TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k);
TH_API void THTensor_(eye)(THTensor *r_, long n, long m);
TH_API void THTensor_(arange)(THTensor *r_, accreal xmin, accreal xmax, accreal step);