Skip to content
Snippets Groups Projects
Commit e15bb05c authored by michael.divia's avatar michael.divia
Browse files

Migrated from SoftMax to Logits

parent cc235fa3
Branches
No related tags found
No related merge requests found
......@@ -69,15 +69,11 @@ print(f"Detected {len(class_names)} Pokémon classes.")
# --- Compute class weights ---
print("Computing class weights...")
all_labels = []
for _, labels in raw_train_ds.unbatch():
all_labels.append(labels.numpy())
all_labels = np.array(all_labels)
all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()]
class_weights = compute_class_weight(
class_weight="balanced",
classes=np.unique(all_labels),
y=all_labels
y=np.array(all_labels)
)
class_weight_dict = dict(enumerate(class_weights))
print("Class weights ready.")
......@@ -92,26 +88,29 @@ with strategy.scope():
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
outputs = Dense(len(class_names), activation="softmax")(x)
outputs = Dense(len(class_names), activation=None)(x)
model = Model(inputs=base_model.input, outputs=outputs)
# Freeze some layers
for layer in base_model.layers[:100]:
layer.trainable = False
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# --- Callbacks ---
callbacks = [
EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
]
# --- Train the model with class weights ---
# --- Train the model ---
model.fit(
train_ds,
validation_data=val_ds,
epochs=1,
epochs=20,
callbacks=callbacks,
class_weight=class_weight_dict
)
......@@ -121,7 +120,7 @@ model_h5_path = os.path.join(model_output_path, "pokedex_ResNet50.h5")
model.save(model_h5_path)
print(f"Model saved to {model_h5_path}")
# --- Save as TensorFlow SavedModel (for ONNX export) ---
# --- Save as TensorFlow SavedModel ---
saved_model_path = os.path.join(model_output_path, "saved_model")
tf.saved_model.save(model, saved_model_path)
print(f"SavedModel exported to {saved_model_path}")
print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment