Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialFullConvolution.lua - github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 9093ce41288c356597e78ce21c0353e8c2f44486 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
local SpatialFullConvolution, parent =
    torch.class('cudnn.SpatialFullConvolution', 'nn.SpatialFullConvolution')
local ffi = require 'ffi'
local find = require 'cudnn.find'
local errcheck = cudnn.errcheck
local checkedCall = find.checkedCall

local Convolution = cudnn.SpatialConvolution

function SpatialFullConvolution:__init(nInputPlane, nOutputPlane,
                            kW, kH, dW, dH, padW, padH, adjW, adjH, groups)
    local delayedReset = self.reset
    self.reset = function() end
    parent.__init(self, nInputPlane, nOutputPlane, 
                         kW, kH, dW, dH, padW, padH, adjW, adjH)
    self.reset = delayedReset
    self.groups = groups or 1
    assert(nInputPlane % self.groups == 0,
           'nInputPlane should be divisible by nGroups')
    assert(nOutputPlane % self.groups == 0,
           'nOutputPlane should be divisible by nGroups')
    self.weight = torch.Tensor(nInputPlane, nOutputPlane/self.groups, kH, kW)
    self.gradWeight = torch.Tensor(nInputPlane, nOutputPlane/self.groups, kH, kW)
    self:reset()
    -- should nil for serialization, the reset will still work
    self.reset = nil			 
end
                                                            
function SpatialFullConvolution:resetWeightDescriptors()
   return Convolution.resetWeightDescriptors(self, {self.nInputPlane,
                                                    self.nOutputPlane/self.groups,
                                                    self.kH, self.kW})
end

function SpatialFullConvolution:fastest(mode)
   return Convolution.fastest(self, mode)
end

function SpatialFullConvolution:setMode(fmode, bdmode, bwmode)
   return Convolution.setMode(self, fmode, bdmode, bwmode)
end

function SpatialFullConvolution:resetMode()
   return Convolution.resetMode(self)
end

function SpatialFullConvolution:noBias()
   return Convolution.noBias(self)
end

function SpatialFullConvolution:createIODescriptors(input)
    local batch = true
    if input:dim() == 3 then
        input = input:view(1, input:size(1), input:size(2), input:size(3))
        batch = false
    end
    assert(input:dim() == 4 and input:isContiguous());
    self.iSize = self.iSize or torch.LongStorage(4):fill(0)

    if Convolution.checkInputChanged(self, input) then
        -- create input descriptor
        local input_slice = input[{{},{1,self.nInputPlane},{},{}}]
        self.iDesc = cudnn.toDescriptor(input_slice)

        -- create conv descriptor
        self.pad = {self.padH, self.padW}
        self.stride = {self.dH, self.dW}

        self.convDesc = cudnn.setConvolutionDescriptor({
	                      padA = self.pad,
                              filterStrideA = self.stride,
                              dataType = cudnn.configmap(torch.type(self.weight)),
                              groupCount = self.groups
        })

        -- get output shape, resize output
        local iwidth = input:size(4)
        local iheight = input:size(3)
        local owidth = (iwidth - 1) * self.dW - 2*self.padW + self.kW + self.adjW
        local oheight = (iheight - 1) * self.dH - 2*self.padH + self.kH + self.adjH
        local oSize = torch.IntTensor({input:size(1), self.nOutputPlane, oheight, owidth})
        self.output:resize(oSize:long():storage())

        -- create descriptor for output
        local output_slice = self.output[{{},{1,self.nOutputPlane},{},{}}]
        self.oDesc = cudnn.toDescriptor(output_slice)
        self.oDescForBias = cudnn.toDescriptor(self.output)

        self.input_offset = 0
        self.output_offset = 0
        self.weight_offset = 0

        find:prepare(self, input_slice, output_slice)

        if not batch then
            self.output = self.output:view(self.output:size(2),
                                           self.output:size(3),
                                           self.output:size(4))
        end
    end
end

