diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-04-09 21:36:37 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-04-09 21:36:37 +0400 |
commit | 0fe868ddaf7b4642a6f6ed87bd0bc6e62f18631d (patch) | |
tree | 893010c6f2cbdda3827ccd019431ffc2d3ddbb95 /LookupTable.lua | |
parent | fef7ef9ebb5a03024984376b240e960ac3ee6d4c (diff) |
works with batches. still needs localized scales
Diffstat (limited to 'LookupTable.lua')
-rw-r--r-- | LookupTable.lua | 34 |
1 files changed, 28 insertions, 6 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 9ccbf56..f762230 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -79,16 +79,38 @@ function LookupTable:zeroGradParameters() end function LookupTable:accGradParameters(input, gradOutput, scale) - for i=1,input:size(1) do - local k = input[i] - self.inputs[k] = true - self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i)) + if input:dim() == 1 then + for i=1,input:size(1) do + local k = input[i] + self.inputs[k] = true + self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i)) + end + elseif input:dim() == 2 then + for i=1,input:size(1) do + local input = input:select(1, i) + local gradOutput = gradOutput:select(1, i) + for j=1,input:size(1) do + local k = input[j] + self.input[k] = true + self.gradWeight:select(1, k):add(scale, gradOutput:select(1, j)) + end + end end end function LookupTable:accUpdateGradParameters(input, gradOutput, lr) - for i=1,input:size(1) do - self.weight:select(1, input[i]):add(-lr, gradOutput:select(1, i)) + if input:dim() == 1 then + for i=1,input:size(1) do + self.weight:select(1, input[i]):add(-lr, gradOutput:select(1, i)) + end + elseif input:dim() == 2 then + for i=1,input:size(1) do + local input = input:select(1, i) + local gradOutput = gradOutput:select(1, i) + for j=1,input:size(2) do + self.weight:select(1, input[j]):add(-lr, gradOutput:select(1, j)) + end + end end end |