Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2015-01-12 13:59:33 +0300
committerDominik Grewe <dominikg@google.com>2015-01-14 12:54:25 +0300
commit00dd58ce5ba784bff6d713b48198af116288f010 (patch)
tree0917a344fb1aee9d90a27748b8e4fe98d7d5de3e /Storage.c
parentedc92de73d760ce36dbd04f53bc1c3727a50e3ad (diff)
Pass a state to every THC function.
Every THC function gets a THCState pointer as the first argument. Some generic files that were previously included have been instantiated because TH functions currently don't get a state parameter.
Diffstat (limited to 'Storage.c')
-rw-r--r--Storage.c34
1 files changed, 32 insertions, 2 deletions
diff --git a/Storage.c b/Storage.c
index f378ef2..11ea696 100644
--- a/Storage.c
+++ b/Storage.c
@@ -1,3 +1,4 @@
+#include "torch/utils.h"
#include "THC.h"
#include "THFile.h"
#include "luaT.h"
@@ -36,6 +37,36 @@
/* now we overwrite some methods specific to CudaStorage */
+static int cutorch_CudaStorage_copy(lua_State *L)
+{
+ THCState *state = cutorch_getstate(L);
+ THCudaStorage *storage = luaT_checkudata(L, 1, "torch.CudaStorage");
+ void *src;
+ if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) )
+ THCudaStorage_copy(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) )
+ THCudaStorage_copyByte(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) )
+ THCudaStorage_copyChar(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) )
+ THCudaStorage_copyShort(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) )
+ THCudaStorage_copyInt(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) )
+ THCudaStorage_copyLong(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) )
+ THCudaStorage_copyFloat(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
+ THCudaStorage_copyDouble(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) )
+ THCudaStorage_copyCuda(state, storage, src);
+ else
+ luaL_typerror(L, 2, "torch.*Storage");
+
+ lua_settop(L, 1);
+ return 1;
+}
+
#define CUDA_IMPLEMENT_STORAGE_COPY(TYPEC) \
static int cutorch_##TYPEC##Storage_copy(lua_State *L) \
{ \
@@ -58,7 +89,7 @@
else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) \
TH##TYPEC##Storage_copyDouble(storage, src); \
else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) \
- TH##TYPEC##Storage_copyCuda(storage, src); \
+ TH##TYPEC##Storage_copyCuda(cutorch_getstate(L), storage, src); \
else \
luaL_typerror(L, 2, "torch.*Storage"); \
\
@@ -73,7 +104,6 @@ CUDA_IMPLEMENT_STORAGE_COPY(Int)
CUDA_IMPLEMENT_STORAGE_COPY(Long)
CUDA_IMPLEMENT_STORAGE_COPY(Float)
CUDA_IMPLEMENT_STORAGE_COPY(Double)
-CUDA_IMPLEMENT_STORAGE_COPY(Cuda)
void cutorch_CudaStorage_init(lua_State* L)
{