From e15bb05ca685215a7625be4b9f05cdc519bf1988 Mon Sep 17 00:00:00 2001
From: "michael.divia" <michael.divia@etu.hesge.ch>
Date: Wed, 9 Apr 2025 18:16:38 +0200
Subject: [PATCH] Migrated from SoftMax to Logits

---
 python/pokedex_ResNet50.py | 25 ++++++++++++-------------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/python/pokedex_ResNet50.py b/python/pokedex_ResNet50.py
index dad6c96..971ac69 100644
--- a/python/pokedex_ResNet50.py
+++ b/python/pokedex_ResNet50.py
@@ -69,15 +69,11 @@ print(f"Detected {len(class_names)} Pokémon classes.")
 
 # --- Compute class weights ---
 print("Computing class weights...")
-all_labels = []
-for _, labels in raw_train_ds.unbatch():
-    all_labels.append(labels.numpy())
-
-all_labels = np.array(all_labels)
+all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()]
 class_weights = compute_class_weight(
     class_weight="balanced",
     classes=np.unique(all_labels),
-    y=all_labels
+    y=np.array(all_labels)
 )
 class_weight_dict = dict(enumerate(class_weights))
 print("Class weights ready.")
@@ -92,26 +88,29 @@ with strategy.scope():
     base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
     x = base_model.output
     x = GlobalAveragePooling2D()(x)
-    outputs = Dense(len(class_names), activation="softmax")(x)
+    outputs = Dense(len(class_names), activation=None)(x)
 
     model = Model(inputs=base_model.input, outputs=outputs)
 
-    # Freeze some layers
     for layer in base_model.layers[:100]:
         layer.trainable = False
 
-    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+    model.compile(
+        optimizer='adam',
+        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+        metrics=['accuracy']
+    )
 
 # --- Callbacks ---
 callbacks = [
     EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
 ]
 
-# --- Train the model with class weights ---
+# --- Train the model ---
 model.fit(
     train_ds,
     validation_data=val_ds,
-    epochs=1,
+    epochs=20,
     callbacks=callbacks,
     class_weight=class_weight_dict
 )
@@ -121,7 +120,7 @@ model_h5_path = os.path.join(model_output_path, "pokedex_ResNet50.h5")
 model.save(model_h5_path)
 print(f"Model saved to {model_h5_path}")
 
-# --- Save as TensorFlow SavedModel (for ONNX export) ---
+# --- Save as TensorFlow SavedModel ---
 saved_model_path = os.path.join(model_output_path, "saved_model")
 tf.saved_model.save(model, saved_model_path)
-print(f"SavedModel exported to {saved_model_path}")
+print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file
-- 
GitLab