diff options
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 20 |
1 files changed, 15 insertions, 5 deletions
@@ -1,8 +1,20 @@ require "torch" paths.require "libcutorch" -torch.CudaStorage.__tostring__ = torch.FloatStorage.__tostring__ -torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__ +torch.CudaByteStorage.__tostring__ = torch.ByteStorage.__tostring__ +torch.CudaByteTensor.__tostring__ = torch.ByteTensor.__tostring__ +torch.CudaCharStorage.__tostring__ = torch.CharStorage.__tostring__ +torch.CudaCharTensor.__tostring__ = torch.CharTensor.__tostring__ +torch.CudaShortStorage.__tostring__ = torch.ShortStorage.__tostring__ +torch.CudaShortTensor.__tostring__ = torch.ShortTensor.__tostring__ +torch.CudaIntStorage.__tostring__ = torch.IntStorage.__tostring__ +torch.CudaIntTensor.__tostring__ = torch.IntTensor.__tostring__ +torch.CudaLongStorage.__tostring__ = torch.LongStorage.__tostring__ +torch.CudaLongTensor.__tostring__ = torch.LongTensor.__tostring__ +torch.CudaStorage.__tostring__ = torch.FloatStorage.__tostring__ +torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__ +torch.CudaDoubleStorage.__tostring__ = torch.DoubleStorage.__tostring__ +torch.CudaDoubleTensor.__tostring__ = torch.DoubleTensor.__tostring__ include('Tensor.lua') include('FFI.lua') @@ -16,7 +28,7 @@ function cutorch.withDevice(newDeviceID, closure) local vals = {pcall(closure)} cutorch.setDevice(curDeviceID) if vals[1] then - return unpack(vals, 2) + return unpack(vals, 2) end error(unpack(vals, 2)) end @@ -37,6 +49,4 @@ function cutorch.createCudaHostTensor(...) return torch.FloatTensor(storage, 1, size:storage()) end -cutorch.setHeapTracking(true) - return cutorch |