diff --git a/Python/pokedex_Xception.py b/Python/pokedex_Xception.py new file mode 100644 index 0000000000000000000000000000000000000000..98785406e4e610930769c15a36db245d788d65c8 --- /dev/null +++ b/Python/pokedex_Xception.py @@ -0,0 +1,145 @@ +import tensorflow as tf +from tensorflow.keras import layers, models, Input +from tensorflow.keras.callbacks import EarlyStopping +from tensorflow.keras.preprocessing import image_dataset_from_directory +from sklearn.utils.class_weight import compute_class_weight +import numpy as np +import json +import os +import argparse + +# --- GPU Strategy --- +strategy = tf.distribute.MirroredStrategy() +print("Number of GPUs:", strategy.num_replicas_in_sync) + +# --- Paths --- +parser = argparse.ArgumentParser(description="WHERE ?!") +parser.add_argument("--hpc", choices=["yes", "no"], default="no", + help="Use HPC paths if 'yes', otherwise local paths.") +args = parser.parse_args() + +if args.hpc == "yes": + dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" + model_output_path = "/home/users/d/divia/pokedex/models/Xception" +else: + dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" + model_output_path = "/home/padi/Git/pokedex/models/Xception" + +os.makedirs(model_output_path, exist_ok=True) + +# --- Custom Xception-like model --- +def simple_xception(input_shape, num_classes): + inputs = Input(shape=input_shape) + + x = layers.Rescaling(1.0 / 255)(inputs) + x = layers.Conv2D(128, 3, strides=2, padding="same")(x) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + previous_block_activation = x + + for size in [256, 512, 728]: + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(size, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(size, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.MaxPooling2D(3, strides=2, padding="same")(x) + residual = layers.Conv2D(size, 1, strides=2, padding="same")(previous_block_activation) + x = layers.add([x, residual]) + previous_block_activation = x + + x = layers.SeparableConv2D(1024, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + x = layers.GlobalAveragePooling2D()(x) + x = layers.Dropout(0.25)(x) + outputs = layers.Dense(num_classes, activation='softmax')(x) + + return models.Model(inputs, outputs) + +# --- Image settings --- +img_size = (256, 256) +batch_size = 32 + +# --- Data Augmentation --- +data_augmentation = tf.keras.Sequential([ + layers.RandomFlip("horizontal"), + layers.RandomRotation(0.1), + layers.RandomZoom(0.1), + layers.RandomContrast(0.1), +]) + +# --- Load datasets --- +raw_train_ds = image_dataset_from_directory( + dataset_path, + image_size=img_size, + batch_size=batch_size, + validation_split=0.2, + subset="training", + seed=123, +) + +raw_val_ds = image_dataset_from_directory( + dataset_path, + image_size=img_size, + batch_size=batch_size, + validation_split=0.2, + subset="validation", + seed=123, +) + +# Save class names +class_names = raw_train_ds.class_names +with open(os.path.join(model_output_path, "class_names.json"), "w") as f: + json.dump(class_names, f) +print(f"Detected {len(class_names)} Pokémon classes.") + +# --- Compute class weights --- +print("Computing class weights...") +all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()] +class_weights = compute_class_weight( + class_weight="balanced", + classes=np.unique(all_labels), + y=np.array(all_labels) +) +class_weight_dict = dict(enumerate(class_weights)) +print("Class weights ready.") + +# --- Performance improvements --- +AUTOTUNE = tf.data.AUTOTUNE +train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE) +val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) + +# --- Build and compile model --- +with strategy.scope(): + model = simple_xception((*img_size, 3), num_classes=len(class_names)) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + +# --- Callbacks --- +callbacks = [ + EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) +] + +# --- Train the model --- +model.fit( + train_ds, + validation_data=val_ds, + epochs=20, + callbacks=callbacks, + class_weight=class_weight_dict +) + +# --- Save the model --- +model_h5_path = os.path.join(model_output_path, "pokemon_xception.h5") +model.save(model_h5_path) +print(f"Model saved to {model_h5_path}") + +# --- Save as TensorFlow SavedModel --- +saved_model_path = os.path.join(model_output_path, "saved_model") +tf.saved_model.save(model, saved_model_path) +print(f"SavedModel exported to {saved_model_path}")