diff options
author | Frank Seide <fseide@microsoft.com> | 2020-07-24 20:23:05 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2020-07-24 20:23:05 +0300 |
commit | 435aa9505ef2d93e4a55fd4122df0dbbee2a6e60 (patch) | |
tree | a6b8010bdd1bbc0e3ce498bbcda7d7e299769154 /scripts | |
parent | c3fb60cbcd4f99ecd51adee835a071af41d77b9e (diff) |
Merged PR 14334: full ONNX conversion script
This PR adds a full ONNX conversion script that exports a Marian model and wraps it in a greedy-search implemented in ONNX.
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/onnx/example-greedy.py | 81 | ||||
-rw-r--r-- | scripts/onnx/marian_to_onnx.py | 215 | ||||
-rw-r--r-- | scripts/onnx/marian_to_onnx_example.py | 34 |
3 files changed, 249 insertions, 81 deletions
diff --git a/scripts/onnx/example-greedy.py b/scripts/onnx/example-greedy.py deleted file mode 100644 index 382e8a85..00000000 --- a/scripts/onnx/example-greedy.py +++ /dev/null @@ -1,81 +0,0 @@ -import onnxruntime as ort
-import numpy as np
-import onnx
-import os, sys, time
-
-os.environ['OMP_NUM_THREADS'] = '1'
-sess_options = ort.SessionOptions()
-sess_options.intra_op_num_threads = 1
-sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
-sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
-
-def get_function(path, output_vars):
- print("Reading ONNX function from", path)
- #model = onnx.load(path)
- #print("Done", flush=True)
- #print(model)
- ort_sess = ort.InferenceSession(path, sess_options)
- output_defs = ort_sess.get_outputs()
- for input in ort_sess.get_inputs():
- print(" input: ", input.name, input.shape, input.type)
- for output in output_defs:
- print(" output: ", output.name, output.shape, output.type)
- def invoke_model(**kwargs):
- def to_numpy(val):
- arr = np.array(val)
- if arr.dtype == np.double:
- arr = arr.astype(np.float32)
- elif arr.dtype == np.int64:
- arr = arr.astype(np.int32)
- return arr
- kwargs = { name: to_numpy(val) for name, val in kwargs.items() }
- output_vals = ort_sess.run(None, kwargs)
- output_dict = { output_def.name : output_val for output_val, output_def in zip(output_vals, output_defs) }
- return [output_dict[output_var] for output_var in output_vars]
- return invoke_model
-
-id2word = { id : word.rstrip() for id, word in enumerate(open('c:/work/marian-dev/local/model/vocab_v1.wl', encoding='utf-8').readlines()) }
-word2id = { word : id for id, word in id2word.items() }
-unk_id = word2id["<unk>"]
-
-model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans-nounk"
-encode_source = get_function(model_path_prefix + '.encode_source.onnx',
- ['encoder_context_0'])
-decode_first = get_function(model_path_prefix + '.decode_first.onnx',
- ['first_logits', 'first_decoder_state_0', 'first_decoder_state_1', 'first_decoder_state_2', 'first_decoder_state_3', 'first_decoder_state_4', 'first_decoder_state_5'])
-decode_next = get_function(model_path_prefix + '.decode_next.onnx',
- ['next_logits', 'next_decoder_state_0', 'next_decoder_state_1', 'next_decoder_state_2', 'next_decoder_state_3', 'next_decoder_state_4', 'next_decoder_state_5'])
-
-def greedy_decode(data_0):
- if len(data_0) == 1: # special handling for the empty sentence, like Marian
- return data_0
- data_0_mask = [[[1.]]] * len(data_0)
- data_0_index_range = [[[float(t)]] for t in range(len(data_0))]
- #print(data_0, data_0_mask, data_0_index_range)
-
- max_len = len(data_0) * 3
- Y = []
- encoder_context_0, *_ = encode_source(data_0=data_0, data_0_mask=data_0_mask, data_0_posrange=data_0_index_range)
- logp, *out_decoder_states = decode_first(data_1_posrange=[[[float(0)]]],
- encoder_context_0=encoder_context_0, data_0_mask=data_0_mask)
- logp[:,:,:,unk_id] = -1e8 # suppress <unk>, like Marian
- Y.append(np.argmax(logp[0][0]))
- while Y[-1] != 0 and len(Y) < max_len:
- logp, *out_decoder_states = decode_next(prev_word=[Y[-1]], data_1_posrange=[[[float(len(Y))]]],
- encoder_context_0=encoder_context_0, data_0_mask=data_0_mask,
- decoder_state_0=out_decoder_states[0], decoder_state_1=out_decoder_states[1],
- decoder_state_2=out_decoder_states[2], decoder_state_3=out_decoder_states[3],
- decoder_state_4=out_decoder_states[4], decoder_state_5=out_decoder_states[5])
- logp[:,:,:,unk_id] = -1e8
- Y.append(np.argmax(logp[0][0]))
- return Y
-
-start_time = time.time()
-with open("C:/work/marian-dev/local/model/predictions.out-onnx-debug-sin-notrans-first100-d.tok", 'wt', encoding='utf-8') as out_f:
- for line in open("C:/work/marian-dev/local/model/predictions.in-first100.tok", encoding='utf-8').readlines():
- data = [word2id.get(w, unk_id) for w in (line.rstrip() + " </s>").split(' ') if w]
- Y = greedy_decode(data)
- print("input: ", ' '.join(id2word[x] for x in data))
- print("output:", ' '.join(id2word[y] for y in Y))
- print(' '.join(id2word[y] for y in Y[:-1]), file=out_f, flush=True) # strip </s> for output to file
-print("--- %s seconds ---" % (time.time() - start_time))
diff --git a/scripts/onnx/marian_to_onnx.py b/scripts/onnx/marian_to_onnx.py new file mode 100644 index 00000000..180297e0 --- /dev/null +++ b/scripts/onnx/marian_to_onnx.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Library for converting certain types of Marian models to a standalone ONNX model. + +Because Marian and ONNX use very different philosophies, a conversion is not possible +for all possible Marian models. Specifically, currently we don't support recurrent +networks in the encoder. + +This works by running a Marian decode for 2 output steps, and capturing pieces of +this graph that correspond to the encoder, the first decoding steps, and the second +decoding step. The graph of the second decoding step can be applied repeatedly in +order to decoder a variable-length sequence. + +Prerequisites: +``` +pip install onnxruntime +git clone https://github.com/microsoft/onnxconverter-common.git +``` +""" + +import os, sys, inspect, subprocess +from typing import List, Dict, Optional, Callable + +# get the Marian root path +_marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + "/../.." + +# we assume onnxconverter-common to be available next to the marian-dev repo; you must adjust this if needed +sys.path.append(_marian_root_path + "/../onnxconverter-common") +from onnxconverter_common.onnx_fx import Graph +from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty +import onnxruntime as _ort +def _ort_apply_model(model, inputs): # ORT execution is a callback so that Graph itself does not need to depend on ORT + sess = _ort.InferenceSession(model.SerializeToString()) + return sess.run(None, inputs) +Graph.inference_runtime = _ort_apply_model +Graph.opset = 11 + + +def export_marian_model_components(marian_model_path: str, marian_vocab_paths: List[str], + marian_executable_path: Optional[str]=None) -> Dict[str,Graph]: + """ + Export the Marian graph to a set of models. + + Args: + marian_model_path: path to Marian model to convert + marian_vocab_paths: paths of vocab files (normally, this requires 2 entries, which may be identical) + marian_executable_path: path to Marian executable; will default to THIS_SCRIPT_PATH/../../build/marian + Returns: + Dict of ONNX Graph instances corresponding to pieces of Marian models. + """ + assert isinstance(marian_vocab_paths, list), "marian_vocab_paths must be a list of paths" + # default marian executable is found relative to location of this script (Linux/CMake only) + if marian_executable_path is None: + marian_executable_path = _marian_root_path + "/build/marian" + # partial models are written to /tmp + output_path_stem = "/tmp/" + os.path.basename(marian_model_path) + # exporting is done via invoking Marian via its command-line interface; models are written to tmp files + command = marian_executable_path + args = [ + "convert", + "--from", marian_model_path, + "--vocabs", *marian_vocab_paths, + "--to", output_path_stem, + "--export-as", "onnx-encode" + ] + subprocess.run([command] + args, check=True) + # load the tmp files into Python bytes objects + graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names + output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files + res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) } + # clean up after ourselves + for output_path in output_paths: + os.unlink(output_path) + return res + + +def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int): + """ + Create an ONNX model that implements greedy search over the exported Marian pieces. + + Args: + partial_models: models returned from export_marian_model_components() + num_decoder_layers: must be specified, since it cannot be inferred from the model files presently (e.g. 6) + Returns: + ONNX model that can be called as + result_ids = greedy_search_fn(np.array(source_ids, dtype=np.int64), np.array([target_eos_id], dtype=np.int64))[0] + """ + # load our partial functions + # ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here. + def define_parameter_order(graph, inputs, outputs): + tmppath = "/tmp/tmpmodel.onnx" + graph.save(tmppath) # unfortunately, Graph.load() cannot load from bytes, so use a tmp file + graph = Graph.load(tmppath, inputs=inputs, outputs=outputs) + os.unlink(tmppath) + return graph + encode_source = define_parameter_order(partial_models["encode_source"], + inputs=['data_0', 'data_0_mask', 'data_0_posrange'], # define the order of arguments + outputs=['encoder_context_0']) + decode_first = define_parameter_order(partial_models["decode_first"], + inputs=['data_1_posrange', 'encoder_context_0', 'data_0_mask'], + outputs=['first_logits'] + + [f"first_decoder_state_{i}" for i in range(num_decoder_layers)]) + decode_next = define_parameter_order(partial_models["decode_next"], + inputs=['prev_word', 'data_1_posrange', 'encoder_context_0', 'data_0_mask'] + + [f"decoder_state_{i}" for i in range(num_decoder_layers)], + outputs=['next_logits'] + + [f"next_decoder_state_{i}" for i in range(num_decoder_layers)]) + + # create an ONNX graph that implements full greedy search + # The greedy search is implemented via the @Graph.trace decorator, which allows us to + # author the greedy search in Python, similar to @CNTK.Function and PyTorch trace-based jit. + # The decorator executes greedy_search() below on a dummy input in order to generate an ONNX graph + # via invoking operators from the onnx.fx library. + # The partial functions exported from Marian are invoked (=inlined) by this. + # The result is a full ONNX graph that implements greedy search using the Marian model. + @Graph.trace( + input_types=[_Ty.I(shape=['N']), _Ty.I([1])], + output_types=[_Ty.I(shape=['T'])], + outputs="Y") + def greedy_search(X, eos_id): + """ + Args: + X: sequence of input tokens, including EOS symbol, as integer indices into the input vocabulary + eos_id: id of the EOS symbol in the output vocabulary + """ + ox = X.ox + data_0 = X + data_0_shape = data_0.shape() + data_0_mask = ox.constant_of_shape(data_0_shape, value=1.0) + seq_len = data_0_shape[-1] + data_0_index_range = ox.range([ox.constant(value=0), seq_len, ox.constant(value=1)]).cast(to=ox.float) + data_0_index_range = ox.unsqueeze(data_0_index_range, axes=[1, 2]) + max_len = seq_len * 3 + + encoder_context_0 = encode_source(data_0=data_0, data_0_mask=data_0_mask, + data_0_posrange=data_0_index_range) + + y_len_0 = ox.constant(value=0.0) + logp, *out_decoder_states = decode_first(data_1_posrange=y_len_0, + encoder_context_0=encoder_context_0, data_0_mask=data_0_mask) + + y_t = logp[0, 0, 0].argmax(axis=-1, keepdims=True) # note: rank-1 tensor, not a scalar + eos_token = eos_id + 0 + test_y_t = (y_t != eos_token) + + @Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(num_decoder_layers)), 'y_t_o2'], + output_types=[_Ty.b, _Ty.i] + [_Ty.f] * 6 + [_Ty.i], + input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * num_decoder_layers) + def loop_body(iteration_count, condition, # these are not actually used inside + y_t, + out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5): + # Currently, we do not support variable number of arguments to the callable. + # @TODO: We have the information from the type signature in Graph.trace(), so this should be possible. + assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change" + out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5] + """ + Loop body follows the requirements of ONNX Loop: + + "The graph run each iteration. + It has 2+N inputs: (iteration_num, condition, loop carried dependencies...). + It has 1+N+K outputs: (condition, loop carried dependencies..., scan_outputs...). + Each scan_output is created by concatenating the value of the specified output value at the end of each iteration of the loop. + It is an error if the dimensions or data type of these scan_outputs change across loop iterations." + + Inputs: + iteration_num (not used by our function) + test_y_t: condition (not used as an input) + y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies + + Outputs: + test_y_t: condition, return True if there is more to decode + y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies (same as in the Inputs section) + y_t: K=1 outputs + """ + pos = iteration_count + 1 + data_1_posrange = pos.cast(to=1).unsqueeze(axes=[0, 1, 2]) + logp, *out_decoder_states = decode_next( + prev_word=y_t, data_1_posrange=data_1_posrange, + encoder_context_0=encoder_context_0, data_0_mask=data_0_mask, + **{f"decoder_state_{i}": out_decoder_states[i] for i in range(len(out_decoder_states))}) + y_t = logp[0, 0, 0].argmax(axis=-1, keepdims=True) + test_y_t = (y_t != eos_token) + return [test_y_t, y_t] + out_decoder_states + [y_t] + + # "Final N loop carried dependency values then K scan_outputs" + ret_vals = ox.loop(max_len, test_y_t, loop_body, + inputs=[y_t] + out_decoder_states, + outputs=['gy_t_o', *[f"gods_{i}" for i in range(len(out_decoder_states))], 'greedy_out']) + y = ret_vals[-1] # scan_output + + # we must prepend the very first token + Y = ox.concat([ox.unsqueeze(y_t), y], axis=0) # note: y_t are rank-1 tensors, not scalars (ORT concat fails with scalars) + return ox.squeeze(Y, axes=[1]) + greedy_search.to_model() # this triggers the model tracing (which is lazy) + return greedy_search + + +def apply_model(greedy_search_fn: Callable, source_ids: List[int], target_eos_id: int) -> List[int]: + """ + Apply model to an input sequence, e.g. run translation. + + Args: + greedy_search_fn: ONNX model created with combine_model_components_with_greedy_search()\ + source_ids: list of source tokens, as indices into soure vocabulary, ending in EOS symbol + target_eos_id: id of EOS symbol in target vocabulary + Returns: + Result as list of ids into target vocabulary + """ + import numpy as np + Y = greedy_search_fn( + np.array(source_ids, dtype=np.int64), + np.array([target_eos_id], dtype=np.int64))[0] + return Y diff --git a/scripts/onnx/marian_to_onnx_example.py b/scripts/onnx/marian_to_onnx_example.py new file mode 100644 index 00000000..97510d6c --- /dev/null +++ b/scripts/onnx/marian_to_onnx_example.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Example program demonstrating how to convert a Marian model using the marian_to_onnx library +to a self-contained ONNX model that implements greedy search. +""" + +import os, sys +import marian_to_onnx as mo + +# The following variables would normally be command-line arguments. +# We use constants here to keep it simple. Please just adjust these as needed. +my_dir = os.path.expanduser("~/") +marian_npz = my_dir + "model.npz.best-ce-mean-words.npz" # path to the Marian model to convert +num_decoder_layers = 6 # number of decoder layers +marian_vocs = [my_dir + "vocab_v1.wl"] * 2 # path to the vocabularies for source and target +onnx_model_path = my_dir + "model.npz.best-ce-mean-words.onnx" # resulting model gets written here + +# export Marian model as multiple ONNX models +partial_models = mo.export_marian_model_components(marian_npz, marian_vocs) + +# use the ONNX models in a greedy-search +# The result is a fully self-contained model that implements greedy search. +onnx_model = mo.combine_model_components_with_greedy_search(partial_models, num_decoder_layers) + +# save as ONNX file +onnx_model.save(onnx_model_path) + +# run a test sentence +Y = mo.apply_model(greedy_search_fn=onnx_model, + source_ids=[274, 35, 52, 791, 59, 4060, 6, 2688, 2, 7744, 9, 2128, 7, 2, 4695, 9, 950, 2561, 3, 0], + target_eos_id=0) +print(Y.shape, Y) |