diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-17 07:08:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-17 07:08:30 +0300 |
commit | e71e246e172b757f81b5da4a58871104bc81d3d9 (patch) | |
tree | 83c062506541f1628547ba055e6418c522e6b07b | |
parent | d1584c714469a46c70856c041734f781fcb454b9 (diff) | |
parent | d285340cba0d55be08b23a768a1c693dd6e5f4fe (diff) |
Merge pull request #120 from germank/cmul_batch
CMul batch processing
-rw-r--r-- | CMul.lua | 43 | ||||
-rw-r--r-- | test.lua | 21 |
2 files changed, 58 insertions, 6 deletions
@@ -13,25 +13,56 @@ function CMul:__init(inputSize) self:reset() end -function CMul:reset() - self.weight:fill(1) +function CMul:reset(stdv) + if stdv then + stdv = stdv * math.sqrt(3) + else + stdv = 1./math.sqrt(self.weight:size(1)) + end + self.weight:uniform(-stdv,stdv) end function CMul:updateOutput(input) - self.output:copy(input); - self.output:cmul(self.weight); + self.output:resizeAs(input):copy(input) + if input:nElement() == self.weight:nElement() then + self.output:view(-1):cmul(self.weight:view(-1)); + else + if input:isSameSizeAs(self.weight) then + self.output:cmul(self.weight) + else + local batchSize = input:size(1) + self.output:view(batchSize, -1):cmul(self.weight:view(1,-1):expandAs(input:view(batchSize, -1))) + end + end return self.output end function CMul:updateGradInput(input, gradOutput) if self.gradInput then + local nElement = self.gradInput:nElement() self.gradInput:resizeAs(input) self.gradInput:zero() - self.gradInput:addcmul(1, self.weight, gradOutput) + if self.weight:nElement() == gradOutput:nElement() then + self.gradInput:addcmul(1, self.weight, gradOutput) + else + local gradOutput = gradOutput:view(input:size(1), -1) + local gradInput = self.gradInput:view(input:size(1), -1) + gradInput:addcmul(1, self.weight:view(1,-1):expandAs(gradOutput), gradOutput) + end return self.gradInput end end function CMul:accGradParameters(input, gradOutput, scale) - self.gradWeight:addcmul(scale or 1, input, gradOutput) + if self.weight:nElement() == gradOutput:nElement() then + self.gradWeight:addcmul(scale or 1, input, gradOutput) + else + local batchSize = input:size(1) + local input = input:view(batchSize, -1) + local gradOutput = gradOutput:view(batchSize, -1) + local gradWeight = self.gradWeight:view(1, -1) + for i=1,batchSize do + gradWeight:addcmul(scale or 1, input[i], gradOutput[i]) + end + end end @@ -76,6 +76,7 @@ function nntest.CMul() local input = torch.Tensor(ini,inj,ink):zero() local module = nn.CMul(ini*inj*ink) + -- 1D local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -90,6 +91,26 @@ function nntest.CMul() 'error on weight [%s]', t)) end + -- 2D + local nframe = math.random(50,70) + local nframe = 5 + local input = torch.Tensor(nframe, ini,inj,ink):zero() + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format('error on weight [%s]', t)) + end + + + -- IO local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') |