Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCStorage.c « generic « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: e7c529c24fd71a52172719f515afc519a875731c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "THCStorage.h"
#include "THCGeneral.h"
#include "THAtomic.h"

float* THCudaStorage_data(THCState *state, const THCudaStorage *self)
{
  return self->data;
}

long THCudaStorage_size(THCState *state, const THCudaStorage *self)
{
  return self->size;
}

int THCudaStorage_elementSize(THCState *state)
{
  return sizeof(float);
}

void THCudaStorage_set(THCState *state, THCudaStorage *self, long index, float value)
{
  THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds");
  THCudaCheck(cudaMemcpy(self->data + index, &value, sizeof(float), cudaMemcpyHostToDevice));
}

float THCudaStorage_get(THCState *state, const THCudaStorage *self, long index)
{
  float value;
  THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds");
  THCudaCheck(cudaMemcpy(&value, self->data + index, sizeof(float), cudaMemcpyDeviceToHost));
  return value;
}

THCudaStorage* THCudaStorage_new(THCState *state)
{
  THCudaStorage *storage = (THCudaStorage*)THAlloc(sizeof(THCudaStorage));
  storage->data = NULL;
  storage->size = 0;
  storage->refcount = 1;
  storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
  return storage;
}

THCudaStorage* THCudaStorage_newWithSize(THCState *state, long size)
{
  THArgCheck(size >= 0, 2, "invalid size");

  if(size > 0)
  {
    THCudaStorage *storage = (THCudaStorage*)THAlloc(sizeof(THCudaStorage));

    // update heap *before* attempting malloc, to free space for the malloc
    THCHeapUpdate(state, size * sizeof(float));
    cudaError_t err =
      THCudaMalloc(state, (void**)&(storage->data), size * sizeof(float));
    if(err != cudaSuccess){
      THCHeapUpdate(state, -size * sizeof(float));
    }
    THCudaCheck(err);

    storage->size = size;
    storage->refcount = 1;
    storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
    return storage;
  }
  else
  {
    return THCudaStorage_new(state);
  }
}

THCudaStorage* THCudaStorage_newWithSize1(THCState *state, float data0)
{
  THCudaStorage *self = THCudaStorage_newWithSize(state, 1);
  THCudaStorage_set(state, self, 0, data0);
  return self;
}

THCudaStorage* THCudaStorage_newWithSize2(THCState *state, float data0, float data1)
{
  THCudaStorage *self = THCudaStorage_newWithSize(state, 2);
  THCudaStorage_set(state, self, 0, data0);
  THCudaStorage_set(state, self, 1, data1);
  return self;
}

THCudaStorage* THCudaStorage_newWithSize3(THCState *state, float data0, float data1, float data2)
{
  THCudaStorage *self = THCudaStorage_newWithSize(state, 3);
  THCudaStorage_set(state, self, 0, data0);
  THCudaStorage_set(state, self, 1, data1);
  THCudaStorage_set(state, self, 2, data2);
  return self;
}

THCudaStorage* THCudaStorage_newWithSize4(THCState *state, float data0, float data1, float data2, float data3)
{
  THCudaStorage *self = THCudaStorage_newWithSize(state, 4);
  THCudaStorage_set(state, self, 0, data0);
  THCudaStorage_set(state, self, 1, data1);
  THCudaStorage_set(state, self, 2, data2);
  THCudaStorage_set(state, self, 3, data3);
  return self;
}

THCudaStorage* THCudaStorage_newWithMapping(THCState *state, const char *fileName, long size, int isShared)
{
  THError("not available yet for THCudaStorage");
  return NULL;
}

THCudaStorage* THCudaStorage_newWithData(THCState *state, float *data, long size)
{
  THCudaStorage *storage = (THCudaStorage*)THAlloc(sizeof(THCudaStorage));
  storage->data = data;
  storage->size = size;
  storage->refcount = 1;
  storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
  return storage;
}

void THCudaStorage_setFlag(THCState *state, THCudaStorage *storage, const char flag)
{
  storage->flag |= flag;
}

void THCudaStorage_clearFlag(THCState *state, THCudaStorage *storage, const char flag)
{
  storage->flag &= ~flag;
}

void THCudaStorage_retain(THCState *state, THCudaStorage *self)
{
  if(self && (self->flag & TH_STORAGE_REFCOUNTED))
    THAtomicIncrementRef(&self->refcount);
}

void THCudaStorage_free(THCState *state, THCudaStorage *self)
{
  if(!(self->flag & TH_STORAGE_REFCOUNTED))
    return;

  if (THAtomicDecrementRef(&self->refcount))
  {
    if(self->flag & TH_STORAGE_FREEMEM) {
      THCHeapUpdate(state, -self->size * sizeof(float));
      THCudaCheck(THCudaFree(state, self->data));
    }
    THFree(self);
  }
}