Skip to content
Snippets Groups Projects
Commit 9e0298bc authored by michael.divia's avatar michael.divia
Browse files

Dynamic ONNX converter

parent cd94e0ed
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf import tensorflow as tf
import tf2onnx import tf2onnx
import argparse
# WHAT ?
parser = argparse.ArgumentParser(description="WHAT ?!")
parser.add_argument("--model", choices=["1", "2"], default="1",
help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()
# Load Sequential model # Load Sequential model
seq_model = tf.keras.models.load_model("../models/resnet50.h5", compile=False) if args.model == "1":
seq_model = tf.keras.models.load_model("../models/ResNet50/pokemon_resnet50.h5", compile=False)
elif args.model == "2":
seq_model = tf.keras.models.load_model("../models/ResNet50/pokemon_xception.h5", compile=False)
# Create input layer with same shape # Create input layer with same shape
inputs = tf.keras.Input(shape=(224, 224, 3), name="input") inputs = tf.keras.Input(shape=(224, 224, 3), name="input")
...@@ -15,9 +25,18 @@ model = tf.keras.Model(inputs=inputs, outputs=outputs) ...@@ -15,9 +25,18 @@ model = tf.keras.Model(inputs=inputs, outputs=outputs)
# Convert to ONNX # Convert to ONNX
spec = (tf.TensorSpec((1, 224, 224, 3), tf.float32, name="input"),) spec = (tf.TensorSpec((1, 224, 224, 3), tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(
model, if args.model == "1":
input_signature=spec, onnx_model, _ = tf2onnx.convert.from_keras(
opset=13, model,
output_path="../models/resnet50.onnx" input_signature=spec,
) opset=13,
\ No newline at end of file output_path="../models/ResNet50/pokemon_resnet50.onnx"
)
elif args.model == "2":
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=spec,
opset=13,
output_path="../models/ResNet50/pokemon_xception.onnx"
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment