diff --git a/Python/convert_onnx.py b/Python/convert_onnx.py index 29f305027091e6d6a990151ff15a79adb04c1e5e..d59ebffea8b581a0a1419592af7dd12566132a9b 100644 --- a/Python/convert_onnx.py +++ b/Python/convert_onnx.py @@ -1,8 +1,18 @@ import tensorflow as tf import tf2onnx +import argparse + +# WHAT ? +parser = argparse.ArgumentParser(description="WHAT ?!") +parser.add_argument("--model", choices=["1", "2"], default="1", + help="1 = ResNet50, 2 = Xception") +args = parser.parse_args() # Load Sequential model -seq_model = tf.keras.models.load_model("../models/resnet50.h5", compile=False) +if args.model == "1": + seq_model = tf.keras.models.load_model("../models/ResNet50/pokemon_resnet50.h5", compile=False) +elif args.model == "2": + seq_model = tf.keras.models.load_model("../models/ResNet50/pokemon_xception.h5", compile=False) # Create input layer with same shape inputs = tf.keras.Input(shape=(224, 224, 3), name="input") @@ -15,9 +25,18 @@ model = tf.keras.Model(inputs=inputs, outputs=outputs) # Convert to ONNX spec = (tf.TensorSpec((1, 224, 224, 3), tf.float32, name="input"),) -onnx_model, _ = tf2onnx.convert.from_keras( - model, - input_signature=spec, - opset=13, - output_path="../models/resnet50.onnx" -) \ No newline at end of file + +if args.model == "1": + onnx_model, _ = tf2onnx.convert.from_keras( + model, + input_signature=spec, + opset=13, + output_path="../models/ResNet50/pokemon_resnet50.onnx" + ) +elif args.model == "2": + onnx_model, _ = tf2onnx.convert.from_keras( + model, + input_signature=spec, + opset=13, + output_path="../models/ResNet50/pokemon_xception.onnx" + ) \ No newline at end of file