diff options
author | Rui Guo <guorui.xt@gmail.com> | 2016-10-08 16:39:16 +0300 |
---|---|---|
committer | Rui Guo <guorui.xt@gmail.com> | 2016-10-08 16:39:16 +0300 |
commit | c1c5e58b474e2580f13376dc49cbeec841a5c1f8 (patch) | |
tree | e908c7a4b0ed3d45fce1ecf2c7016700473c9861 /lib/THC/generic/THCTensor.c | |
parent | 820becf4fb36954592821a75d7c6e8885be2f724 (diff) |
replace long with ptrdiff_t for memory size/offset, element count
Diffstat (limited to 'lib/THC/generic/THCTensor.c')
-rw-r--r-- | lib/THC/generic/THCTensor.c | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/lib/THC/generic/THCTensor.c b/lib/THC/generic/THCTensor.c index e18044d..f6c82b5 100644 --- a/lib/THC/generic/THCTensor.c +++ b/lib/THC/generic/THCTensor.c @@ -8,7 +8,7 @@ THCStorage *THCTensor_(storage)(THCState *state, const THCTensor *self) return self->storage; } -long THCTensor_(storageOffset)(THCState *state, const THCTensor *self) +ptrdiff_t THCTensor_(storageOffset)(THCState *state, const THCTensor *self) { return self->storageOffset; } @@ -65,7 +65,7 @@ void THCTensor_(clearFlag)(THCState *state, THCTensor *self, const char flag) /**** creation methods ****/ static void THCTensor_(rawInit)(THCState *state, THCTensor *self); -static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, long storageOffset, int nDimension, long *size, long *stride); +static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride); /* Empty init */ @@ -92,7 +92,7 @@ THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor) } /* Storage init */ -THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, long storageOffset, THLongStorage *size, THLongStorage *stride) +THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride) { THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor)); if(size && stride) @@ -109,20 +109,20 @@ THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, long return self; } -THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage, long storageOffset, +THCTensor *THCTensor_(newWithStorage1d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, long size0, long stride0) { return THCTensor_(newWithStorage4d)(state, storage, storageOffset, size0, stride0, -1, -1, -1, -1, -1, -1); } -THCTensor *THCTensor_(newWithStorage2d)(THCState *state, THCStorage *storage, long storageOffset, +THCTensor *THCTensor_(newWithStorage2d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, long size0, long stride0, long size1, long stride1) { return THCTensor_(newWithStorage4d)(state, storage, storageOffset, size0, stride0, size1, stride1, -1, -1, -1, -1); } -THCTensor *THCTensor_(newWithStorage3d)(THCState *state, THCStorage *storage, long storageOffset, +THCTensor *THCTensor_(newWithStorage3d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, long size0, long stride0, long size1, long stride1, long size2, long stride2) @@ -130,7 +130,7 @@ THCTensor *THCTensor_(newWithStorage3d)(THCState *state, THCStorage *storage, lo return THCTensor_(newWithStorage4d)(state, storage, storageOffset, size0, stride0, size1, stride1, size2, stride2, -1, -1); } -THCTensor *THCTensor_(newWithStorage4d)(THCState *state, THCStorage *storage, long storageOffset, +THCTensor *THCTensor_(newWithStorage4d)(THCState *state, THCStorage *storage, ptrdiff_t storageOffset, long size0, long stride0, long size1, long stride1, long size2, long stride2, @@ -296,7 +296,7 @@ void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src) src->stride); } -void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_) +void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_) { if(size_ && stride_) THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes"); @@ -310,7 +310,7 @@ void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storag (stride_ ? stride_->data : NULL)); } -void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, long storageOffset_, +void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, long size0_, long stride0_) { THCTensor_(setStorage4d)(state, self, storage_, storageOffset_, @@ -320,7 +320,7 @@ void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *stor -1, -1); } -void THCTensor_(setStorage2d)(THCState *state, THCTensor *self, THCStorage *storage_, long storageOffset_, +void THCTensor_(setStorage2d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, long size0_, long stride0_, long size1_, long stride1_) { @@ -331,7 +331,7 @@ void THCTensor_(setStorage2d)(THCState *state, THCTensor *self, THCStorage *stor -1, -1); } -void THCTensor_(setStorage3d)(THCState *state, THCTensor *self, THCStorage *storage_, long storageOffset_, +void THCTensor_(setStorage3d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, long size0_, long stride0_, long size1_, long stride1_, long size2_, long stride2_) @@ -343,7 +343,7 @@ void THCTensor_(setStorage3d)(THCState *state, THCTensor *self, THCStorage *stor -1, -1); } -void THCTensor_(setStorage4d)(THCState *state, THCTensor *self, THCStorage *storage_, long storageOffset_, +void THCTensor_(setStorage4d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, long size0_, long stride0_, long size1_, long stride1_, long size2_, long stride2_, @@ -578,13 +578,13 @@ int THCTensor_(isSameSizeAs)(THCState *state, const THCTensor *self, const THCTe return 1; } -long THCTensor_(nElement)(THCState *state, const THCTensor *self) +ptrdiff_t THCTensor_(nElement)(THCState *state, const THCTensor *self) { if(self->nDimension == 0) return 0; else { - long nElement = 1; + ptrdiff_t nElement = 1; int d; for(d = 0; d < self->nDimension; d++) nElement *= self->size[d]; @@ -637,7 +637,7 @@ static void THCTensor_(rawInit)(THCState *state, THCTensor *self) self->flag = TH_TENSOR_REFCOUNTED; } -static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, long storageOffset, int nDimension, long *size, long *stride) +static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride) { /* storage */ if(self->storage != storage) @@ -667,7 +667,7 @@ void THCTensor_(rawResize)(THCState *state, THCTensor *self, int nDimension, lon { int d; int nDimension_; - long totalSize; + ptrdiff_t totalSize; int hascorrectsize = 1; nDimension_ = 0; |