From 55ed35f44de0b1b35b3c5193db40037070761089 Mon Sep 17 00:00:00 2001 From: "michael.divia" <michael.divia@etu.hesge.ch> Date: Tue, 25 Mar 2025 17:33:42 +0100 Subject: [PATCH] Working Simple CNN --- pokedex.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++++++++ test.py | 43 ++++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 test.py diff --git a/pokedex.py b/pokedex.py index e69de29..0675508 100644 --- a/pokedex.py +++ b/pokedex.py @@ -0,0 +1,113 @@ +import os +import numpy as np +import keras +from keras import layers +import matplotlib.pyplot as plt +from tensorflow import data as tf_data +import random + +#os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +# Load training and validation dataset +train_ds, val_ds = keras.utils.image_dataset_from_directory( + "Combined_Dataset", + labels="inferred", + label_mode="int", + image_size=(256, 256), + batch_size=20, + shuffle=True, + validation_split=0.2, + subset="both", + seed=random.randint(0,8000) +) + +# Get class (Pokémon) names +class_names = train_ds.class_names + +# Introduce artificial sample diversity +data_augmentation_layers = [ + layers.RandomFlip("horizontal"), + layers.RandomRotation(0.1), +] + +def data_augmentation(images): + for layer in data_augmentation_layers: + images = layer(images) + return images + +# Apply `data_augmentation` to the training images. +train_ds = train_ds.map( + lambda img, label: (data_augmentation(img), label), + num_parallel_calls=tf_data.AUTOTUNE, +) + +# Prefetching samples in GPU memory helps maximize GPU utilization. +train_ds = train_ds.prefetch(tf_data.AUTOTUNE) +val_ds = val_ds.prefetch(tf_data.AUTOTUNE) + + +# MODEL +def simple_xception_netowkr(input_shape, num_classes): + inputs = keras.Input(shape=input_shape) + + # Entry block + x = layers.Rescaling(1.0 / 255)(inputs) + x = layers.Conv2D(128, 3, strides=2, padding="same")(x) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + previous_block_activation = x # Set aside residual + + for size in [256, 512, 728]: + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(size, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(size, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.MaxPooling2D(3, strides=2, padding="same")(x) + + # Project residual + residual = layers.Conv2D(size, 1, strides=2, padding="same")( + previous_block_activation + ) + x = layers.add([x, residual]) # Add back residual + previous_block_activation = x # Set aside next residual + + x = layers.SeparableConv2D(1024, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + x = layers.GlobalAveragePooling2D()(x) + if num_classes == 2: + units = 1 + else: + units = num_classes + + x = layers.Dropout(0.25)(x) + # We specify activation=None so as to return logits + outputs = layers.Dense(units, activation=None)(x) + return keras.Model(inputs, outputs) + + +model = simple_xception_netowkr(input_shape=(256, 256) + (3,), num_classes=156) + +# Train +epochs = 25 + +callbacks = [ + keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"), +] +model.compile( + optimizer=keras.optimizers.Adam(3e-4), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")], +) +model.fit( + train_ds, + epochs=epochs, + callbacks=callbacks, + validation_data=val_ds, +) \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..7541620 --- /dev/null +++ b/test.py @@ -0,0 +1,43 @@ +import keras +import matplotlib.pyplot as plt +import numpy as np + +# Load just enough to get class_names +temp_ds = keras.utils.image_dataset_from_directory( + "Combined_Dataset", + labels="inferred", + label_mode="int", + image_size=(256, 256), + batch_size=1, + shuffle=False +) + +# Load class names +class_names = temp_ds.class_names + +# Load model +model = keras.models.load_model("save_at_3.keras") + +# Load and show image +img = keras.utils.load_img( + "Combined_Dataset/Charmeleon/28d58b5e8c68f76d7986aac99b571377cccac3b6f831fc223ad6123f55fcb001.jpg", + target_size=(256, 256) +) +plt.imshow(img) +plt.axis("off") + + +# Preprocess image +img_array = keras.utils.img_to_array(img) +img_array = keras.ops.expand_dims(img_array, 0) + +# Predict +predictions = model.predict(img_array) +probabilities = keras.ops.softmax(predictions[0]) +predicted_class_index = np.argmax(probabilities) + +# Output result +print(f"Predicted Pokémon: {class_names[predicted_class_index]}") +print(f"Confidence: {100 * probabilities[predicted_class_index]:.2f}%") + +plt.show() -- GitLab