function SpatialFullConvolution:updateOutput(input)
    self:backCompatibility()
    if not self.weightDesc then self:resetWeightDescriptors() end
    self:createIODescriptors(input)
    local finder = find.get()
    local bwdDataAlgo = finder:backwardDataAlgorithm(self, {self.weightDesc[0], self.weight,
                                                            self.iDesc[0],self.input_slice,
                                                            self.convDesc[0], self.oDesc[0], self.output_slice})
    local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace()

    -- Because SpatialFullConvolution is performing the adjoint of the forward
    -- convolution operator, we need to swap the forward and backward passes.
    checkedCall(self,'cudnnConvolutionBackwardData', cudnn.getHandle(),
                cudnn.scalar(input, 1),
                self.weightDesc[0], self.weight:data(),
                self.iDesc[0], input:data(),
                self.convDesc[0], bwdDataAlgo,
                extraBuffer, extraBufferSize,
                cudnn.scalar(input, 0),
                self.oDesc[0], self.output:data())

    -- add bias
    if self.bias then
        errcheck('cudnnAddTensor', cudnn.getHandle(),
                 cudnn.scalar(input, 1), self.biasDesc[0], self.bias:data(),
                 cudnn.scalar(input, 1), self.oDescForBias[0], self.output:data())
    end

    return self.output
end

function SpatialFullConvolution:updateGradInput(input, gradOutput)
    self:backCompatibility()
    if not self.gradInput then return end
    self.gradInput:resizeAs(input)

    assert(gradOutput:dim() == 3 or gradOutput:dim() == 4, 'gradOutput has to be 3D or 4D');
    assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous')
    self:createIODescriptors(input)
    local finder = find.get()
    local fwdAlgo = finder:forwardAlgorithm(self, {self.oDesc[0], self.output_slice,
                                                   self.weightDesc[0], self.weight,
                                                   self.convDesc[0], self.iDesc[0], self.input_slice})
    local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace()
    checkedCall(self,'cudnnConvolutionForward', cudnn.getHandle(),
                cudnn.scalar(input, 1),
                self.oDesc[0], gradOutput:data(),
                self.weightDesc[0], self.weight:data(),
                self.convDesc[0],
                fwdAlgo,
                extraBuffer, extraBufferSize,
                cudnn.scalar(input, 0),
                self.iDesc[0], self.gradInput:data());
    return self.gradInput
end

function SpatialFullConvolution:accGradParameters(input, gradOutput, scale)
    self:backCompatibility()
    self.scaleT = self.scaleT or self.weight.new(1)
    -- this line forces this member to always be on CPU (needed for cudnn)
    self.scaleT = torch.type(self.weight) == 'torch.CudaDoubleTensor'
       and self.scaleT:double() or self.scaleT:float()
    scale = scale or 1.0
    self.scaleT[1] = scale

    assert(gradOutput:dim() == 3 or gradOutput:dim() == 4,
           'gradOutput has to be 3D or 4D');
    assert(gradOutput:isContiguous(), 'gradOutput has to be contiguous')
    self:createIODescriptors(input)
    local finder = find.get()
    local bwdFilterAlgo = finder:backwardFilterAlgorithm(self, {self.oDesc[0], self.output_slice,
                                                                self.iDesc[0], self.input_slice,
                                                                self.convDesc[0], self.weightDesc[0], self.weight})
    -- gradBias
    if self.bias then
        errcheck('cudnnConvolutionBackwardBias', cudnn.getHandle(),
                 self.scaleT:data(),
                 self.oDescForBias[0], gradOutput:data(),
                 cudnn.scalar(input, 1),
                 self.biasDesc[0], self.gradBias:data())
    end
    local extraBuffer, extraBufferSize = cudnn.getSharedWorkspace()
    -- gradWeight
    checkedCall(self,'cudnnConvolutionBackwardFilter', cudnn.getHandle(),
                self.scaleT:data(),
                self.oDesc[0], gradOutput:data(),
                self.iDesc[0], input:data(),
                self.convDesc[0],
                bwdFilterAlgo,
                extraBuffer, extraBufferSize,
                cudnn.scalar(input, 1),
                self.weightDesc[0], self.gradWeight:data())
end

function SpatialFullConvolution:clearDesc()
   return Convolution.clearDesc(self)
end

function SpatialFullConvolution:write(f)
    self:clearDesc()
    local var = {}
    for k,v in pairs(self) do
        var[k] = v
    end
    f:writeObject(var)
end

function SpatialFullConvolution:clearState()
   return Convolution.clearState(self)
end

function SpatialFullConvolution:read(file, version)
   parent.read(self, file)
   self.adjW = self.adjW or 0
   self.adjH = self.adjH or 0
end