Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib/THC
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-03 19:48:08 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:27 +0300
commit61f8e132a92e0f935cfa4f1eb4a7575f77792702 (patch)
tree7dd39eac7bdbc70dc2f964cbca57e440e50132c0 /lib/THC
parentaa39a6cd8aa0f2078f85618727443e5456815900 (diff)
[cutorch refactor] move stdall into generic, wrap test for std
Diffstat (limited to 'lib/THC')
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMath2.cu6
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu6
-rw-r--r--lib/THC/generic/THCTensorMathReduce.h2
4 files changed, 7 insertions, 8 deletions
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 8ee93d9..292b5b1 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -73,7 +73,6 @@ THC_API void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTenso
THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension);
THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag);
-THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self);
THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index a9dc277..b28f1c3 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -125,12 +125,6 @@ void THCudaTensor_lerp(THCState *state, THCudaTensor *result, THCudaTensor *a, T
THCudaCheck(cudaGetLastError());
}
-float THCudaTensor_stdall(THCState *state, THCudaTensor *self)
-{
- THAssert(THCudaTensor_checkGPU(state, 1, self));
- return sqrt(THCudaTensor_varall(state, self));
-}
-
void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
index 4103d48..a9e58a9 100644
--- a/lib/THC/generic/THCTensorMathReduce.cu
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -89,6 +89,12 @@ void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dim
THCTensor_(freeCopyTo)(state, self, self_);
}
+accreal THCTensor_(stdall)(THCState *state, THCTensor *self)
+{
+ THAssert(THCTensor_(checkGPU)(state, 1, self));
+ return THCNumerics<accreal>::sqrt((THCTensor_(varall)(state, self)));
+}
+
THC_API accreal
THCTensor_(varall)(THCState *state, THCTensor *self)
{
diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h
index 3aefb17..77bcb20 100644
--- a/lib/THC/generic/THCTensorMathReduce.h
+++ b/lib/THC/generic/THCTensorMathReduce.h
@@ -8,8 +8,8 @@ THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src
THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag);
THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension);
+THC_API accreal THCTensor_(stdall)(THCState *state, THCTensor *self);
THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value);
-
THC_API accreal THCTensor_(varall)(THCState *state, THCTensor *self);
#endif