Welcome to mirror list, hosted at ThFree Co, Russian Federation.

init.lua - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b397abbbfac626e8fd7205b75c5c241d71ac592d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
require "torch"
paths.require("libcutorch")

torch.CudaByteStorage.__tostring__   = torch.ByteStorage.__tostring__
torch.CudaByteTensor.__tostring__    = torch.ByteTensor.__tostring__
torch.CudaCharStorage.__tostring__   = torch.CharStorage.__tostring__
torch.CudaCharTensor.__tostring__    = torch.CharTensor.__tostring__
torch.CudaShortStorage.__tostring__  = torch.ShortStorage.__tostring__
torch.CudaShortTensor.__tostring__   = torch.ShortTensor.__tostring__
torch.CudaIntStorage.__tostring__    = torch.IntStorage.__tostring__
torch.CudaIntTensor.__tostring__     = torch.IntTensor.__tostring__
torch.CudaLongStorage.__tostring__   = torch.LongStorage.__tostring__
torch.CudaLongTensor.__tostring__    = torch.LongTensor.__tostring__
torch.CudaStorage.__tostring__       = torch.FloatStorage.__tostring__
torch.CudaTensor.__tostring__        = torch.FloatTensor.__tostring__
torch.CudaDoubleStorage.__tostring__ = torch.DoubleStorage.__tostring__
torch.CudaDoubleTensor.__tostring__  = torch.DoubleTensor.__tostring__
if cutorch.hasHalf then
   torch.CudaHalfStorage.__tostring__  = torch.FloatStorage.__tostring__
   torch.CudaHalfTensor.__tostring__  = torch.FloatTensor.__tostring__
end

require('cutorch.Tensor')
require('cutorch.FFI')
require('cutorch.test')

local unpack = unpack or table.unpack

function cutorch.withDevice(newDeviceID, closure)
    local curDeviceID = cutorch.getDevice()
    cutorch.setDevice(newDeviceID)
    local vals = {pcall(closure)}
    cutorch.setDevice(curDeviceID)
    if vals[1] then
       return unpack(vals, 2)
    end
    error(unpack(vals, 2))
end

-- Creates a FloatTensor using the CudaHostAllocator.
-- Accepts either a LongStorage or a sequence of numbers.
function cutorch.createCudaHostTensor(...)
   local size
   if not ... then
      size = torch.LongTensor{0}
   elseif torch.isStorage(...) then
      size = torch.LongTensor(...)
   else
      size = torch.LongTensor{...}
   end

   local storage = torch.FloatStorage(cutorch.CudaHostAllocator, size:prod())
   return torch.FloatTensor(storage, 1, size:storage())
end

-- remove this line to disable automatic cutorch heap-tracking
-- for garbage collection
cutorch.setHeapTracking(true)

return cutorch