From c801ffeb750b55aa0a95c121f14e8edf4a9e235d Mon Sep 17 00:00:00 2001
From: "michael.divia" <michael.divia@etu.hesge.ch>
Date: Wed, 2 Apr 2025 10:09:17 +0200
Subject: [PATCH] Added Freezing and Fine Tuning for EfficientNetV2M

---
 Python/pokedex_EfficientNetV2M.py | 62 +++++++++++++++++++------------
 Python/pokedex_ResNet50.py        |  2 -
 2 files changed, 38 insertions(+), 26 deletions(-)

diff --git a/Python/pokedex_EfficientNetV2M.py b/Python/pokedex_EfficientNetV2M.py
index d6207c1..87d8274 100644
--- a/Python/pokedex_EfficientNetV2M.py
+++ b/Python/pokedex_EfficientNetV2M.py
@@ -1,10 +1,12 @@
 import os
-import gc
 import keras
 import tensorflow as tf
 from keras import layers
 from tensorflow import data as tf_data
 
+# --- Silence TensorFlow logs ---
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
+
 # --- GPU Strategy ---
 strategy = tf.distribute.MirroredStrategy()
 print("Number of GPUs:", strategy.num_replicas_in_sync)
@@ -14,9 +16,8 @@ data_dir = "/home/users/d/divia/scratch/Combined_Dataset"
 image_size = (240, 240)
 num_classes = 151
 base_batch_size = 32
-base_lr = 1e-3
-
 global_batch_size = 32
+base_lr = 1e-3
 scaled_lr = min(base_lr * (global_batch_size / base_batch_size), 1e-3)
 
 # --- Load Dataset ---
@@ -54,10 +55,13 @@ def preprocess_val(img, label):
 train_ds = train_ds.map(preprocess_train, num_parallel_calls=tf_data.AUTOTUNE)
 val_ds = val_ds.map(preprocess_val, num_parallel_calls=tf_data.AUTOTUNE)
 
-train_ds = train_ds.prefetch(buffer_size=tf_data.AUTOTUNE)
-val_ds = val_ds.prefetch(buffer_size=tf_data.AUTOTUNE)
+# Add auto-shard options to suppress Grappler warning
+options = tf.data.Options()
+options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
+train_ds = train_ds.with_options(options).prefetch(buffer_size=tf_data.AUTOTUNE)
+val_ds = val_ds.with_options(options).prefetch(buffer_size=tf_data.AUTOTUNE)
 
-# --- Build & Compile Model ---
+# --- Build model ---
 with strategy.scope():
     base_model = tf.keras.applications.EfficientNetV2M(
         include_top=False,
@@ -65,31 +69,41 @@ with strategy.scope():
         input_shape=(240, 240, 3)
     )
 
-    model = keras.Sequential([
-        base_model,
-        layers.GlobalAveragePooling2D(),
-        layers.Dense(256, activation='relu'),
-        layers.Dropout(0.5),
-        layers.Dense(num_classes, activation='softmax')
-    ])
+    x = layers.GlobalAveragePooling2D()(base_model.output)
+    x = layers.Dense(256, activation='relu')(x)
+    x = layers.Dropout(0.5)(x)
+    predictions = layers.Dense(num_classes, activation='softmax')(x)
+
+    model = keras.Model(inputs=base_model.input, outputs=predictions)
 
-    optimizer = tf.keras.optimizers.Adam(learning_rate=scaled_lr)
+    # PHASE 1: Freeze the base model
+    base_model.trainable = False
 
     model.compile(
-        optimizer=optimizer,
-        loss='categorical_crossentropy',
-        metrics=['accuracy']
+        optimizer=tf.keras.optimizers.Adam(learning_rate=scaled_lr),
+        loss="categorical_crossentropy",
+        metrics=["accuracy"]
     )
 
-# --- Train ---
+# --- Callbacks ---
 callbacks = [
     keras.callbacks.ModelCheckpoint("/home/users/d/divia/EfficientNetV2M/save_at_{epoch}.keras"),
     keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)
 ]
 
-model.fit(
-    train_ds,
-    validation_data=val_ds,
-    epochs=10,
-    callbacks=callbacks
-)
\ No newline at end of file
+# --- Train head only ---
+model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=callbacks)
+
+# PHASE 2: Fine-tune top of the base model
+# Unfreeze the whole base:
+base_model.trainable = True
+
+# Recompile with lower LR for fine-tuning
+model.compile(
+    optimizer=tf.keras.optimizers.SGD(learning_rate=1e-5, momentum=0.9),
+    loss="categorical_crossentropy",
+    metrics=["accuracy"]
+)
+
+# Fine-tune the full model
+model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=callbacks)
\ No newline at end of file
diff --git a/Python/pokedex_ResNet50.py b/Python/pokedex_ResNet50.py
index 9659d76..925618a 100644
--- a/Python/pokedex_ResNet50.py
+++ b/Python/pokedex_ResNet50.py
@@ -1,5 +1,3 @@
-import os
-import gc
 import keras
 import tensorflow as tf
 from keras import layers
-- 
GitLab