Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/sentencepiece.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaku Kudo <taku@google.com>2018-07-12 19:24:46 +0300
committerTaku Kudo <taku@google.com>2018-07-12 19:24:46 +0300
commit256c6f5bb731c567c897999e4dca35e171f3b212 (patch)
tree7684b61ae6a175a3e5fc6d9f3423261bae7fb6b1 /tensorflow
parent983c0f5aeb26d6963c3adef94b12e2ea1595dac9 (diff)
Added new API to get bos/eos/unk/pad ids
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/sentencepiece_processor_ops.cc273
-rwxr-xr-xtensorflow/setup.py2
-rwxr-xr-xtensorflow/test/tf_sentencepiece_test.py26
-rwxr-xr-xtensorflow/tf_sentencepiece/_sentencepiece_processor_ops.sobin3508648 -> 3523600 bytes
-rw-r--r--tensorflow/tf_sentencepiece/sentencepiece_processor_ops.py63
5 files changed, 273 insertions, 91 deletions
diff --git a/tensorflow/sentencepiece_processor_ops.cc b/tensorflow/sentencepiece_processor_ops.cc
index 5d2df57..4226e6f 100644
--- a/tensorflow/sentencepiece_processor_ops.cc
+++ b/tensorflow/sentencepiece_processor_ops.cc
@@ -1,4 +1,4 @@
-// Copyright 2016 Google Inc.
+// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
+#include <mutex>
#include <string>
+#include <unordered_map>
#include <vector>
#include "sentencepiece_processor.h"
@@ -20,12 +22,15 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/hash/hash.h"
typedef int int32;
typedef long long int int64;
+typedef unsigned long long int uint64;
namespace sentencepiece {
using ::tensorflow::DEVICE_CPU;
+using ::tensorflow::Hash64;
using ::tensorflow::OpKernel;
using ::tensorflow::OpKernelConstruction;
using ::tensorflow::OpKernelContext;
@@ -45,90 +50,124 @@ namespace {
::tensorflow::string(s.error_message()));
}
-// A factory function to initialize SentencePieceProcessor with
-// OpKernelConstruction `context`.
-enum InitType { GENERAL, ENCODE, DECODE }; // purpose of processor.
+// Global cache to reuse SentencePieceProcessor with the same
+// model file or model proto. The instance is managed with shared_ptr so
+// the instance is deleted when no client is using it (refcount is zero).
+class SentencePieceProcessorCache {
+ public:
+ std::shared_ptr<SentencePieceProcessor> get(
+ const std::string key, bool is_proto,
+ sentencepiece::util::Status* status) {
+ std::lock_guard<std::mutex> l(mutex_);
-void InitializeModel(OpKernelConstruction* context,
- SentencePieceProcessor* sentencepiece_processor,
- InitType type) {
- std::string model_file_attr, model_proto_attr;
- OP_REQUIRES_OK(context, context->GetAttr("model_file", &model_file_attr));
- OP_REQUIRES_OK(context, context->GetAttr("model_proto", &model_proto_attr));
+ const uint64 fp = Hash64(key.data(), key.size());
+ auto sp = data_[fp].lock();
- if (!model_file_attr.empty()) {
- OP_REQUIRES(
- context, model_proto_attr.empty(),
- ::tensorflow::errors::InvalidArgument(
- "`model_proto` must be empty when `model_file` is specified."));
- OP_REQUIRES_OK(context,
- ToTFStatus(sentencepiece_processor->Load(model_file_attr)));
- } else {
- // Loads serialized sentencepiece model proto to enable embedding the
- // relatively small sentencepiece model proto into the tensorflow graph
- // such that the tensorflow graph is self-contained.
- OP_REQUIRES_OK(context,
- ToTFStatus(sentencepiece_processor->LoadFromSerializedProto(
- model_proto_attr)));
+ if (sp) {
+ *status = sp->status();
+ return sp;
+ }
+
+ sp = std::make_shared<SentencePieceProcessor>();
+ *status = is_proto ? sp->LoadFromSerializedProto(key) : sp->Load(key);
+ if (!status->ok()) return nullptr;
+
+ data_[fp] = sp;
+ return sp;
}
- // Sets extra options to add <s>, </s>.
- std::string options;
- auto add_options = [&options, &context](const std::string& name,
- const std::string& v) {
- bool flag = false;
- OP_REQUIRES_OK(context, context->GetAttr(name, &flag));
- if (flag) {
- if (!options.empty()) options += ':';
- options += v;
+ private:
+ std::mutex mutex_;
+ std::unordered_map<uint64, std::weak_ptr<SentencePieceProcessor>> data_;
+};
+
+class SentencePieceBaseOp : public OpKernel {
+ public:
+ explicit SentencePieceBaseOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ std::string model_file_attr, model_proto_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("model_file", &model_file_attr));
+ OP_REQUIRES_OK(context, context->GetAttr("model_proto", &model_proto_attr));
+
+ // Initializes global cache.
+ static SentencePieceProcessorCache* cache = new SentencePieceProcessorCache;
+ sentencepiece::util::Status status;
+
+ OP_REQUIRES(context,
+ ((model_proto_attr.empty() && !model_file_attr.empty()) ||
+ (!model_proto_attr.empty() && model_file_attr.empty())),
+ ::tensorflow::errors::InvalidArgument(
+ "Either `model_proto` or `model_file` must be set."));
+
+ if (!model_file_attr.empty()) {
+ sentencepiece_processor_ = cache->get(model_file_attr, false, &status);
+ } else {
+ // Loads serialized sentencepiece model proto to enable embedding the
+ // relatively small sentencepiece model proto into the tensorflow graph
+ // such that the tensorflow graph is self-contained.
+ sentencepiece_processor_ = cache->get(model_proto_attr, true, &status);
}
- };
- if (type == ENCODE || type == DECODE) {
- add_options("reverse", "reverse");
- }
+ OP_REQUIRES_OK(context, ToTFStatus(status));
+ OP_REQUIRES(context, sentencepiece_processor_,
+ ::tensorflow::errors::InvalidArgument(
+ "Failed to initialize SentencePieceProcessor"));
+
+ // Sets extra options to add <s>, </s>.
+ auto has_attribute = [&context](const std::string& name) {
+ bool flag = false;
+ context->GetAttr(name, &flag);
+ return flag;
+ };
+
+ if (has_attribute("add_bos")) {
+ bos_id_ = sentencepiece_processor_->bos_id();
+ OP_REQUIRES(context, bos_id_ >= 0,
+ ::tensorflow::errors::InvalidArgument(
+ "`bos_id` is not defined in model"));
+ }
- if (type == ENCODE) {
- add_options("add_bos", "bos");
- add_options("add_eos", "eos");
- OP_REQUIRES_OK(
- context,
- ToTFStatus(sentencepiece_processor->SetEncodeExtraOptions(options)));
- } else if (type == DECODE) {
- OP_REQUIRES_OK(
- context,
- ToTFStatus(sentencepiece_processor->SetDecodeExtraOptions(options)));
+ if (has_attribute("add_eos")) {
+ eos_id_ = sentencepiece_processor_->eos_id();
+ OP_REQUIRES(context, eos_id_ >= 0,
+ ::tensorflow::errors::InvalidArgument(
+ "`eos_id` is not defined in model"));
+ }
+
+ reverse_ = has_attribute("reverse");
+
+ pad_id_ = sentencepiece_processor_->pad_id();
+ if (pad_id_ == -1) pad_id_ = sentencepiece_processor_->unk_id();
}
-}
+
+ protected:
+ std::shared_ptr<SentencePieceProcessor> sentencepiece_processor_;
+ int bos_id_ = -1;
+ int eos_id_ = -1;
+ int pad_id_ = -1;
+ bool reverse_ = false;
+};
} // namespace
-class SentencePieceGetPieceSizeOp : public OpKernel {
+class SentencePieceGetPieceSizeOp : public SentencePieceBaseOp {
public:
explicit SentencePieceGetPieceSizeOp(OpKernelConstruction* context)
- : OpKernel(context) {
- SentencePieceProcessor sp;
- InitializeModel(context, &sp, GENERAL);
- vocab_size_ = sp.GetPieceSize();
- }
+ : SentencePieceBaseOp(context) {}
void Compute(OpKernelContext* context) override {
Tensor* vocab_size_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, {}, &vocab_size_tensor));
- vocab_size_tensor->scalar<int32>()() = vocab_size_;
+ vocab_size_tensor->scalar<int32>()() =
+ sentencepiece_processor_->GetPieceSize();
}
-
- private:
- int32 vocab_size_ = 0;
};
template <typename S, typename T>
-class SentencePieceConvertPieceOp : public OpKernel {
+class SentencePieceConvertPieceOp : public SentencePieceBaseOp {
public:
explicit SentencePieceConvertPieceOp(OpKernelConstruction* context)
- : OpKernel(context) {
- InitializeModel(context, &sentencepiece_processor_, GENERAL);
- }
+ : SentencePieceBaseOp(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor = nullptr;
@@ -142,27 +181,62 @@ class SentencePieceConvertPieceOp : public OpKernel {
}
int32 Convert(const std::string& piece) const {
- return sentencepiece_processor_.PieceToId(piece);
+ return sentencepiece_processor_->PieceToId(piece);
}
std::string Convert(int32 id) const {
- if (id >= 0 && id < sentencepiece_processor_.GetPieceSize()) {
- return sentencepiece_processor_.IdToPiece(id);
+ if (id >= 0 && id < sentencepiece_processor_->GetPieceSize()) {
+ return sentencepiece_processor_->IdToPiece(id);
}
return "";
}
+};
+
+class SentencePieceGetPieceTypeOp : public SentencePieceBaseOp {
+ public:
+ explicit SentencePieceGetPieceTypeOp(OpKernelConstruction* context)
+ : SentencePieceBaseOp(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("piece_type", &piece_type_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor = nullptr;
+ OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(),
+ &output_tensor));
+
+ for (int i = 0; i < input_tensor->NumElements(); ++i) {
+ const int id = input_tensor->flat<int32>()(i);
+ switch (piece_type_) {
+ case 0:
+ output_tensor->flat<bool>()(i) =
+ sentencepiece_processor_->IsUnknown(id);
+ break;
+ case 1:
+ output_tensor->flat<bool>()(i) =
+ sentencepiece_processor_->IsControl(id);
+ break;
+ case 2:
+ output_tensor->flat<bool>()(i) =
+ sentencepiece_processor_->IsUnused(id);
+ break;
+ default:
+ break;
+ }
+ }
+ }
private:
- SentencePieceProcessor sentencepiece_processor_;
+ int piece_type_;
};
template <typename T>
-class SentencePieceEncodeOpBase : public OpKernel {
+class SentencePieceEncodeOpBase : public SentencePieceBaseOp {
public:
explicit SentencePieceEncodeOpBase(OpKernelConstruction* context)
- : OpKernel(context) {
- InitializeModel(context, &sentencepiece_processor_, ENCODE);
- }
+ : SentencePieceBaseOp(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor = nullptr;
@@ -207,26 +281,40 @@ class SentencePieceEncodeOpBase : public OpKernel {
? nbest_size_tensor->vec<int32>()(i)
: nbest_size_tensor->scalar<int32>()();
if (nbest_size == 0 || nbest_size == 1) {
- OP_REQUIRES_OK(context, ToTFStatus(sentencepiece_processor_.Encode(
+ OP_REQUIRES_OK(context, ToTFStatus(sentencepiece_processor_->Encode(
input_sentences(i), &pieces[i])));
} else {
const float alpha = alpha_tensor->dims() == 1
? alpha_tensor->vec<float>()(i)
: alpha_tensor->scalar<float>()();
OP_REQUIRES_OK(context,
- ToTFStatus(sentencepiece_processor_.SampleEncode(
+ ToTFStatus(sentencepiece_processor_->SampleEncode(
input_sentences(i), nbest_size, alpha, &pieces[i])));
}
+ RewritePieces(&pieces[i]);
}
MakeOutputTensor(context, pieces);
}
- private:
+ protected:
+ void RewritePieces(std::vector<std::string>* pieces) const {
+ if (reverse_) std::reverse(pieces->begin(), pieces->end());
+ if (bos_id_ > 0)
+ pieces->insert(pieces->begin(),
+ sentencepiece_processor_->IdToPiece(bos_id_));
+ if (eos_id_ > 0)
+ pieces->push_back(sentencepiece_processor_->IdToPiece(eos_id_));
+ }
+
+ void RewritePieces(std::vector<int32>* pieces) const {
+ if (reverse_) std::reverse(pieces->begin(), pieces->end());
+ if (bos_id_ > 0) pieces->insert(pieces->begin(), bos_id_);
+ if (eos_id_ > 0) pieces->push_back(eos_id_);
+ }
+
virtual void MakeOutputTensor(OpKernelContext* context,
const std::vector<std::vector<T>>& pieces) = 0;
-
- SentencePieceProcessor sentencepiece_processor_;
};
template <typename T>
@@ -235,7 +323,7 @@ class SentencePieceEncodeSparseOp : public SentencePieceEncodeOpBase<T> {
explicit SentencePieceEncodeSparseOp(OpKernelConstruction* context)
: SentencePieceEncodeOpBase<T>(context) {}
- private:
+ protected:
void MakeOutputTensor(OpKernelContext* context,
const std::vector<std::vector<T>>& pieces) override {
const int64 batch_size = pieces.size();
@@ -292,7 +380,7 @@ class SentencePieceEncodeDenseOp : public SentencePieceEncodeOpBase<T> {
explicit SentencePieceEncodeDenseOp(OpKernelConstruction* context)
: SentencePieceEncodeOpBase<T>(context) {}
- private:
+ // protected:
void MakeOutputTensor(OpKernelContext* context,
const std::vector<std::vector<T>>& pieces) override {
const int64 batch_size = pieces.size();
@@ -326,12 +414,10 @@ class SentencePieceEncodeDenseOp : public SentencePieceEncodeOpBase<T> {
};
template <typename T>
-class SentencePieceDecodeOp : public OpKernel {
+class SentencePieceDecodeOp : public SentencePieceBaseOp {
public:
explicit SentencePieceDecodeOp(OpKernelConstruction* context)
- : OpKernel(context) {
- InitializeModel(context, &sentencepiece_processor_, DECODE);
- }
+ : SentencePieceBaseOp(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor = nullptr;
@@ -368,15 +454,13 @@ class SentencePieceDecodeOp : public OpKernel {
sequence_length(i) <= max_sequence_length),
::tensorflow::errors::InvalidArgument(
"`sequence_length` is out-of-range."));
- const std::vector<T> pieces(&input_sentences(i, 0),
- &input_sentences(i, 0) + sequence_length(i));
- OP_REQUIRES_OK(context, ToTFStatus(sentencepiece_processor_.Decode(
+ std::vector<T> pieces(&input_sentences(i, 0),
+ &input_sentences(i, 0) + sequence_length(i));
+ if (reverse_) std::reverse(pieces.begin(), pieces.end());
+ OP_REQUIRES_OK(context, ToTFStatus(sentencepiece_processor_->Decode(
pieces, &values_tensor_output(i))));
}
}
-
- private:
- SentencePieceProcessor sentencepiece_processor_;
};
namespace {
@@ -384,6 +468,7 @@ namespace {
constexpr char kGetPieceSizeOpName[] = "SentencepieceGetPieceSize";
constexpr char kPieceToIdOpName[] = "SentencepiecePieceToId";
constexpr char kIdToPieceOpName[] = "SentencepieceIdToPiece";
+constexpr char kGetPieceTypeOpName[] = "SentencepieceGetPieceType";
constexpr char kEncodeDenseOpName[] = "SentencepieceEncodeDense";
constexpr char kEncodeSparseOpName[] = "SentencepieceEncodeSparse";
constexpr char kDecodeOpName[] = "SentencepieceDecode";
@@ -427,6 +512,20 @@ REGISTER_OP(kIdToPieceOpName)
REGISTER_KERNEL_BUILDER(Name(kIdToPieceOpName).Device(DEVICE_CPU),
SentencePieceConvertPieceOp<int32, std::string>);
+REGISTER_OP(kGetPieceTypeOpName)
+ .Input("input: int32")
+ .Output("values: bool")
+ .Attr("model_file: string = ''")
+ .Attr("model_proto: string = ''")
+ .Attr("piece_type: int = 0")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return ::tensorflow::Status::OK();
+ });
+
+REGISTER_KERNEL_BUILDER(Name(kGetPieceTypeOpName).Device(DEVICE_CPU),
+ SentencePieceGetPieceTypeOp);
+
REGISTER_OP(kEncodeDenseOpName)
.Attr("out_type: {int32, string} = DT_INT32")
.Input("input: string")
diff --git a/tensorflow/setup.py b/tensorflow/setup.py
index aca3285..38fc8c0 100755
--- a/tensorflow/setup.py
+++ b/tensorflow/setup.py
@@ -25,7 +25,7 @@ setup(name = 'tf_sentencepiece',
author = 'Taku Kudo',
author_email='taku@google.com',
description = 'SentencePiece Encode/Decode ops for TensorFlow',
- version='0.1.1',
+ version='0.1.2',
url = 'https://github.com/google/sentencepiece',
license = 'Apache',
platforms = 'Unix',
diff --git a/tensorflow/test/tf_sentencepiece_test.py b/tensorflow/test/tf_sentencepiece_test.py
index e1a7b52..ad7cf9c 100755
--- a/tensorflow/test/tf_sentencepiece_test.py
+++ b/tensorflow/test/tf_sentencepiece_test.py
@@ -189,6 +189,32 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
self.assertEqual(ids.eval().tolist(), expected_ids)
self.assertEqual(pieces.eval().tolist(), expected_pieces)
+ def testGetPieceType(self):
+ sentencepiece_model_file = self._getSentencePieceModelFile()
+ processor = spm.SentencePieceProcessor()
+ processor.Load(sentencepiece_model_file)
+ expected_is_unknown = []
+ expected_is_control = []
+ expected_is_unused = []
+ ids = []
+
+ for i in range(processor.GetPieceSize()):
+ ids.append(i)
+ expected_is_unknown.append(processor.IsUnknown(i))
+ expected_is_control.append(processor.IsControl(i))
+ expected_is_unused.append(processor.IsUnused(i))
+
+ with tf.Session():
+ s = tf.constant(ids)
+ is_unknown = tfspm.is_unknown(s, model_file=sentencepiece_model_file)
+ is_control = tfspm.is_control(s, model_file=sentencepiece_model_file)
+ is_unused = tfspm.is_unused(s, model_file=sentencepiece_model_file)
+
+ self.assertEqual(is_unknown.eval().tolist(), expected_is_unknown)
+ self.assertEqual(is_control.eval().tolist(), expected_is_control)
+ self.assertEqual(is_unused.eval().tolist(), expected_is_unused)
+
+
def testLoadModelProto(self):
# Makes a serialized model proto.
model_proto = open(self._getSentencePieceModelFile(), 'rb').read()
diff --git a/tensorflow/tf_sentencepiece/_sentencepiece_processor_ops.so b/tensorflow/tf_sentencepiece/_sentencepiece_processor_ops.so
index db7fe23..1e54b80 100755
--- a/tensorflow/tf_sentencepiece/_sentencepiece_processor_ops.so
+++ b/tensorflow/tf_sentencepiece/_sentencepiece_processor_ops.so
Binary files differ
diff --git a/tensorflow/tf_sentencepiece/sentencepiece_processor_ops.py b/tensorflow/tf_sentencepiece/sentencepiece_processor_ops.py
index baafd6c..4fe263d 100644
--- a/tensorflow/tf_sentencepiece/sentencepiece_processor_ops.py
+++ b/tensorflow/tf_sentencepiece/sentencepiece_processor_ops.py
@@ -77,6 +77,60 @@ def id_to_piece(input, model_file=None, model_proto=None, name=None):
input, model_file=model_file, model_proto=model_proto, name=name)
+def is_unknown(input, model_file=None, model_proto=None, name=None):
+ """Returns true if input id is unknown piece.
+
+ Args:
+ input: An arbitrary tensor of int32.
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto.
+ Either `model_file` or `model_proto` must be set.
+ name: The name argument that is passed to the op function.
+ Returns:
+ A tensor of bool with the same shape as input.
+ """
+
+ return _gen_sentencepiece_processor_op.sentencepiece_get_piece_type(
+ input, model_file=model_file, model_proto=model_proto, name=name,
+ piece_type=0)
+
+
+def is_control(input, model_file=None, model_proto=None, name=None):
+ """Returns true if input id is control piece.
+
+ Args:
+ input: An arbitrary tensor of int32.
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto.
+ Either `model_file` or `model_proto` must be set.
+ name: The name argument that is passed to the op function.
+ Returns:
+ A tensor of bool with the same shape as input.
+ """
+
+ return _gen_sentencepiece_processor_op.sentencepiece_get_piece_type(
+ input, model_file=model_file, model_proto=model_proto, name=name,
+ piece_type=1)
+
+
+def is_unused(input, model_file=None, model_proto=None, name=None):
+ """Returns true if input id is unused piece.
+
+ Args:
+ input: An arbitrary tensor of int32.
+ model_file: The sentencepiece model file path.
+ model_proto: The sentencepiece model serialized proto.
+ Either `model_file` or `model_proto` must be set.
+ name: The name argument that is passed to the op function.
+ Returns:
+ A tensor of bool with the same shape as input.
+ """
+
+ return _gen_sentencepiece_processor_op.sentencepiece_get_piece_type(
+ input, model_file=model_file, model_proto=model_proto, name=name,
+ piece_type=2)
+
+
def encode_dense(input_sentences, nbest_size=0, alpha=1.0,
model_file=None, model_proto=None,
reverse=False, add_bos=False, add_eos=False,
@@ -115,9 +169,6 @@ def encode_dense(input_sentences, nbest_size=0, alpha=1.0,
reverse=reverse, add_bos=add_bos, add_eos=add_eos,
out_type=out_type, name=name)
-# Adds an alias for encode_dense. Accepts the `encode` function.
-encode = encode_dense
-
def encode_sparse(input_sentences, nbest_size=0, alpha=1.0,
model_file=None, model_proto=None,
@@ -183,10 +234,16 @@ def decode(pieces, sequence_length, model_file=None, model_proto=None,
pieces, sequence_length, model_file=model_file,
model_proto=model_proto, reverse=reverse, name=name)
+# Adds an alias for encode_dense. Accepts the `encode` function.
+encode = encode_dense
+sparse_encode = encode_sparse
+dense_encode = encode_dense
+
tf.NotDifferentiable('SentencepieceGetPieceSize')
tf.NotDifferentiable('SentencepieceIdToPiece')
tf.NotDifferentiable('SentencepiecePieceToId')
+tf.NotDifferentiable('SentencepieceGetPieceType')
tf.NotDifferentiable('SentencepieceEncodeDense')
tf.NotDifferentiable('SentencepieceEncodeSparse')
tf.NotDifferentiable('SentencepieceDecode')