diff options
-rw-r--r-- | python/sentencepiece.i | 41 | ||||
-rw-r--r-- | python/sentencepiece.py | 30 | ||||
-rw-r--r-- | python/sentencepiece_wrap.cxx | 593 | ||||
-rwxr-xr-x | python/test/sentencepiece_test.py | 13 | ||||
-rw-r--r-- | src/sentencepiece_processor.cc | 35 | ||||
-rw-r--r-- | src/sentencepiece_processor.h | 24 | ||||
-rw-r--r-- | src/sentencepiece_processor_test.cc | 18 |
7 files changed, 750 insertions, 4 deletions
diff --git a/python/sentencepiece.i b/python/sentencepiece.i index 0cad549..3861240 100644 --- a/python/sentencepiece.i +++ b/python/sentencepiece.i @@ -76,6 +76,14 @@ PyObject* MakePyOutputString(const std::string& output, #endif } +PyObject* MakePyOutputBytes(const std::string& output) { +#if PY_VERSION_HEX >= 0x03000000 + return PyBytes_FromStringAndSize(output.data(), output.size()); +#else + return PyString_FromStringAndSize(output.data(), output.size()); +#endif +} + int ToSwigError(sentencepiece::util::error::Code code) { switch (code) { case sentencepiece::util::error::NOT_FOUND: @@ -207,6 +215,29 @@ int ToSwigError(sentencepiece::util::error::Code code) { return $self->DecodeIds(input); } + util::bytes encode_as_serialized_proto(util::min_string_view input) const { + return $self->EncodeAsSerializedProto(input); + } + + util::bytes sample_encode_as_serialized_proto(util::min_string_view input, + int nbest_size, float alpha) const { + return $self->SampleEncodeAsSerializedProto(input, nbest_size, alpha); + } + + util::bytes nbest_encode_as_serialized_proto(util::min_string_view input, + int nbest_size) const { + return $self->NBestEncodeAsSerializedProto(input, nbest_size); + } + + util::bytes decode_pieces_as_serialized_proto( + const std::vector<std::string> &pieces) const { + return $self->DecodePiecesAsSerializedProto(pieces); + } + + util::bytes decode_ids_as_serialized_proto(const std::vector<int> &ids) const { + return $self->DecodeIdsAsSerializedProto(ids); + } + int get_piece_size() const { return $self->GetPieceSize(); } @@ -287,6 +318,15 @@ int ToSwigError(sentencepiece::util::error::Code code) { $result = MakePyOutputString($1, input_type); } +%typemap(out) const std::string& { + PyObject *input_type = resultobj; + $result = MakePyOutputString(*$1, input_type); +} + +%typemap(out) sentencepiece::util::bytes { + $result = MakePyOutputBytes($1); +} + %typemap(out) sentencepiece::util::Status { if (!$1.ok()) { SWIG_exception(ToSwigError($1.code()), $1.ToString().c_str()); @@ -316,7 +356,6 @@ int ToSwigError(sentencepiece::util::error::Code code) { $1 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); } - %typemap(in) const std::vector<std::string>& { std::vector<std::string> *out = nullptr; if (PyList_Check($input)) { diff --git a/python/sentencepiece.py b/python/sentencepiece.py index 3320b97..0b28e5a 100644 --- a/python/sentencepiece.py +++ b/python/sentencepiece.py @@ -162,6 +162,21 @@ class SentencePieceProcessor(_object): def DecodeIds(self, ids): return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids) + def EncodeAsSerializedProto(self, input): + return _sentencepiece.SentencePieceProcessor_EncodeAsSerializedProto(self, input) + + def SampleEncodeAsSerializedProto(self, input, nbest_size, alpha): + return _sentencepiece.SentencePieceProcessor_SampleEncodeAsSerializedProto(self, input, nbest_size, alpha) + + def NBestEncodeAsSerializedProto(self, input, nbest_size): + return _sentencepiece.SentencePieceProcessor_NBestEncodeAsSerializedProto(self, input, nbest_size) + + def DecodePiecesAsSerializedProto(self, pieces): + return _sentencepiece.SentencePieceProcessor_DecodePiecesAsSerializedProto(self, pieces) + + def DecodeIdsAsSerializedProto(self, ids): + return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProto(self, ids) + def GetPieceSize(self): return _sentencepiece.SentencePieceProcessor_GetPieceSize(self) @@ -240,6 +255,21 @@ class SentencePieceProcessor(_object): def decode_ids(self, input): return _sentencepiece.SentencePieceProcessor_decode_ids(self, input) + def encode_as_serialized_proto(self, input): + return _sentencepiece.SentencePieceProcessor_encode_as_serialized_proto(self, input) + + def sample_encode_as_serialized_proto(self, input, nbest_size, alpha): + return _sentencepiece.SentencePieceProcessor_sample_encode_as_serialized_proto(self, input, nbest_size, alpha) + + def nbest_encode_as_serialized_proto(self, input, nbest_size): + return _sentencepiece.SentencePieceProcessor_nbest_encode_as_serialized_proto(self, input, nbest_size) + + def decode_pieces_as_serialized_proto(self, pieces): + return _sentencepiece.SentencePieceProcessor_decode_pieces_as_serialized_proto(self, pieces) + + def decode_ids_as_serialized_proto(self, ids): + return _sentencepiece.SentencePieceProcessor_decode_ids_as_serialized_proto(self, ids) + def get_piece_size(self): return _sentencepiece.SentencePieceProcessor_get_piece_size(self) diff --git a/python/sentencepiece_wrap.cxx b/python/sentencepiece_wrap.cxx index bb258c8..422ecef 100644 --- a/python/sentencepiece_wrap.cxx +++ b/python/sentencepiece_wrap.cxx @@ -3192,6 +3192,14 @@ PyObject* MakePyOutputString(const std::string& output, #endif } +PyObject* MakePyOutputBytes(const std::string& output) { +#if PY_VERSION_HEX >= 0x03000000 + return PyBytes_FromStringAndSize(output.data(), output.size()); +#else + return PyString_FromStringAndSize(output.data(), output.size()); +#endif +} + int ToSwigError(sentencepiece::util::error::Code code) { switch (code) { case sentencepiece::util::error::NOT_FOUND: @@ -3596,6 +3604,21 @@ SWIGINTERN std::string sentencepiece_SentencePieceProcessor_decode_pieces(senten SWIGINTERN std::string sentencepiece_SentencePieceProcessor_decode_ids(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &input){ return self->DecodeIds(input); } +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_encode_as_serialized_proto(sentencepiece::SentencePieceProcessor const *self,sentencepiece::util::min_string_view input){ + return self->EncodeAsSerializedProto(input); + } +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_sample_encode_as_serialized_proto(sentencepiece::SentencePieceProcessor const *self,sentencepiece::util::min_string_view input,int nbest_size,float alpha){ + return self->SampleEncodeAsSerializedProto(input, nbest_size, alpha); + } +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_nbest_encode_as_serialized_proto(sentencepiece::SentencePieceProcessor const *self,sentencepiece::util::min_string_view input,int nbest_size){ + return self->NBestEncodeAsSerializedProto(input, nbest_size); + } +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_decode_pieces_as_serialized_proto(sentencepiece::SentencePieceProcessor const *self,std::vector< std::string > const &pieces){ + return self->DecodePiecesAsSerializedProto(pieces); + } +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_decode_ids_as_serialized_proto(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){ + return self->DecodeIdsAsSerializedProto(ids); + } SWIGINTERN int sentencepiece_SentencePieceProcessor_get_piece_size(sentencepiece::SentencePieceProcessor const *self){ return self->GetPieceSize(); } @@ -4521,6 +4544,283 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_EncodeAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_EncodeAsSerializedProto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_EncodeAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + { + try { + result = ((sentencepiece::SentencePieceProcessor const *)arg1)->EncodeAsSerializedProto(arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_SampleEncodeAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + int arg3 ; + float arg4 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val3 ; + int ecode3 = 0 ; + float val4 ; + int ecode4 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + PyObject * obj2 = 0 ; + PyObject * obj3 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OOOO:SentencePieceProcessor_SampleEncodeAsSerializedProto",&obj0,&obj1,&obj2,&obj3)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_SampleEncodeAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + ecode3 = SWIG_AsVal_int(obj2, &val3); + if (!SWIG_IsOK(ecode3)) { + SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceProcessor_SampleEncodeAsSerializedProto" "', argument " "3"" of type '" "int""'"); + } + arg3 = static_cast< int >(val3); + ecode4 = SWIG_AsVal_float(obj3, &val4); + if (!SWIG_IsOK(ecode4)) { + SWIG_exception_fail(SWIG_ArgError(ecode4), "in method '" "SentencePieceProcessor_SampleEncodeAsSerializedProto" "', argument " "4"" of type '" "float""'"); + } + arg4 = static_cast< float >(val4); + { + try { + result = ((sentencepiece::SentencePieceProcessor const *)arg1)->SampleEncodeAsSerializedProto(arg2,arg3,arg4); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_NBestEncodeAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + int arg3 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val3 ; + int ecode3 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + PyObject * obj2 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OOO:SentencePieceProcessor_NBestEncodeAsSerializedProto",&obj0,&obj1,&obj2)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_NBestEncodeAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + ecode3 = SWIG_AsVal_int(obj2, &val3); + if (!SWIG_IsOK(ecode3)) { + SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceProcessor_NBestEncodeAsSerializedProto" "', argument " "3"" of type '" "int""'"); + } + arg3 = static_cast< int >(val3); + { + try { + result = ((sentencepiece::SentencePieceProcessor const *)arg1)->NBestEncodeAsSerializedProto(arg2,arg3); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodePiecesAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< std::string > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_DecodePiecesAsSerializedProto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodePiecesAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector<std::string> *out = nullptr; + if (PyList_Check(obj1)) { + const size_t size = PyList_Size(obj1); + out = new std::vector<std::string>(size); + for (size_t i = 0; i < size; ++i) { + const PyInputString ustring(PyList_GetItem(obj1, i)); + if (ustring.IsAvalable()) { + (*out)[i] = std::string(ustring.data(), ustring.size()); + } else { + PyErr_SetString(PyExc_TypeError, "list must contain strings"); + SWIG_fail; + } + resultobj = ustring.input_type(); + } + } else { + PyErr_SetString(PyExc_TypeError, "not a list"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = ((sentencepiece::SentencePieceProcessor const *)arg1)->DecodePiecesAsSerializedProto((std::vector< std::string > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< int > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_DecodeIdsAsSerializedProto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector<int> *out = nullptr; + if (PyList_Check(obj1)) { + const size_t size = PyList_Size(obj1); + out = new std::vector<int>(size); + for (size_t i = 0; i < size; ++i) { + PyObject *o = PyList_GetItem(obj1, i); + if (PyInt_Check(o)) { + (*out)[i] = static_cast<int>(PyInt_AsLong(o)); + } else { + PyErr_SetString(PyExc_TypeError,"list must contain integers"); + SWIG_fail; + } + } + } else { + PyErr_SetString(PyExc_TypeError,"not a list"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = ((sentencepiece::SentencePieceProcessor const *)arg1)->DecodeIdsAsSerializedProto((std::vector< int > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceProcessor_GetPieceSize(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -4602,7 +4902,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor_IdToPiece(PyObject *SWIGUNUSED int ecode2 = 0 ; PyObject * obj0 = 0 ; PyObject * obj1 = 0 ; - std::string result; + std::string *result = 0 ; if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_IdToPiece",&obj0,&obj1)) SWIG_fail; res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); @@ -4617,7 +4917,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor_IdToPiece(PyObject *SWIGUNUSED arg2 = static_cast< int >(val2); { try { - result = ((sentencepiece::SentencePieceProcessor const *)arg1)->IdToPiece(arg2); + result = (std::string *) &((sentencepiece::SentencePieceProcessor const *)arg1)->IdToPiece(arg2); ReleaseResultObject(resultobj); } catch (const sentencepiece::util::Status &status) { @@ -4626,7 +4926,7 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor_IdToPiece(PyObject *SWIGUNUSED } { PyObject *input_type = resultobj; - resultobj = MakePyOutputString(result, input_type); + resultobj = MakePyOutputString(*result, input_type); } return resultobj; fail: @@ -5712,6 +6012,283 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_encode_as_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_encode_as_serialized_proto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_encode_as_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + { + try { + result = sentencepiece_SentencePieceProcessor_encode_as_serialized_proto((sentencepiece::SentencePieceProcessor const *)arg1,arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_sample_encode_as_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + int arg3 ; + float arg4 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val3 ; + int ecode3 = 0 ; + float val4 ; + int ecode4 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + PyObject * obj2 = 0 ; + PyObject * obj3 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OOOO:SentencePieceProcessor_sample_encode_as_serialized_proto",&obj0,&obj1,&obj2,&obj3)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_sample_encode_as_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + ecode3 = SWIG_AsVal_int(obj2, &val3); + if (!SWIG_IsOK(ecode3)) { + SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceProcessor_sample_encode_as_serialized_proto" "', argument " "3"" of type '" "int""'"); + } + arg3 = static_cast< int >(val3); + ecode4 = SWIG_AsVal_float(obj3, &val4); + if (!SWIG_IsOK(ecode4)) { + SWIG_exception_fail(SWIG_ArgError(ecode4), "in method '" "SentencePieceProcessor_sample_encode_as_serialized_proto" "', argument " "4"" of type '" "float""'"); + } + arg4 = static_cast< float >(val4); + { + try { + result = sentencepiece_SentencePieceProcessor_sample_encode_as_serialized_proto((sentencepiece::SentencePieceProcessor const *)arg1,arg2,arg3,arg4); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_nbest_encode_as_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + sentencepiece::util::min_string_view arg2 ; + int arg3 ; + void *argp1 = 0 ; + int res1 = 0 ; + int val3 ; + int ecode3 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + PyObject * obj2 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OOO:SentencePieceProcessor_nbest_encode_as_serialized_proto",&obj0,&obj1,&obj2)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_nbest_encode_as_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(obj1); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = sentencepiece::util::min_string_view(ustring.data(), ustring.size()); + } + ecode3 = SWIG_AsVal_int(obj2, &val3); + if (!SWIG_IsOK(ecode3)) { + SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceProcessor_nbest_encode_as_serialized_proto" "', argument " "3"" of type '" "int""'"); + } + arg3 = static_cast< int >(val3); + { + try { + result = sentencepiece_SentencePieceProcessor_nbest_encode_as_serialized_proto((sentencepiece::SentencePieceProcessor const *)arg1,arg2,arg3); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_decode_pieces_as_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< std::string > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_decode_pieces_as_serialized_proto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_decode_pieces_as_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector<std::string> *out = nullptr; + if (PyList_Check(obj1)) { + const size_t size = PyList_Size(obj1); + out = new std::vector<std::string>(size); + for (size_t i = 0; i < size; ++i) { + const PyInputString ustring(PyList_GetItem(obj1, i)); + if (ustring.IsAvalable()) { + (*out)[i] = std::string(ustring.data(), ustring.size()); + } else { + PyErr_SetString(PyExc_TypeError, "list must contain strings"); + SWIG_fail; + } + resultobj = ustring.input_type(); + } + } else { + PyErr_SetString(PyExc_TypeError, "not a list"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = sentencepiece_SentencePieceProcessor_decode_pieces_as_serialized_proto((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< std::string > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_decode_ids_as_serialized_proto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< int > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject * obj0 = 0 ; + PyObject * obj1 = 0 ; + sentencepiece::util::bytes result; + + if (!PyArg_ParseTuple(args,(char *)"OO:SentencePieceProcessor_decode_ids_as_serialized_proto",&obj0,&obj1)) SWIG_fail; + res1 = SWIG_ConvertPtr(obj0, &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_decode_ids_as_serialized_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector<int> *out = nullptr; + if (PyList_Check(obj1)) { + const size_t size = PyList_Size(obj1); + out = new std::vector<int>(size); + for (size_t i = 0; i < size; ++i) { + PyObject *o = PyList_GetItem(obj1, i); + if (PyInt_Check(o)) { + (*out)[i] = static_cast<int>(PyInt_AsLong(o)); + } else { + PyErr_SetString(PyExc_TypeError,"list must contain integers"); + SWIG_fail; + } + } + } else { + PyErr_SetString(PyExc_TypeError,"not a list"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = sentencepiece_SentencePieceProcessor_decode_ids_as_serialized_proto((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceProcessor_get_piece_size(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -6160,6 +6737,11 @@ static PyMethodDef SwigMethods[] = { { (char *)"SentencePieceProcessor_SampleEncodeAsIds", _wrap_SentencePieceProcessor_SampleEncodeAsIds, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_DecodePieces", _wrap_SentencePieceProcessor_DecodePieces, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_DecodeIds", _wrap_SentencePieceProcessor_DecodeIds, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_EncodeAsSerializedProto", _wrap_SentencePieceProcessor_EncodeAsSerializedProto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_SampleEncodeAsSerializedProto", _wrap_SentencePieceProcessor_SampleEncodeAsSerializedProto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_NBestEncodeAsSerializedProto", _wrap_SentencePieceProcessor_NBestEncodeAsSerializedProto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_DecodePiecesAsSerializedProto", _wrap_SentencePieceProcessor_DecodePiecesAsSerializedProto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_DecodeIdsAsSerializedProto", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProto, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_GetPieceSize", _wrap_SentencePieceProcessor_GetPieceSize, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_PieceToId", _wrap_SentencePieceProcessor_PieceToId, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_IdToPiece", _wrap_SentencePieceProcessor_IdToPiece, METH_VARARGS, NULL}, @@ -6186,6 +6768,11 @@ static PyMethodDef SwigMethods[] = { { (char *)"SentencePieceProcessor_sample_encode_as_ids", _wrap_SentencePieceProcessor_sample_encode_as_ids, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_decode_pieces", _wrap_SentencePieceProcessor_decode_pieces, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_decode_ids", _wrap_SentencePieceProcessor_decode_ids, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_encode_as_serialized_proto", _wrap_SentencePieceProcessor_encode_as_serialized_proto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_sample_encode_as_serialized_proto", _wrap_SentencePieceProcessor_sample_encode_as_serialized_proto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_nbest_encode_as_serialized_proto", _wrap_SentencePieceProcessor_nbest_encode_as_serialized_proto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_decode_pieces_as_serialized_proto", _wrap_SentencePieceProcessor_decode_pieces_as_serialized_proto, METH_VARARGS, NULL}, + { (char *)"SentencePieceProcessor_decode_ids_as_serialized_proto", _wrap_SentencePieceProcessor_decode_ids_as_serialized_proto, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_get_piece_size", _wrap_SentencePieceProcessor_get_piece_size, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_piece_to_id", _wrap_SentencePieceProcessor_piece_to_id, METH_VARARGS, NULL}, { (char *)"SentencePieceProcessor_id_to_piece", _wrap_SentencePieceProcessor_id_to_piece, METH_VARARGS, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 39d8505..7c9420d 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -208,6 +208,19 @@ class TestSentencepieceProcessor(unittest.TestCase): sp.decode_pieces(sp.encode_as_pieces(line)) sp.decode_ids(sp.encode_as_ids(line)) + def test_serialized_proto(self): + text = u'I saw a girl with a telescope.' + self.assertNotEqual('', self.sp_.EncodeAsSerializedProto(text)) + self.assertNotEqual('', self.sp_.SampleEncodeAsSerializedProto(text, 10, 0.2)) + self.assertNotEqual('', self.sp_.NBestEncodeAsSerializedProto(text, 10)) + self.assertNotEqual('', self.sp_.DecodePiecesAsSerializedProto(['foo', 'bar'])) + self.assertNotEqual('', self.sp_.DecodeIdsAsSerializedProto([20, 30])) + self.assertNotEqual('', self.sp_.encode_as_serialized_proto(text)) + self.assertNotEqual('', self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)) + self.assertNotEqual('', self.sp_.nbest_encode_as_serialized_proto(text, 10)) + self.assertNotEqual('', self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])) + self.assertNotEqual('', self.sp_.decode_ids_as_serialized_proto([20, 30])) + def suite(): suite = unittest.TestSuite() diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 8c9c208..43dd11c 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -512,6 +512,41 @@ util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids, return Decode(pieces, spt); } +util::bytes SentencePieceProcessor::EncodeAsSerializedProto( + util::min_string_view input) const { + SentencePieceText spt; + if (!Encode(input, &spt).ok()) return ""; + return spt.SerializeAsString(); +} + +util::bytes SentencePieceProcessor::SampleEncodeAsSerializedProto( + util::min_string_view input, int nbest_size, float alpha) const { + SentencePieceText spt; + if (!SampleEncode(input, nbest_size, alpha, &spt).ok()) return ""; + return spt.SerializeAsString(); +} + +util::bytes SentencePieceProcessor::NBestEncodeAsSerializedProto( + util::min_string_view input, int nbest_size) const { + NBestSentencePieceText spt; + if (!NBestEncode(input, nbest_size, &spt).ok()) return ""; + return spt.SerializeAsString(); +} + +util::bytes SentencePieceProcessor::DecodePiecesAsSerializedProto( + const std::vector<std::string> &pieces) const { + SentencePieceText spt; + if (!Decode(pieces, &spt).ok()) return ""; + return spt.SerializeAsString(); +} + +util::bytes SentencePieceProcessor::DecodeIdsAsSerializedProto( + const std::vector<int> &ids) const { + SentencePieceText spt; + if (!Decode(ids, &spt).ok()) return ""; + return spt.SerializeAsString(); +} + #define CHECK_STATUS_OR_RETURN_DEFAULT(value) \ if (!status().ok()) { \ LOG(ERROR) << status().error_message() << "\nReturns default value " \ diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 61da691..ee5cd17 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -158,6 +158,11 @@ class min_string_view { const char *ptr_ = nullptr; size_t length_ = 0; }; + +// Redefine std::string for serialized_proto interface as Python's string is +// a Unicode string. We can enforce the return value to be raw byte sequence +// with SWIG's typemap. +using bytes = std::string; } // namespace util class SentencePieceProcessor { @@ -357,6 +362,25 @@ class SentencePieceProcessor { #undef DEFINE_SPP_DIRECT_FUNC_IMPL + // They are used in Python interface. Returns serialized proto. + // In python module, we can get access to the full Proto after + // deserialzing the returned byte sequence. + virtual util::bytes EncodeAsSerializedProto( + util::min_string_view input) const; + + virtual util::bytes SampleEncodeAsSerializedProto(util::min_string_view input, + int nbest_size, + float alpha) const; + + virtual util::bytes NBestEncodeAsSerializedProto(util::min_string_view input, + int nbest_size) const; + + virtual util::bytes DecodePiecesAsSerializedProto( + const std::vector<std::string> &pieces) const; + + virtual util::bytes DecodeIdsAsSerializedProto( + const std::vector<int> &ids) const; + ////////////////////////////////////////////////////////////// // Vocabulary management methods. // diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 4fbbab8..af1f6f1 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -149,6 +149,10 @@ TEST(SentencepieceProcessorTest, EncodeTest) { EXPECT_EQ(result[i].first, spt.pieces(i).piece()); } + SentencePieceText spt2; + EXPECT_TRUE(spt2.ParseFromString(sp.EncodeAsSerializedProto("ABC DEF"))); + EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); + EXPECT_EQ("ABC", spt.pieces(0).surface()); EXPECT_EQ(" DE", spt.pieces(1).surface()); EXPECT_EQ("F", spt.pieces(2).surface()); @@ -369,6 +373,11 @@ TEST(SentencepieceProcessorTest, NBestEncodeTest) { EXPECT_EQ(result[1].first[i].first, spt.nbests(1).pieces(i).piece()); } + NBestSentencePieceText spt2; + EXPECT_TRUE( + spt2.ParseFromString(sp.NBestEncodeAsSerializedProto("ABC DEF", 2))); + EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); + auto mock_empty = MakeUnique<MockModel>(); mock_empty->SetNBestEncodeResult(kInput, {}); sp.SetModel(std::move(mock_empty)); @@ -414,6 +423,11 @@ TEST(SentencepieceProcessorTest, SampleEncodeTest) { EXPECT_EQ(result[i].second, spt.pieces(i).id()); } + SentencePieceText spt2; + EXPECT_TRUE(spt2.ParseFromString( + sp.SampleEncodeAsSerializedProto("ABC DEF", -1, 0.5))); + EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); + EXPECT_NOT_OK(sp.SampleEncode("ABC DEF", 1024, 0.5, &output)); EXPECT_OK(sp.SampleEncode("ABC DEF", 0, 0.5, &output)); EXPECT_OK(sp.SampleEncode("ABC DEF", 1, 0.5, &output)); @@ -517,6 +531,10 @@ TEST(SentencepieceProcessorTest, DecodeTest) { EXPECT_EQ(16, spt.pieces(6).end()); EXPECT_EQ(16, spt.pieces(7).begin()); EXPECT_EQ(16, spt.pieces(7).end()); + + SentencePieceText spt2; + EXPECT_TRUE(spt2.ParseFromString(sp.DecodePiecesAsSerializedProto(input))); + EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); } // unk_surface is not defined. |