diff options
author | Andreas Köpf <andreas.koepf@xamla.com> | 2015-12-13 00:52:35 +0300 |
---|---|---|
committer | soumith <soumith@gmail.com> | 2015-12-30 00:38:06 +0300 |
commit | 1e3abcb83ee71478a5f772b757cfbaf5c66be3c4 (patch) | |
tree | 63993b7f944fd500b32112af5e5324cdccc81d9a /THNN.lua | |
parent | b887940ea683415d1d94614fd15ddddf124af68b (diff) |
Add functional version of AbsCriterion using metatable call
THNN state is now passed implicitely.
Diffstat (limited to 'THNN.lua')
-rw-r--r-- | THNN.lua | 24 |
1 files changed, 20 insertions, 4 deletions
@@ -12,6 +12,19 @@ TH_API void THNN_(Abs_updateGradInput)( THTensor *input, THTensor *gradOutput, THTensor *gradInput); + +TH_API void THNN_(AbsCriterion_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *target, + real *output, + bool sizeAverage); +TH_API void THNN_(AbsCriterion_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *target, + THTensor *gradInput, + bool sizeAverage); ]] -- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h @@ -82,14 +95,14 @@ local function extract_function_names(s) return t end -function THNN.bind(lib, base_names, type_name) +function THNN.bind(lib, base_names, type_name, state_getter) local ftable = {} local prefix = 'THNN_' .. type_name for i,n in ipairs(base_names) do -- use pcall since some libs might not support all functions (e.g. cunn) local ok,v = pcall(function() return lib[prefix .. n] end) if ok then - ftable[n] = v + ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state end end return ftable @@ -99,8 +112,11 @@ end local function_names = extract_function_names(generic_THNN_h) THNN.kernels = {} -THNN.kernels['torch.FloatTensor'] = THNN.bind(THNN.C, function_names, 'Float') -THNN.kernels['torch.DoubleTensor'] = THNN.bind(THNN.C, function_names, 'Double') +THNN.kernels['torch.FloatTensor'] = THNN.bind(THNN.C, function_names, 'Float', THNN.getState) +THNN.kernels['torch.DoubleTensor'] = THNN.bind(THNN.C, function_names, 'Double', THNN.getState) + +torch.getmetatable('torch.FloatTensor').THNN = THNN.kernels['torch.FloatTensor'] +torch.getmetatable('torch.DoubleTensor').THNN = THNN.kernels['torch.DoubleTensor'] function THNN.runKernel(f, type, ...) local ftable = THNN.kernels[type] |