Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlykhan Tejani <alykhan.tejani@gmail.com>2017-08-19 14:34:51 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-25 21:10:52 +0300
commite42a37e92fb94be6d38b5f9f587613b95167927c (patch)
treed02507328614bb1c1066c913a1862f8c2c3bad50
parentd0bb7e12cbfbae560b02b4226d7eb861bd7f48af (diff)
add ones_like and zeros_like
-rw-r--r--lib/THC/generic/THCTensorMath.cu16
-rw-r--r--lib/THC/generic/THCTensorMath.h2
2 files changed, 18 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu
index b9d1412..628240a 100644
--- a/lib/THC/generic/THCTensorMath.cu
+++ b/lib/THC/generic/THCTensorMath.cu
@@ -44,6 +44,14 @@ THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size)
}
THC_API void
+THCTensor_(zerosLike)(THCState *state, THCTensor *r_, THCTensor *input)
+{
+ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, r_, input));
+ THCTensor_(resizeAs)(state, r_, input);
+ THCTensor_(zero)(state, r_);
+}
+
+THC_API void
THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_));
@@ -52,6 +60,14 @@ THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size)
}
THC_API void
+THCTensor_(onesLike)(THCState *state, THCTensor *r_, THCTensor *input)
+{
+ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, r_, input));
+ THCTensor_(resizeAs)(state, r_, input);
+ THCTensor_(fill)(state, r_, ScalarConvert<int, real>::to(1));
+}
+
+THC_API void
THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, r_, t));
diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h
index 7b83d02..26a7a49 100644
--- a/lib/THC/generic/THCTensorMath.h
+++ b/lib/THC/generic/THCTensorMath.h
@@ -6,7 +6,9 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, real value);
THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size);
+THC_API void THCTensor_(zerosLike)(THCState *state, THCTensor *r_, THCTensor* input);
THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size);
+THC_API void THCTensor_(onesLike)(THCState *state, THCTensor *r_, THCTensor* input);
THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size);
THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);