Skip to content
Snippets Groups Projects
Commit 55ed35f4 authored by michael.divia's avatar michael.divia
Browse files

Working Simple CNN

parent c15ae6cc
No related branches found
No related tags found
No related merge requests found
import os
import numpy as np
import keras
from keras import layers
import matplotlib.pyplot as plt
from tensorflow import data as tf_data
import random
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Load training and validation dataset
train_ds, val_ds = keras.utils.image_dataset_from_directory(
"Combined_Dataset",
labels="inferred",
label_mode="int",
image_size=(256, 256),
batch_size=20,
shuffle=True,
validation_split=0.2,
subset="both",
seed=random.randint(0,8000)
)
# Get class (Pokémon) names
class_names = train_ds.class_names
# Introduce artificial sample diversity
data_augmentation_layers = [
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
def data_augmentation(images):
for layer in data_augmentation_layers:
images = layer(images)
return images
# Apply `data_augmentation` to the training images.
train_ds = train_ds.map(
lambda img, label: (data_augmentation(img), label),
num_parallel_calls=tf_data.AUTOTUNE,
)
# Prefetching samples in GPU memory helps maximize GPU utilization.
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)
# MODEL
def simple_xception_netowkr(input_shape, num_classes):
inputs = keras.Input(shape=input_shape)
# Entry block
x = layers.Rescaling(1.0 / 255)(inputs)
x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
for size in [256, 512, 728]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
if num_classes == 2:
units = 1
else:
units = num_classes
x = layers.Dropout(0.25)(x)
# We specify activation=None so as to return logits
outputs = layers.Dense(units, activation=None)(x)
return keras.Model(inputs, outputs)
model = simple_xception_netowkr(input_shape=(256, 256) + (3,), num_classes=156)
# Train
epochs = 25
callbacks = [
keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
]
model.compile(
optimizer=keras.optimizers.Adam(3e-4),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)
model.fit(
train_ds,
epochs=epochs,
callbacks=callbacks,
validation_data=val_ds,
)
\ No newline at end of file
test.py 0 → 100644
import keras
import matplotlib.pyplot as plt
import numpy as np
# Load just enough to get class_names
temp_ds = keras.utils.image_dataset_from_directory(
"Combined_Dataset",
labels="inferred",
label_mode="int",
image_size=(256, 256),
batch_size=1,
shuffle=False
)
# Load class names
class_names = temp_ds.class_names
# Load model
model = keras.models.load_model("save_at_3.keras")
# Load and show image
img = keras.utils.load_img(
"Combined_Dataset/Charmeleon/28d58b5e8c68f76d7986aac99b571377cccac3b6f831fc223ad6123f55fcb001.jpg",
target_size=(256, 256)
)
plt.imshow(img)
plt.axis("off")
# Preprocess image
img_array = keras.utils.img_to_array(img)
img_array = keras.ops.expand_dims(img_array, 0)
# Predict
predictions = model.predict(img_array)
probabilities = keras.ops.softmax(predictions[0])
predicted_class_index = np.argmax(probabilities)
# Output result
print(f"Predicted Pokémon: {class_names[predicted_class_index]}")
print(f"Confidence: {100 * probabilities[predicted_class_index]:.2f}%")
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment