from tensorflow import keras import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import os import random import json import argparse # --- WHAT ? --- parser = argparse.ArgumentParser(description="WHAT ?!") parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception") args = parser.parse_args() # Paths if args.model == "1": h5_path = "../models/ResNet50/pokedex_ResNet50.h5" json_path = "../models/ResNet50/class_names.json" size = (224,224) elif args.model == "2": h5_path = "../models/Xception/pokedex_Xception.h5" json_path = "../models/Xception/class_names.json" size = (256,256) # --- Load class names from JSON --- with open(json_path, "r") as f: class_names = json.load(f) class_names = [class_names[i] for i in range(len(class_names))] # convert to list # --- Load trained model --- model = keras.models.load_model(h5_path) # --- Paths --- base_path = "../Combined_Dataset" # --- Prepare 2x2 Plot --- plt.figure(figsize=(10, 10)) for i in range(4): # Pick random class and image random_class = random.choice(class_names) class_folder = os.path.join(base_path, random_class) random_image = random.choice([ f for f in os.listdir(class_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ]) img_path = os.path.join(class_folder, random_image) # --- Load & Preprocess Image --- img = keras.utils.load_img(img_path, target_size=size) # resize to match model input img_array = keras.utils.img_to_array(img) img_array = img_array / 255.0 # normalize if your model expects it img_array = tf.expand_dims(img_array, 0) # --- Predict --- predictions = model.predict(img_array, verbose=0) probabilities = tf.nn.softmax(predictions[0]) predicted_class_index = np.argmax(probabilities) predicted_label = class_names[predicted_class_index] confidence = 100 * probabilities[predicted_class_index] top_5_indices = np.argsort(probabilities)[-5:][::-1] print("\nTop 5 predictions:") for idx in top_5_indices: print(f"{class_names[idx]:<20}: {probabilities[idx]:.4f}") # Compare with actual is_correct = predicted_label == random_class # --- Plot --- ax = plt.subplot(2, 2, i + 1) plt.imshow(img) plt.axis("off") plt.title( f"Pred: {predicted_label}\n" f"True: {random_class}\n" f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%", fontsize=10 ) plt.tight_layout() plt.show()