Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCGeneral.c')
-rw-r--r--lib/THC/THCGeneral.c25
1 files changed, 13 insertions, 12 deletions
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index 07dbf2c..5bcce19 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -60,8 +60,8 @@ struct THCState {
void (*cutorchGCFunction)(void *data);
void *cutorchGCData;
- long heapSoftmax;
- long heapDelta;
+ ptrdiff_t heapSoftmax;
+ ptrdiff_t heapDelta;
};
THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr(
@@ -640,8 +640,9 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
}
}
-static long heapSize = 0; // not thread-local
-static const long heapMaxDelta = 1e6;
+static ptrdiff_t heapSize = 0; // not thread-local
+static const ptrdiff_t heapMaxDelta = (ptrdiff_t)1e6;
+static const ptrdiff_t heapMinDelta = (ptrdiff_t)-1e6;
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%
@@ -691,8 +692,8 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
return allocator->free(allocator->state, ptr);
}
-static long applyHeapDelta(THCState *state) {
- long newHeapSize = THAtomicAddLong(&heapSize, state->heapDelta) + state->heapDelta;
+static ptrdiff_t applyHeapDelta(THCState *state) {
+ ptrdiff_t newHeapSize = THAtomicAddPtrdiff(&heapSize, state->heapDelta) + state->heapDelta;
state->heapDelta = 0;
return newHeapSize;
}
@@ -701,27 +702,27 @@ static long applyHeapDelta(THCState *state) {
// When THC heap size goes above this softmax, the GC hook is triggered.
// If heap size is above 80% of the softmax after GC, then the softmax is
// increased.
-static void maybeTriggerGC(THCState *state, long curHeapSize) {
+static void maybeTriggerGC(THCState *state, ptrdiff_t curHeapSize) {
if (state->cutorchGCFunction != NULL && curHeapSize > state->heapSoftmax) {
(state->cutorchGCFunction)(state->cutorchGCData);
// ensure heapSize is accurate before updating heapSoftmax
- long newHeapSize = applyHeapDelta(state);
+ ptrdiff_t newHeapSize = applyHeapDelta(state);
if (newHeapSize > state->heapSoftmax * heapSoftmaxGrowthThresh) {
- state->heapSoftmax = state->heapSoftmax * heapSoftmaxGrowthFactor;
+ state->heapSoftmax = (ptrdiff_t)state->heapSoftmax * heapSoftmaxGrowthFactor;
}
}
}
-void THCHeapUpdate(THCState *state, long size) {
+void THCHeapUpdate(THCState *state, ptrdiff_t size) {
state->heapDelta += size;
// batch updates to global heapSize to minimize thread contention
- if (labs(state->heapDelta) < heapMaxDelta) {
+ if (state->heapDelta < heapMaxDelta && state->heapDelta > heapMinDelta) {
return;
}
- long newHeapSize = applyHeapDelta(state);
+ ptrdiff_t newHeapSize = applyHeapDelta(state);
if (size > 0) {
maybeTriggerGC(state, newHeapSize);
}