Skip to content
Snippets Groups Projects
pokedex_test.py 2.50 KiB
from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import json
import argparse

# --- WHAT ? ---
parser = argparse.ArgumentParser(description="WHAT ?!")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()

# Paths
if args.model == "1":
    h5_path = "../models/ResNet50/pokedex_ResNet50.h5"
    json_path = "../models/ResNet50/class_names.json"
    size = (224,224)
elif args.model == "2":
    h5_path = "../models/Xception/pokedex_Xception.h5"
    json_path = "../models/Xception/class_names.json"
    size = (256,256)

# --- Load class names from JSON ---
with open(json_path, "r") as f:
    class_names = json.load(f)
    class_names = [class_names[i] for i in range(len(class_names))]  # convert to list

# --- Load trained model ---
model = keras.models.load_model(h5_path)

# --- Paths ---
base_path = "../Combined_Dataset"

# --- Prepare 2x2 Plot ---
plt.figure(figsize=(10, 10))

for i in range(4):
    # Pick random class and image
    random_class = random.choice(class_names)
    class_folder = os.path.join(base_path, random_class)
    random_image = random.choice([
        f for f in os.listdir(class_folder)
        if f.lower().endswith(('.png', '.jpg', '.jpeg'))
    ])
    img_path = os.path.join(class_folder, random_image)

    # --- Load & Preprocess Image ---
    img = keras.utils.load_img(img_path, target_size=size)  # resize to match model input
    img_array = keras.utils.img_to_array(img)
    img_array = img_array / 255.0  # normalize if your model expects it
    img_array = tf.expand_dims(img_array, 0)

    # --- Predict ---
    predictions = model.predict(img_array, verbose=0)
    probabilities = tf.nn.softmax(predictions[0])
    predicted_class_index = np.argmax(probabilities)
    predicted_label = class_names[predicted_class_index]
    confidence = 100 * probabilities[predicted_class_index]
    
    top_5_indices = np.argsort(probabilities)[-5:][::-1]
    print("\nTop 5 predictions:")
    for idx in top_5_indices:
        print(f"{class_names[idx]:<20}: {probabilities[idx]:.4f}")


    # Compare with actual
    is_correct = predicted_label == random_class

    # --- Plot ---
    ax = plt.subplot(2, 2, i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(
        f"Pred: {predicted_label}\n"
        f"True: {random_class}\n"
        f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%",
        fontsize=10
    )

plt.tight_layout()
plt.show()