-
michael.divia authoredmichael.divia authored
pokedex_test.py 2.50 KiB
from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import json
import argparse
# --- WHAT ? ---
parser = argparse.ArgumentParser(description="WHAT ?!")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()
# Paths
if args.model == "1":
h5_path = "../models/ResNet50/pokedex_ResNet50.h5"
json_path = "../models/ResNet50/class_names.json"
size = (224,224)
elif args.model == "2":
h5_path = "../models/Xception/pokedex_Xception.h5"
json_path = "../models/Xception/class_names.json"
size = (256,256)
# --- Load class names from JSON ---
with open(json_path, "r") as f:
class_names = json.load(f)
class_names = [class_names[i] for i in range(len(class_names))] # convert to list
# --- Load trained model ---
model = keras.models.load_model(h5_path)
# --- Paths ---
base_path = "../Combined_Dataset"
# --- Prepare 2x2 Plot ---
plt.figure(figsize=(10, 10))
for i in range(4):
# Pick random class and image
random_class = random.choice(class_names)
class_folder = os.path.join(base_path, random_class)
random_image = random.choice([
f for f in os.listdir(class_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
img_path = os.path.join(class_folder, random_image)
# --- Load & Preprocess Image ---
img = keras.utils.load_img(img_path, target_size=size) # resize to match model input
img_array = keras.utils.img_to_array(img)
img_array = img_array / 255.0 # normalize if your model expects it
img_array = tf.expand_dims(img_array, 0)
# --- Predict ---
predictions = model.predict(img_array, verbose=0)
probabilities = tf.nn.softmax(predictions[0])
predicted_class_index = np.argmax(probabilities)
predicted_label = class_names[predicted_class_index]
confidence = 100 * probabilities[predicted_class_index]
top_5_indices = np.argsort(probabilities)[-5:][::-1]
print("\nTop 5 predictions:")
for idx in top_5_indices:
print(f"{class_names[idx]:<20}: {probabilities[idx]:.4f}")
# Compare with actual
is_correct = predicted_label == random_class
# --- Plot ---
ax = plt.subplot(2, 2, i + 1)
plt.imshow(img)
plt.axis("off")
plt.title(
f"Pred: {predicted_label}\n"
f"True: {random_class}\n"
f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%",
fontsize=10
)
plt.tight_layout()
plt.show()