Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua20
1 files changed, 15 insertions, 5 deletions
diff --git a/init.lua b/init.lua
index bd8a0cb..df47f4e 100644
--- a/init.lua
+++ b/init.lua
@@ -1,8 +1,20 @@
require "torch"
paths.require "libcutorch"
-torch.CudaStorage.__tostring__ = torch.FloatStorage.__tostring__
-torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__
+torch.CudaByteStorage.__tostring__ = torch.ByteStorage.__tostring__
+torch.CudaByteTensor.__tostring__ = torch.ByteTensor.__tostring__
+torch.CudaCharStorage.__tostring__ = torch.CharStorage.__tostring__
+torch.CudaCharTensor.__tostring__ = torch.CharTensor.__tostring__
+torch.CudaShortStorage.__tostring__ = torch.ShortStorage.__tostring__
+torch.CudaShortTensor.__tostring__ = torch.ShortTensor.__tostring__
+torch.CudaIntStorage.__tostring__ = torch.IntStorage.__tostring__
+torch.CudaIntTensor.__tostring__ = torch.IntTensor.__tostring__
+torch.CudaLongStorage.__tostring__ = torch.LongStorage.__tostring__
+torch.CudaLongTensor.__tostring__ = torch.LongTensor.__tostring__
+torch.CudaStorage.__tostring__ = torch.FloatStorage.__tostring__
+torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__
+torch.CudaDoubleStorage.__tostring__ = torch.DoubleStorage.__tostring__
+torch.CudaDoubleTensor.__tostring__ = torch.DoubleTensor.__tostring__
include('Tensor.lua')
include('FFI.lua')
@@ -16,7 +28,7 @@ function cutorch.withDevice(newDeviceID, closure)
local vals = {pcall(closure)}
cutorch.setDevice(curDeviceID)
if vals[1] then
- return unpack(vals, 2)
+ return unpack(vals, 2)
end
error(unpack(vals, 2))
end
@@ -37,6 +49,4 @@ function cutorch.createCudaHostTensor(...)
return torch.FloatTensor(storage, 1, size:storage())
end
-cutorch.setHeapTracking(true)
-
return cutorch