Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-09 22:09:30 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-12 00:23:02 +0300
commit86e3d253f9262358bd11fab4b1d27ebe72b23003 (patch)
treea3847df7d56f0317dc5745fad13bd95832996888
parentb33d92cb7d28b0764f8278dde2e160097f858457 (diff)
[cutorch rand2gen] move randn to generic
-rw-r--r--TensorMath.lua5
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMath2.cu6
-rw-r--r--lib/THC/generic/THCTensorRandom.cu7
-rw-r--r--lib/THC/generic/THCTensorRandom.h1
5 files changed, 13 insertions, 7 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index c6dada5..f94154d 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -939,6 +939,11 @@ for k, Tensor_ in pairs(handledTypenames) do
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name="LongArg"}})
+ wrap("randn",
+ cname("randn"),
+ {{name=Tensor, default=true, returned=true, method={default='nil'}},
+ {name="LongArg"}})
+
wrap("multinomial",
cname("multinomial"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 84b6f1b..7cbef32 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -53,7 +53,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
-THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size);
THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index fb94cb6..7e6af9b 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -28,9 +28,3 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
-void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size)
-{
- THAssert(THCudaTensor_checkGPU(state, 1, r_));
- THCudaTensor_resize(state, r_, size, NULL);
- THCudaTensor_normal(state, r_, 0, 1);
-}
diff --git a/lib/THC/generic/THCTensorRandom.cu b/lib/THC/generic/THCTensorRandom.cu
index c410956..9af79c1 100644
--- a/lib/THC/generic/THCTensorRandom.cu
+++ b/lib/THC/generic/THCTensorRandom.cu
@@ -279,6 +279,13 @@ THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *siz
THCTensor_(uniform)(state, r_, 0, 1);
}
+void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, r_));
+ THCTensor_(resize)(state, r_, size, NULL);
+ THCTensor_(normal)(state, r_, 0, 1);
+}
+
#endif
GENERATE_KERNEL1(generate_bernoulli, real, double p, float, curand_uniform, (ScalarConvert<bool, real>::to(x <= p)))
diff --git a/lib/THC/generic/THCTensorRandom.h b/lib/THC/generic/THCTensorRandom.h
index ebb3852..a2896c3 100644
--- a/lib/THC/generic/THCTensorRandom.h
+++ b/lib/THC/generic/THCTensorRandom.h
@@ -6,6 +6,7 @@
THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double a, double b);
THC_API void THCTensor_(rand)(THCState *state, THCTensor *r_, THLongStorage *size);
+THC_API void THCTensor_(randn)(THCState *state, THCTensor *r_, THLongStorage *size);
THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);