diff --git a/python/pokedex_ResNet50.py b/python/pokedex_ResNet50.py index dad6c96c0b6936f2ff05812d53a25b742e9fee02..971ac693cc771b4cc8bb20a47c41ed54c4979304 100644 --- a/python/pokedex_ResNet50.py +++ b/python/pokedex_ResNet50.py @@ -69,15 +69,11 @@ print(f"Detected {len(class_names)} Pokémon classes.") # --- Compute class weights --- print("Computing class weights...") -all_labels = [] -for _, labels in raw_train_ds.unbatch(): - all_labels.append(labels.numpy()) - -all_labels = np.array(all_labels) +all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()] class_weights = compute_class_weight( class_weight="balanced", classes=np.unique(all_labels), - y=all_labels + y=np.array(all_labels) ) class_weight_dict = dict(enumerate(class_weights)) print("Class weights ready.") @@ -92,26 +88,29 @@ with strategy.scope(): base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) x = base_model.output x = GlobalAveragePooling2D()(x) - outputs = Dense(len(class_names), activation="softmax")(x) + outputs = Dense(len(class_names), activation=None)(x) model = Model(inputs=base_model.input, outputs=outputs) - # Freeze some layers for layer in base_model.layers[:100]: layer.trainable = False - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + model.compile( + optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy'] + ) # --- Callbacks --- callbacks = [ EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) ] -# --- Train the model with class weights --- +# --- Train the model --- model.fit( train_ds, validation_data=val_ds, - epochs=1, + epochs=20, callbacks=callbacks, class_weight=class_weight_dict ) @@ -121,7 +120,7 @@ model_h5_path = os.path.join(model_output_path, "pokedex_ResNet50.h5") model.save(model_h5_path) print(f"Model saved to {model_h5_path}") -# --- Save as TensorFlow SavedModel (for ONNX export) --- +# --- Save as TensorFlow SavedModel --- saved_model_path = os.path.join(model_output_path, "saved_model") tf.saved_model.save(model, saved_model_path) -print(f"SavedModel exported to {saved_model_path}") +print(f"SavedModel exported to {saved_model_path}") \ No newline at end of file