diff --git a/python/pokedex_Xception.py b/python/pokedex_Xception.py
index 55d425c25c73dd8b8a8b8a66624949c4efa75ea3..24f0dadeca59814fa142894ef110fe57619bee64 100644
--- a/python/pokedex_Xception.py
+++ b/python/pokedex_Xception.py
@@ -13,9 +13,8 @@ strategy = tf.distribute.MirroredStrategy()
 print("Number of GPUs:", strategy.num_replicas_in_sync)
 
 # --- Paths ---
-parser = argparse.ArgumentParser(description="WHERE ?!")
-parser.add_argument("--hpc", choices=["yes", "no"], default="no",
-                    help="Use HPC paths if 'yes', otherwise local paths.")
+parser = argparse.ArgumentParser(description="Train Xception Pokémon Classifier")
+parser.add_argument("--hpc", choices=["yes", "no"], default="no", help="Use HPC paths if 'yes', otherwise local paths.")
 args = parser.parse_args()
 
 if args.hpc == "yes":
@@ -30,14 +29,12 @@ os.makedirs(model_output_path, exist_ok=True)
 # --- Custom Xception-like model ---
 def simple_xception(input_shape, num_classes):
     inputs = Input(shape=input_shape)
-
     x = layers.Rescaling(1.0 / 255)(inputs)
     x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
     x = layers.BatchNormalization()(x)
     x = layers.Activation("relu")(x)
 
     previous_block_activation = x
-
     for size in [256, 512, 728]:
         x = layers.Activation("relu")(x)
         x = layers.SeparableConv2D(size, 3, padding="same")(x)
@@ -55,11 +52,11 @@ def simple_xception(input_shape, num_classes):
     x = layers.SeparableConv2D(1024, 3, padding="same")(x)
     x = layers.BatchNormalization()(x)
     x = layers.Activation("relu")(x)
-
     x = layers.GlobalAveragePooling2D()(x)
     x = layers.Dropout(0.25)(x)
-    outputs = layers.Dense(num_classes, activation='softmax')(x)
 
+    # Output logits
+    outputs = layers.Dense(num_classes, activation=None)(x)
     return models.Model(inputs, outputs)
 
 # --- Image settings ---
@@ -83,7 +80,6 @@ raw_train_ds = image_dataset_from_directory(
     subset="training",
     seed=123,
 )
-
 raw_val_ds = image_dataset_from_directory(
     dataset_path,
     image_size=img_size,
@@ -110,6 +106,12 @@ class_weights = compute_class_weight(
 class_weight_dict = dict(enumerate(class_weights))
 print("Class weights ready.")
 
+# --- Debug print for class balance ---
+print("Unique labels in training set:", np.unique(all_labels))
+print("Class Names (index -> name):")
+for i, name in enumerate(class_names):
+    print(f"{i}: {name}")
+
 # --- Performance improvements ---
 AUTOTUNE = tf.data.AUTOTUNE
 train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE)
@@ -118,7 +120,11 @@ val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE)
 # --- Build and compile model ---
 with strategy.scope():
     model = simple_xception((*img_size, 3), num_classes=len(class_names))
-    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+    model.compile(
+        optimizer='adam',
+        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=['accuracy']
+    )
 
 # --- Callbacks ---
 callbacks = [
@@ -142,4 +148,4 @@ print(f"Model saved to {model_h5_path}")
 # --- Save as TensorFlow SavedModel ---
 saved_model_path = os.path.join(model_output_path, "saved_model")
 tf.saved_model.save(model, saved_model_path)
-print(f"SavedModel exported to {saved_model_path}")
+print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file
diff --git a/python/pokedex_test.py b/python/pokedex_test.py
index bccf30f21ec69c9a3d9c849f73f79c1e9e0d59dd..7cefe5741d7a96e5f820f7b7f1ca27a987139639 100644
--- a/python/pokedex_test.py
+++ b/python/pokedex_test.py
@@ -54,10 +54,16 @@ for i in range(4):
 
     # --- Predict ---
     predictions = model.predict(img_array, verbose=0)
-    probabilities = tf.nn.softmax(predictions[0])
+    probabilities = predictions[0]
     predicted_class_index = np.argmax(probabilities)
     predicted_label = class_names[predicted_class_index]
     confidence = 100 * probabilities[predicted_class_index]
+    
+    top_5_indices = np.argsort(probabilities)[-5:][::-1]
+    print("\nTop 5 predictions:")
+    for idx in top_5_indices:
+        print(f"{class_names[idx]:<20}: {probabilities[idx]:.4f}")
+
 
     # Compare with actual
     is_correct = predicted_label == random_class