Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-08-22 01:16:57 +0300
committerAdam Lerer <alerer@fb.com>2015-08-22 01:16:57 +0300
commitef2e3d290418c411add8eee8dd5ca7d486362ad8 (patch)
tree6516477c6dadf9b652a209ce9975e339e7498451
parent5659715a11a38582f88bebb7a2736075d876ab20 (diff)
Moved ClassNLLCriterion to C. Normalization done by total weight instead of size
-rw-r--r--ClassNLLCriterion.lua156
-rw-r--r--generic/ClassNLLCriterion.c163
-rw-r--r--init.c5
3 files changed, 223 insertions, 101 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index bc2a2e9..9b8d0aa 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -1,16 +1,28 @@
-local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion')
+local ClassNLLCriterion, parent = torch.class(
+ 'nn.ClassNLLCriterion',
+ 'nn.Criterion'
+)
-function ClassNLLCriterion:__init(weights)
- parent.__init(self)
- self.sizeAverage = true
- self.outputTensor = torch.Tensor(1)
- if weights then
- assert(weights:dim() == 1, "weights input should be 1-D Tensor")
- self.weights = weights
- end
+function ClassNLLCriterion:__init(weights, sizeAverage)
+ parent.__init(self)
+ if sizeAverage ~= nil then
+ self.sizeAverage = sizeAverage
+ else
+ self.sizeAverage = true
+ end
+ if weights then
+ assert(weights:dim() == 1, "weights input should be 1-D Tensor")
+ self.weights = weights
+ end
+
+ self.output_tensor = torch.zeros(1)
+ self.total_weight_tensor = torch.zeros(1)
+ self.target = torch.zeros(1):long()
end
+
+
function ClassNLLCriterion:__len()
if (self.weights) then
return #self.weights
@@ -21,101 +33,43 @@ end
function ClassNLLCriterion:updateOutput(input, target)
- if input:type() == 'torch.CudaTensor' then
- if self.weights == nil then
- -- The CUDA implementation requires self.weights be non-nil
- self.weights = torch.CudaTensor()
- end
- assert(self.weights:dim() == 0 or self.weights:dim() == 1,
- 'weights must be 1D or empty')
- -- The cuda code wont check weight size, so we must do it here.
- if self.weights:dim() == 1 then
- if input:dim() == 1 then
- assert(self.weights:size(1) == input:size(1),
- 'Wrong number of weights')
- else
- assert(self.weights:size(1) == input:size(2),
- 'Wrong number of weights')
- end
- end
- if input:dim() == 1 then
- self._target = self._target or input.new(1)
- if type(target) == 'number' then
- self._target[1] = target
- else
- self._target:copy(target)
- end
- input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
- else
- input.nn.ClassNLLCriterion_updateOutput(self, input, target)
- end
- self.output = self.outputTensor[1]
- return self.output
- end
+ if type(target) == 'number' then
+ self.target[1] = target
+ elseif target:type() == 'torch.CudaTensor' then
+ self.target = target
+ else
+ self.target = target:long()
+ end
- if input:dim() == 1 then
- if torch.isTensor(target) then target = target[1] end
- self.output = -input[target]
- if self.weights then
- self.output = self.output*self.weights[target]
- end
- elseif input:dim() == 2 then
- local output = 0
- for i=1,target:size(1) do
- if self.weights then
- output = output - input[i][target[i]]*self.weights[target[i]]
- else
- output = output - input[i][target[i]]
- end
- end
- if self.sizeAverage then
- output = output / target:size(1)
- end
- self.output = output
- else
- error('matrix or vector expected')
- end
- return self.output
+ input.nn.ClassNLLCriterion_updateOutput(
+ input,
+ self.target,
+ self.weights,
+ self.sizeAverage,
+ self.output_tensor,
+ self.total_weight_tensor
+ )
+ self.output = self.output_tensor[1]
+ return self.output, self.total_weight_tensor[1]
end
function ClassNLLCriterion:updateGradInput(input, target)
- self.gradInput:resizeAs(input)
- self.gradInput:zero()
+ if type(target) == 'number' then
+ self.target[1] = target
+ elseif target:type() == 'torch.CudaTensor' then
+ self.target = target
+ else
+ self.target = target:long()
+ end
- if input:type() == 'torch.CudaTensor' then
- -- Note: we'll assume that updateOutput() has been called and self.weights
- -- is non-nil.
- if input:dim() == 1 then
- self._target = self._target or input.new(1)
- if type(target) == 'number' then
- self._target[1] = target
- else
- self._target:copy(target)
- end
- input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
- else
- input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
- end
- return self.gradInput
- end
-
- if input:dim() == 1 then
- if torch.isTensor(target) then target = target[1] end
- self.gradInput[target] = -1
- if self.weights then
- self.gradInput[target] = self.gradInput[target]*self.weights[target]
- end
- else
- local z = -1
- if self.sizeAverage then
- z = z / target:size(1)
- end
- for i=1,target:size(1) do
- self.gradInput[i][target[i]] = z
- if self.weights then
- self.gradInput[i][target[i]] = self.gradInput[i][target[i]]*self.weights[target[i]]
- end
- end
- end
- return self.gradInput
+ self.gradInput:resizeAs(input):zero()
+ input.nn.ClassNLLCriterion_updateGradInput(
+ input,
+ self.target,
+ self.weights,
+ self.sizeAverage,
+ self.total_weight_tensor,
+ self.gradInput
+ )
+ return self.gradInput
end
diff --git a/generic/ClassNLLCriterion.c b/generic/ClassNLLCriterion.c
new file mode 100644
index 0000000..d8efef7
--- /dev/null
+++ b/generic/ClassNLLCriterion.c
@@ -0,0 +1,163 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/ClassNLLCriterion.c"
+#else
+
+
+static int nn_(ClassNLLCriterion_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 1, torch_Tensor);
+ THLongTensor *target = luaT_checkudata(L, 2, "torch.LongTensor");
+ THTensor *weights = NULL;
+ if (!lua_isnil(L, 3)) {
+ weights = luaT_checkudata(L, 3, torch_Tensor);
+ }
+ int n_dims = THTensor_(nDimension)(input);
+ int n_classes = THTensor_(size)(input, n_dims - 1);
+
+ int sizeAverage = lua_toboolean(L, 4);
+ THTensor *output = luaT_checkudata(L, 5, torch_Tensor);
+ THTensor *total_weight = luaT_checkudata(L, 6, torch_Tensor);
+
+ if (THLongTensor_nDimension(target) > 1) {
+ THError("multi-target not supported");
+ }
+ if (THTensor_(nDimension)(input) > 2) {
+ THError("input tensor should be 1D or 2D");
+ }
+
+ input = THTensor_(newContiguous)(input);
+ target = THLongTensor_newContiguous(target);
+ weights = weights ? THTensor_(newContiguous)(weights) : NULL;
+
+ real *input_data = THTensor_(data)(input);
+ long *target_data = THLongTensor_data(target);
+ real *weights_data = weights ? THTensor_(data)(weights) : NULL;
+ real *output_data = THTensor_(data)(output);
+ real *total_weight_data = THTensor_(data)(total_weight);
+
+ output_data[0] = total_weight_data[0] = 0.0;
+
+ if (THTensor_(nDimension)(input) == 1) {
+ int cur_target = target_data[0] - 1;
+ THAssert(cur_target >= 0 && cur_target < n_classes);
+ total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
+ output_data[0] = -input_data[cur_target] * total_weight_data[0];
+ } else if (THTensor_(nDimension)(input) == 2) {
+ int batch_size = THTensor_(size)(input, 0);
+ int n_target = THTensor_(size)(input, 1);
+
+ int i;
+ for (i = 0; i < batch_size; i++) {
+ int cur_target = target_data[i] - 1;
+ THAssert(cur_target >= 0 && cur_target < n_classes);
+
+ real cur_weight = weights ? weights_data[cur_target] : 1.0f;
+ total_weight_data[0] += cur_weight;
+ output_data[0] -= input_data[i * n_target + cur_target] * cur_weight;
+ }
+ }
+
+ if (sizeAverage && total_weight_data[0]) {
+ output_data[0] /= total_weight_data[0];
+ }
+
+ if (weights) {
+ THTensor_(free)(weights);
+ }
+ THTensor_(free)(input);
+ THLongTensor_free(target);
+
+ return 0;
+}
+
+static int nn_(ClassNLLCriterion_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 1, torch_Tensor);
+ THLongTensor *target = luaT_checkudata(L, 2, "torch.LongTensor");
+ THTensor *weights = NULL;
+ if (!lua_isnil(L, 3)) {
+ weights = luaT_checkudata(L, 3, torch_Tensor);
+ }
+
+ int n_dims = THTensor_(nDimension)(input);
+ int n_classes = THTensor_(size)(input, n_dims - 1);
+
+ int sizeAverage = lua_toboolean(L, 4);
+ THTensor *total_weight = luaT_checkudata(L, 5, torch_Tensor);
+ THTensor *gradInput = luaT_checkudata(L, 6, torch_Tensor);
+ luaL_argcheck(
+ L,
+ THTensor_(isContiguous)(gradInput),
+ 6,
+ "gradInput must be contiguous"
+ );
+
+ real* total_weight_data = THTensor_(data)(total_weight);
+
+ if (!(*total_weight_data > 0)) {
+ return 0;
+ }
+
+ if (THLongTensor_nDimension(target) > 1) {
+ THError("multi-target not supported");
+ }
+
+ if (THTensor_(nDimension)(input) > 2) {
+ THError("input tensor should be 1D or 2D");
+ }
+
+ target = THLongTensor_newContiguous(target);
+ weights = weights ? THTensor_(newContiguous)(weights) : NULL;
+
+ long *target_data = THLongTensor_data(target);
+ real *weights_data = weights ? THTensor_(data)(weights) : NULL;
+ real *gradInput_data = THTensor_(data)(gradInput);
+
+ if (THTensor_(nDimension)(input) == 1) {
+ int cur_target = target_data[0] - 1;
+ THAssert(cur_target >= 0 && cur_target < n_classes);
+
+ gradInput_data[cur_target] =
+ (!sizeAverage && weights) ? -weights_data[cur_target] : -1;
+
+ } else if (THTensor_(nDimension)(input) == 2) {
+ int batch_size = THTensor_(size)(input, 0);
+ int n_target = THTensor_(size)(input, 1);
+
+ int i;
+ for(i = 0; i < batch_size; i++){
+ int cur_target = target_data[i] - 1;
+
+ THAssert(cur_target >= 0 && cur_target < n_classes);
+
+ gradInput_data[i * n_target + cur_target] =
+ -(weights ? weights_data[cur_target] : 1.0f);
+
+ if (sizeAverage && *total_weight_data) {
+ gradInput_data[i * n_target + cur_target] /= *total_weight_data;
+ }
+ }
+ }
+
+ THLongTensor_free(target);
+ if (weights) {
+ THTensor_(free)(weights);
+ }
+
+ return 0;
+}
+
+static const struct luaL_Reg nn_(ClassNLLCriterion__) [] = {
+ {"ClassNLLCriterion_updateOutput", nn_(ClassNLLCriterion_updateOutput)},
+ {"ClassNLLCriterion_updateGradInput", nn_(ClassNLLCriterion_updateGradInput)},
+ {NULL, NULL}
+};
+
+static void nn_(ClassNLLCriterion_init)(lua_State *L)
+{
+ luaT_pushmetatable(L, torch_Tensor);
+ luaT_registeratname(L, nn_(ClassNLLCriterion__), "nn");
+ lua_pop(L,1);
+}
+
+#endif
diff --git a/init.c b/init.c
index 7cdae69..0fc7208 100644
--- a/init.c
+++ b/init.c
@@ -47,6 +47,9 @@
#include "generic/SoftMax.c"
#include "THGenerateFloatTypes.h"
+#include "generic/ClassNLLCriterion.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/MSECriterion.c"
#include "THGenerateFloatTypes.h"
@@ -134,6 +137,7 @@ int luaopen_libnn(lua_State *L)
nn_FloatSquare_init(L);
nn_FloatHardTanh_init(L);
nn_FloatLogSoftMax_init(L);
+ nn_FloatClassNLLCriterion_init(L);
nn_FloatMSECriterion_init(L);
nn_FloatMarginCriterion_init(L);
nn_FloatAbsCriterion_init(L);
@@ -174,6 +178,7 @@ int luaopen_libnn(lua_State *L)
nn_DoubleSquare_init(L);
nn_DoubleHardTanh_init(L);
nn_DoubleLogSoftMax_init(L);
+ nn_DoubleClassNLLCriterion_init(L);
nn_DoubleMSECriterion_init(L);
nn_DoubleMarginCriterion_init(L);
nn_DoubleAbsCriterion_init(L);