Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-08-12 05:39:37 +0300
committersoumith <soumith@fb.com>2016-08-12 05:39:37 +0300
commitb163edb51cf369060fb9605517863f8c6ed3e54a (patch)
tree21a8d4e5f8c8477b8520fa2338b4a46d11a4aad4 /torch
parent9d5b1b1026a559b0c4f6e36432db53c3dd002c24 (diff)
fixing backward compatibility for __index__ and __new_index__
Diffstat (limited to 'torch')
-rw-r--r--torch/generic/Tensor.c51
1 files changed, 51 insertions, 0 deletions
diff --git a/torch/generic/Tensor.c b/torch/generic/Tensor.c
index 4fed963..29b90df 100644
--- a/torch/generic/Tensor.c
+++ b/torch/generic/Tensor.c
@@ -731,6 +731,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THLongStorage *idx = NULL;
THByteTensor *mask;
THCudaByteTensor *maskCuda;
+ THCTensor *maskCudaReal;
if(lua_isnumber(L, 2))
{
@@ -964,6 +965,37 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
{
luaL_error(L,"number or tensor expected");
}
+
+ }
+ else if((maskCudaReal = luaT_toudata(L, 2, torch_Tensor)))
+ {
+ maskCuda = THCudaByteTensor_new(state);
+ THLongStorage *maskCudaSize = THCTensor_(newSizeOf)(state, maskCudaReal);
+ THCudaByteTensor_resize(state, maskCuda, maskCudaSize, NULL);
+ THLongStorage_free(maskCudaSize);
+ TH_CONCAT_2(THCudaByteTensor_copyCuda, Real)(state, maskCuda, maskCudaReal);
+
+ THCTensor *vals;
+ if (lua_isnumber(L, 3))
+ {
+#ifdef THC_REAL_IS_HALF
+ real value = THC_float2half((float) luaL_checknumber(L, 3));
+#else
+ real value = (real) luaL_checknumber(L, 3);
+#endif
+
+ THCTensor_(maskedFill)(state, tensor, maskCuda, value);
+ }
+ else if((vals = luaT_toudata(L, 3, torch_Tensor)))
+ {
+ THCTensor_(maskedCopy)(state, tensor, maskCuda, vals);
+ }
+ else
+ {
+ luaL_error(L,"number or tensor expected");
+ }
+
+ THCudaByteTensor_free(state, maskCuda);
}
else
{
@@ -980,6 +1012,7 @@ static int torch_Tensor_(__index__)(lua_State *L)
THLongStorage *idx = NULL;
THByteTensor *mask;
THCudaByteTensor *maskCuda;
+ THCTensor *maskCudaReal;
if(lua_isnumber(L, 2))
{
@@ -1129,6 +1162,24 @@ static int torch_Tensor_(__index__)(lua_State *L)
THCTensor_(maskedSelect)(state, vals, tensor, maskCuda);
luaT_pushudata(L, vals, torch_Tensor);
lua_pushboolean(L, 1);
+
+ return 2;
+ }
+ else if((maskCudaReal = luaT_toudata(L, 2, torch_Tensor)))
+ {
+ maskCuda = THCudaByteTensor_new(state);
+ THLongStorage *maskCudaSize = THCTensor_(newSizeOf)(state, maskCudaReal);
+ THCudaByteTensor_resize(state, maskCuda, maskCudaSize, NULL);
+ THLongStorage_free(maskCudaSize);
+ TH_CONCAT_2(THCudaByteTensor_copyCuda, Real)(state, maskCuda, maskCudaReal);
+
+ THCTensor *vals = THCTensor_(new)(state);
+ THCTensor_(maskedSelect)(state, vals, tensor, maskCuda);
+ luaT_pushudata(L, vals, torch_Tensor);
+ lua_pushboolean(L, 1);
+
+ THCudaByteTensor_free(state, maskCuda);
+
return 2;
}
else