Performing inference with a real CIFAR-10 input image#

Importing libraries#

[ ]:
import urllib.request
import tarfile
import os
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

import aidge_core
import aidge_onnx
import aidge_backend_cpu

Loading the saved ONNX model on Aidge#

[ ]:
aidge_model = aidge_onnx.load_onnx(
    "./examples/tutorials/Dropout_custom_implementation/ONNX_files/MarwaNet.onnx"
)

Definition of the CIFAR-10 classes names and images labels#

[ ]:
# import random

# CIFAR-10 class names
cifar10_classes = [
    "Airplane",
    "Automobile",
    "Bird",
    "Cat",
    "Deer",
    "Dog",
    "Frog",
    "Horse",
    "Ship",
    "Truck",
]

# Select a class index
# class_index = random.randint(0, len(cifar10_classes) - 1)
class_index = 0
image_label = cifar10_classes[class_index]

Loading the NumPy array from the npy file#

[ ]:
# Step 1: Download CIFAR-10 binary dataset
cifar_url = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
data_dir = "./examples/tutorials/Dropout_custom_implementation/Datasets/CIFAR-10"
cifar_file = os.path.join(data_dir, "cifar-10-binary.tar.gz")

# Create directory
os.makedirs(data_dir, exist_ok=True)

# Download dataset if not already downloaded
if not os.path.exists(cifar_file):
    print("Downloading CIFAR-10 dataset...")
    urllib.request.urlretrieve(cifar_url, cifar_file)

# Extract files
with tarfile.open(cifar_file, "r:gz") as tar:
    tar.extractall(path=data_dir)

# Step 2: Read CIFAR-10 binary data
batch_file = os.path.join(data_dir, "cifar-10-batches-bin", "test_batch.bin")

# CIFAR-10 constants
num_images = 10000
image_size = 32 * 32 * 3  # 32x32 image with 3 RGB channels
record_size = 1 + image_size  # 1 byte for label + image data

# Read file and extract the first "ship" image (label 8)
with open(batch_file, "rb") as f:
    for i in range(num_images):
        label = np.frombuffer(f.read(1), dtype=np.uint8)[0]  # Read label (1 byte)
        image_data = np.frombuffer(
            f.read(image_size), dtype=np.uint8
        )  # Read image (3072 bytes)

        if label == 8:  # Label 8 is "ship"
            image_numpy = (
                image_data.reshape(3, 32, 32).astype(np.float32) / 255.0
            )  # Normalize to [0,1]
            image_label = label
            break

# Step 3: Save as .npy file
np.save(
    "./examples/tutorials/Dropout_custom_implementation/NumPy_images/image_ship.npy",
    image_numpy,
)

# Step 4: Load and verify
loaded_image = np.load(
    "./examples/tutorials/Dropout_custom_implementation/NumPy_images/image_ship.npy"
)
print(
    f"Label: {image_label}, Image shape: {loaded_image.shape}"
)  # Should print (3, 32, 32)

# Step 5: Convert and Save as PNG (Optional)
image_pil = Image.fromarray(
    (image_numpy.transpose(1, 2, 0) * 255).astype(np.uint8)
)  # Convert to (32,32,3)
image_pil.save(
    "./examples/tutorials/Dropout_custom_implementation/NumPy_images/image_ship.png"
)
print("Image saved as 'NumPy_images/image_ship.png'")

Normalizing the image numpy#

[ ]:
# Define mean and std
mean = np.array([0.4914, 0.4822, 0.4465])[:, None, None]
std = np.array([0.2023, 0.1994, 0.2010])[:, None, None]

# Denormalize the image
image_numpy_denorm = std * image_numpy + mean

# Normalize it again
image_numpy_norm = (image_numpy_denorm - mean) / std

# Check the shape again after normalization
print(f"image_numpy_norm shape: {image_numpy_norm.shape}")

Displaying the image numpy after normalization#

[ ]:
image_numpy_plot = np.transpose(image_numpy, (1, 2, 0))

mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2023, 0.1994, 0.2010])

image_numpy_plot = (image_numpy_plot * std + mean) * 255

image_numpy_plot = image_numpy_plot.astype(np.uint8)
print(f"image_numpy_plot shape: {image_numpy_plot.shape}")

plt.imshow(image_numpy_plot)
plt.show()

Converting the image numpy array to Aidge tensor#

[ ]:
image_aidge = aidge_core.Tensor(image_numpy)

# Tensor dimensions and size
tensor_dims = getattr(image_aidge, "dims", "Attribute not found")
tensor_size = image_aidge.size() if hasattr(image_aidge, "size") else "Method not found"

# print(f"Tensor dimensions: {tensor_dims}, Tensor size: {tensor_size}")

Setting up the backend#

[ ]:
aidge_model.set_datatype(aidge_core.dtype.float32)
aidge_model.set_backend("cpu")

Creating the scheduler#

[ ]:
scheduler = aidge_core.SequentialScheduler(aidge_model)

Inserting the input producer#

[ ]:
input_node = aidge_core.Producer([1, 3, 32, 32], "DataProvider")

Adding the model graph as a child of input_node#

[ ]:
input_node.add_child(aidge_model, 0, aidge_model.get_ordered_inputs()[0])

Adding the input_node to Aidge model graph#

[ ]:
aidge_model.add(input_node)

Retrieving nodes and displaying information in the ONNX model#

[ ]:
in_node = aidge_model.get_node("DataProvider")
print(in_node.name())
print(in_node.type())

Performing inference with a real CIFAR-10 input image in Aidge#

[ ]:
# Inference function definition
def inference(model, scheduler, image_numpy):
    # Set up the input
    image_numpy = np.reshape(image_numpy, (1, 3, 32, 32))
    image_aidge = aidge_core.Tensor(image_numpy)
    input_node.get_operator().set_output(0, image_aidge)

    # Run the inference
    scheduler.forward()

    # Gather the results
    output_node = model.get_ordered_outputs()[0][0]
    output_aidge = output_node.get_operator().get_output(0)
    return np.array(output_aidge)


# Perform inference on the single image_aidge
output_array = inference(aidge_model, scheduler, image_aidge)

# Print output_array shape, data type and inference result
print(f"\nOutput_array shape: {output_array.shape}")
print(f"\nOutput_array data type: {output_array.dtype}")

# Print inference result
print("\nEXAMPLE INFERENCES:")
print(f"\nInference result: {output_array}")

# Ensure the output is in the expected shape for classification
output_array = output_array.reshape(1, -1)

# Determine the predicted class index
predicted_class_index = np.argmax(output_array)

# Map the class index to the class name
cifar10_classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

# Check if the predicted class index is within valid range
if predicted_class_index < len(cifar10_classes):
    predicted_class_label = cifar10_classes[predicted_class_index]
    print(
        f"\nPredicted Class Index: {predicted_class_index}, Predicted Class Label: {predicted_class_label}"
    )
else:
    print(
        f"\nPredicted index {predicted_class_index} is out of range for CIFAR-10 classes."
    )