diff --git a/python/pokedex_ResNet50.py b/python/pokedex_ResNet50.py
index dad6c96c0b6936f2ff05812d53a25b742e9fee02..971ac693cc771b4cc8bb20a47c41ed54c4979304 100644
--- a/python/pokedex_ResNet50.py
+++ b/python/pokedex_ResNet50.py
@@ -69,15 +69,11 @@ print(f"Detected {len(class_names)} Pokémon classes.")
 
 # --- Compute class weights ---
 print("Computing class weights...")
-all_labels = []
-for _, labels in raw_train_ds.unbatch():
-    all_labels.append(labels.numpy())
-
-all_labels = np.array(all_labels)
+all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()]
 class_weights = compute_class_weight(
     class_weight="balanced",
     classes=np.unique(all_labels),
-    y=all_labels
+    y=np.array(all_labels)
 )
 class_weight_dict = dict(enumerate(class_weights))
 print("Class weights ready.")
@@ -92,26 +88,29 @@ with strategy.scope():
     base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
     x = base_model.output
     x = GlobalAveragePooling2D()(x)
-    outputs = Dense(len(class_names), activation="softmax")(x)
+    outputs = Dense(len(class_names), activation=None)(x)
 
     model = Model(inputs=base_model.input, outputs=outputs)
 
-    # Freeze some layers
     for layer in base_model.layers[:100]:
         layer.trainable = False
 
-    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+    model.compile(
+        optimizer='adam',
+        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=['accuracy']
+    )
 
 # --- Callbacks ---
 callbacks = [
     EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
 ]
 
-# --- Train the model with class weights ---
+# --- Train the model ---
 model.fit(
     train_ds,
     validation_data=val_ds,
-    epochs=1,
+    epochs=20,
     callbacks=callbacks,
     class_weight=class_weight_dict
 )
@@ -121,7 +120,7 @@ model_h5_path = os.path.join(model_output_path, "pokedex_ResNet50.h5")
 model.save(model_h5_path)
 print(f"Model saved to {model_h5_path}")
 
-# --- Save as TensorFlow SavedModel (for ONNX export) ---
+# --- Save as TensorFlow SavedModel ---
 saved_model_path = os.path.join(model_output_path, "saved_model")
 tf.saved_model.save(model, saved_model_path)
-print(f"SavedModel exported to {saved_model_path}")
+print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file