diff options
author | soumith <soumith@fb.com> | 2016-08-12 05:39:37 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-08-12 05:39:37 +0300 |
commit | b163edb51cf369060fb9605517863f8c6ed3e54a (patch) | |
tree | 21a8d4e5f8c8477b8520fa2338b4a46d11a4aad4 /torch | |
parent | 9d5b1b1026a559b0c4f6e36432db53c3dd002c24 (diff) |
fixing backward compatibility for __index__ and __new_index__
Diffstat (limited to 'torch')
-rw-r--r-- | torch/generic/Tensor.c | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/torch/generic/Tensor.c b/torch/generic/Tensor.c index 4fed963..29b90df 100644 --- a/torch/generic/Tensor.c +++ b/torch/generic/Tensor.c @@ -731,6 +731,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THLongStorage *idx = NULL; THByteTensor *mask; THCudaByteTensor *maskCuda; + THCTensor *maskCudaReal; if(lua_isnumber(L, 2)) { @@ -964,6 +965,37 @@ static int torch_Tensor_(__newindex__)(lua_State *L) { luaL_error(L,"number or tensor expected"); } + + } + else if((maskCudaReal = luaT_toudata(L, 2, torch_Tensor))) + { + maskCuda = THCudaByteTensor_new(state); + THLongStorage *maskCudaSize = THCTensor_(newSizeOf)(state, maskCudaReal); + THCudaByteTensor_resize(state, maskCuda, maskCudaSize, NULL); + THLongStorage_free(maskCudaSize); + TH_CONCAT_2(THCudaByteTensor_copyCuda, Real)(state, maskCuda, maskCudaReal); + + THCTensor *vals; + if (lua_isnumber(L, 3)) + { +#ifdef THC_REAL_IS_HALF + real value = THC_float2half((float) luaL_checknumber(L, 3)); +#else + real value = (real) luaL_checknumber(L, 3); +#endif + + THCTensor_(maskedFill)(state, tensor, maskCuda, value); + } + else if((vals = luaT_toudata(L, 3, torch_Tensor))) + { + THCTensor_(maskedCopy)(state, tensor, maskCuda, vals); + } + else + { + luaL_error(L,"number or tensor expected"); + } + + THCudaByteTensor_free(state, maskCuda); } else { @@ -980,6 +1012,7 @@ static int torch_Tensor_(__index__)(lua_State *L) THLongStorage *idx = NULL; THByteTensor *mask; THCudaByteTensor *maskCuda; + THCTensor *maskCudaReal; if(lua_isnumber(L, 2)) { @@ -1129,6 +1162,24 @@ static int torch_Tensor_(__index__)(lua_State *L) THCTensor_(maskedSelect)(state, vals, tensor, maskCuda); luaT_pushudata(L, vals, torch_Tensor); lua_pushboolean(L, 1); + + return 2; + } + else if((maskCudaReal = luaT_toudata(L, 2, torch_Tensor))) + { + maskCuda = THCudaByteTensor_new(state); + THLongStorage *maskCudaSize = THCTensor_(newSizeOf)(state, maskCudaReal); + THCudaByteTensor_resize(state, maskCuda, maskCudaSize, NULL); + THLongStorage_free(maskCudaSize); + TH_CONCAT_2(THCudaByteTensor_copyCuda, Real)(state, maskCuda, maskCudaReal); + + THCTensor *vals = THCTensor_(new)(state); + THCTensor_(maskedSelect)(state, vals, tensor, maskCuda); + luaT_pushudata(L, vals, torch_Tensor); + lua_pushboolean(L, 1); + + THCudaByteTensor_free(state, maskCuda); + return 2; } else |