diff --git a/Python/pokedex_ResNet50.py b/Python/pokedex_ResNet50.py index 80160bc0c7d1903db4b99b96e1df20b3aa394d86..57240386a7cbebc0be2feddffd6b7e94d4baa610 100644 --- a/Python/pokedex_ResNet50.py +++ b/Python/pokedex_ResNet50.py @@ -9,16 +9,25 @@ from sklearn.utils.class_weight import compute_class_weight import numpy as np import json import os +import argparse # --- GPU Strategy --- strategy = tf.distribute.MirroredStrategy() print("Number of GPUs:", strategy.num_replicas_in_sync) # --- Paths --- -dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" -model_output_path = "/home/padi/Git/pokedex/models/ResNet50" -#dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" -#model_output_path = "/home/users/d/divia/pokedex/models/ResNet50" +parser = argparse.ArgumentParser(description="WHERE ?!") +parser.add_argument("--hpc", choices=["yes", "no"], default="no", + help="Use HPC paths if 'yes', otherwise local paths.") +args = parser.parse_args() + +if args.hpc == "yes": + dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" + model_output_path = "/home/users/d/divia/pokedex/models/ResNet50" +else: + dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" + model_output_path = "/home/padi/Git/pokedex/models/ResNet50" + os.makedirs(model_output_path, exist_ok=True) # --- Image settings --- diff --git a/Python/pokedex_xception.py b/Python/pokedex_xception.py deleted file mode 100644 index 8584135740b475299f44b6d92427dd2090b321ef..0000000000000000000000000000000000000000 --- a/Python/pokedex_xception.py +++ /dev/null @@ -1,136 +0,0 @@ -import tensorflow as tf -from tensorflow.keras import layers, models, Input -from tensorflow.keras.callbacks import EarlyStopping -from tensorflow.keras.preprocessing import image_dataset_from_directory -from sklearn.utils.class_weight import compute_class_weight -import numpy as np -import json -import os - -# --- GPU Strategy --- -strategy = tf.distribute.MirroredStrategy() -print("Number of GPUs:", strategy.num_replicas_in_sync) - -# --- Paths --- -dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" -model_output_path = "/home/padi/Git/pokedex/models/Xception" -#dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" -#model_output_path = "/home/users/d/divia/pokedex/models7Xception" -os.makedirs(model_output_path, exist_ok=True) - -# --- Custom Xception-like model --- -def simple_xception(input_shape, num_classes): - inputs = Input(shape=input_shape) - - x = layers.Rescaling(1.0 / 255)(inputs) - x = layers.Conv2D(128, 3, strides=2, padding="same")(x) - x = layers.BatchNormalization()(x) - x = layers.Activation("relu")(x) - - previous_block_activation = x - - for size in [256, 512, 728]: - x = layers.Activation("relu")(x) - x = layers.SeparableConv2D(size, 3, padding="same")(x) - x = layers.BatchNormalization()(x) - - x = layers.Activation("relu")(x) - x = layers.SeparableConv2D(size, 3, padding="same")(x) - x = layers.BatchNormalization()(x) - - x = layers.MaxPooling2D(3, strides=2, padding="same")(x) - residual = layers.Conv2D(size, 1, strides=2, padding="same")(previous_block_activation) - x = layers.add([x, residual]) - previous_block_activation = x - - x = layers.SeparableConv2D(1024, 3, padding="same")(x) - x = layers.BatchNormalization()(x) - x = layers.Activation("relu")(x) - - x = layers.GlobalAveragePooling2D()(x) - x = layers.Dropout(0.25)(x) - outputs = layers.Dense(num_classes, activation='softmax')(x) - - return models.Model(inputs, outputs) - -# --- Image settings --- -img_size = (256, 256) -batch_size = 32 - -# --- Data Augmentation --- -data_augmentation = tf.keras.Sequential([ - layers.RandomFlip("horizontal"), - layers.RandomRotation(0.1), - layers.RandomZoom(0.1), - layers.RandomContrast(0.1), -]) - -# --- Load datasets --- -raw_train_ds = image_dataset_from_directory( - dataset_path, - image_size=img_size, - batch_size=batch_size, - validation_split=0.2, - subset="training", - seed=123, -) - -raw_val_ds = image_dataset_from_directory( - dataset_path, - image_size=img_size, - batch_size=batch_size, - validation_split=0.2, - subset="validation", - seed=123, -) - -# Save class names -class_names = raw_train_ds.class_names -with open(os.path.join(model_output_path, "class_names.json"), "w") as f: - json.dump(class_names, f) -print(f"Detected {len(class_names)} Pokémon classes.") - -# --- Compute class weights --- -print("Computing class weights...") -all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()] -class_weights = compute_class_weight( - class_weight="balanced", - classes=np.unique(all_labels), - y=np.array(all_labels) -) -class_weight_dict = dict(enumerate(class_weights)) -print("Class weights ready.") - -# --- Performance improvements --- -AUTOTUNE = tf.data.AUTOTUNE -train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE) -val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) - -# --- Build and compile model --- -with strategy.scope(): - model = simple_xception((*img_size, 3), num_classes=len(class_names)) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - -# --- Callbacks --- -callbacks = [ - EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) -] - -# --- Train the model --- -model.fit( - train_ds, - validation_data=val_ds, - epochs=20, - callbacks=callbacks, - class_weight=class_weight_dict -) - -# --- Save the model --- -model_h5_path = os.path.join(model_output_path, "pokemon_xception.h5") -model.save(model_h5_path) -print(f"Model saved to {model_h5_path}") - -# --- Save as TensorFlow SavedModel --- -saved_model_path = os.path.join(model_output_path, "saved_model") -tf.saved_model.save(model, saved_model_path) -print(f"SavedModel exported to {saved_model_path}") diff --git a/slurm/train_ResNet50.sh b/slurm/train_ResNet50.sh index bc5f0a07b5df47a771fce34021c57b5d09b4c053..db868da4750f093a0a3bfee4f465e7743425e16c 100644 --- a/slurm/train_ResNet50.sh +++ b/slurm/train_ResNet50.sh @@ -15,4 +15,4 @@ module load cuDNN/8.4.1.50-CUDA-11.7.0 module load scikit-learn/1.1.2 # Run your script -srun python ../Python/pokedex_ResNet50.py +srun python ../Python/pokedex_ResNet50.py --hpc yes diff --git a/slurm/train_Xception.sh b/slurm/train_Xception.sh new file mode 100644 index 0000000000000000000000000000000000000000..bd6e73801f13b0109f013ada8ee5f5aa6ac04b00 --- /dev/null +++ b/slurm/train_Xception.sh @@ -0,0 +1,18 @@ +#!/bin/sh +#SBATCH --job-name=ResNet50 +#SBATCH --output=ResNet50_%j.out +#SBATCH --partition=shared-gpu +#SBATCH --gres=gpu:1,VramPerGpu:80G +#SBATCH --cpus-per-task=2 +#SBATCH --mem=16G +#SBATCH --time=05:00:00 +#SBATCH --mail-type=FAIL + +# Load modules +module purge +module load GCC/11.3.0 OpenMPI/4.1.4 TensorFlow/2.11.0-CUDA-11.7.0 +module load cuDNN/8.4.1.50-CUDA-11.7.0 +module load scikit-learn/1.1.2 + +# Run your script +srun python ../Python/pokedex_Xception.py --hpc yes