diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-05 18:47:17 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | af459755c0d2477342aead1a645cb4969a7dd215 (patch) | |
tree | 32e7e56a040a10bac9240dd5d95ad591d020ddc6 /lib/THC | |
parent | 12076f677505257ac945ea0b092cb54e42ccecaa (diff) |
[cutorch refactor] make var(...) generic
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 21 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 22 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 1 |
4 files changed, 23 insertions, 22 deletions
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 292b5b1..21482b7 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -72,7 +72,6 @@ THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *r THC_API void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension); THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension); -THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag); THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index b28f1c3..84a5a1c 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -125,27 +125,6 @@ void THCudaTensor_lerp(THCState *state, THCudaTensor *result, THCudaTensor *a, T THCudaCheck(cudaGetLastError()); } -void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); - THLongStorage *dim = THCudaTensor_newSizeOf(state, src); - THLongStorage_set(dim, dimension, 1); - THCudaTensor_resize(state, self_, dim, NULL); - THLongStorage_free(dim); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - - if (dimension == THCudaTensor_nDimension(state, src) - 1) { - THCTensor_varInnermostDim<THCudaTensor, float, false>(state, self, src, flag); - } else { - THCTensor_varOuterDim<THCudaTensor, float, false>(state, self, src, dimension, flag); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - struct dist_functor { const float exponent; diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 5e1ea00..502fa75 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -91,6 +91,28 @@ THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, long dimensio THCTensor_(freeCopyTo)(state, self, self_); } +THC_API void +THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); + THLongStorage *dim = THCTensor_(newSizeOf)(state, src); + THLongStorage_set(dim, dimension, 1); + THCTensor_(resize)(state, self_, dim, NULL); + THLongStorage_free(dim); + + THCTensor *self = THCTensor_(newContiguous)(state, self_); + src = THCTensor_(newContiguous)(state, src); + + if (dimension == THCTensor_(nDimension)(state, src) - 1) { + THCTensor_varInnermostDim<THCTensor, real, false>(state, self, src, flag); + } else { + THCTensor_varOuterDim<THCTensor, real, false>(state, self, src, dimension, flag); + } + + THCTensor_(free)(state, src); + THCTensor_(freeCopyTo)(state, self, self_); +} + THC_API accreal THCTensor_(stdall)(THCState *state, THCTensor *self) { diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index 77bcb20..09a26fc 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -7,6 +7,7 @@ THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real max_norm); THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension); +THC_API void THCTensor_(var)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); THC_API accreal THCTensor_(stdall)(THCState *state, THCTensor *self); THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value); |