diff --git a/python/pokedex_Xception.py b/python/pokedex_Xception.py index 55d425c25c73dd8b8a8b8a66624949c4efa75ea3..24f0dadeca59814fa142894ef110fe57619bee64 100644 --- a/python/pokedex_Xception.py +++ b/python/pokedex_Xception.py @@ -13,9 +13,8 @@ strategy = tf.distribute.MirroredStrategy() print("Number of GPUs:", strategy.num_replicas_in_sync) # --- Paths --- -parser = argparse.ArgumentParser(description="WHERE ?!") -parser.add_argument("--hpc", choices=["yes", "no"], default="no", - help="Use HPC paths if 'yes', otherwise local paths.") +parser = argparse.ArgumentParser(description="Train Xception Pokémon Classifier") +parser.add_argument("--hpc", choices=["yes", "no"], default="no", help="Use HPC paths if 'yes', otherwise local paths.") args = parser.parse_args() if args.hpc == "yes": @@ -30,14 +29,12 @@ os.makedirs(model_output_path, exist_ok=True) # --- Custom Xception-like model --- def simple_xception(input_shape, num_classes): inputs = Input(shape=input_shape) - x = layers.Rescaling(1.0 / 255)(inputs) x = layers.Conv2D(128, 3, strides=2, padding="same")(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) previous_block_activation = x - for size in [256, 512, 728]: x = layers.Activation("relu")(x) x = layers.SeparableConv2D(size, 3, padding="same")(x) @@ -55,11 +52,11 @@ def simple_xception(input_shape, num_classes): x = layers.SeparableConv2D(1024, 3, padding="same")(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) - x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(0.25)(x) - outputs = layers.Dense(num_classes, activation='softmax')(x) + # Output logits + outputs = layers.Dense(num_classes, activation=None)(x) return models.Model(inputs, outputs) # --- Image settings --- @@ -83,7 +80,6 @@ raw_train_ds = image_dataset_from_directory( subset="training", seed=123, ) - raw_val_ds = image_dataset_from_directory( dataset_path, image_size=img_size, @@ -110,6 +106,12 @@ class_weights = compute_class_weight( class_weight_dict = dict(enumerate(class_weights)) print("Class weights ready.") +# --- Debug print for class balance --- +print("Unique labels in training set:", np.unique(all_labels)) +print("Class Names (index -> name):") +for i, name in enumerate(class_names): + print(f"{i}: {name}") + # --- Performance improvements --- AUTOTUNE = tf.data.AUTOTUNE train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE) @@ -118,7 +120,11 @@ val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) # --- Build and compile model --- with strategy.scope(): model = simple_xception((*img_size, 3), num_classes=len(class_names)) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + model.compile( + optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy'] + ) # --- Callbacks --- callbacks = [ @@ -142,4 +148,4 @@ print(f"Model saved to {model_h5_path}") # --- Save as TensorFlow SavedModel --- saved_model_path = os.path.join(model_output_path, "saved_model") tf.saved_model.save(model, saved_model_path) -print(f"SavedModel exported to {saved_model_path}") +print(f"SavedModel exported to {saved_model_path}") \ No newline at end of file diff --git a/python/pokedex_test.py b/python/pokedex_test.py index bccf30f21ec69c9a3d9c849f73f79c1e9e0d59dd..7cefe5741d7a96e5f820f7b7f1ca27a987139639 100644 --- a/python/pokedex_test.py +++ b/python/pokedex_test.py @@ -54,10 +54,16 @@ for i in range(4): # --- Predict --- predictions = model.predict(img_array, verbose=0) - probabilities = tf.nn.softmax(predictions[0]) + probabilities = predictions[0] predicted_class_index = np.argmax(probabilities) predicted_label = class_names[predicted_class_index] confidence = 100 * probabilities[predicted_class_index] + + top_5_indices = np.argsort(probabilities)[-5:][::-1] + print("\nTop 5 predictions:") + for idx in top_5_indices: + print(f"{class_names[idx]:<20}: {probabilities[idx]:.4f}") + # Compare with actual is_correct = predicted_label == random_class