Performing inference with a real CIFAR-10 input image#

Importing libraries#

[1]:
import os
import tarfile
import urllib

import matplotlib.pyplot as plt
import numpy as np
import requests
from huggingface_hub import hf_hub_download
from PIL import Image

import aidge_core
import aidge_onnx

Loading the saved ONNX model on Aidge#

[ ]:
local_path = hf_hub_download(
    repo_id="EclipseAidge/CustomCNN",
    filename="MarwaNet.onnx",
)

aidge_model = aidge_onnx.load_onnx(local_path)

print("\n Model loaded!")

Definition of the CIFAR-10 classes names and images labels#

[3]:
# import random

# CIFAR-10 class names
cifar10_classes = [
    "Airplane",
    "Automobile",
    "Bird",
    "Cat",
    "Deer",
    "Dog",
    "Frog",
    "Horse",
    "Ship",
    "Truck",
]

# Select a class index
# class_index = random.randint(0, len(cifar10_classes) - 1)
class_index = 0
image_label = cifar10_classes[class_index]

Loading the NumPy array from the npy file#

[ ]:
# Paths
cifar_file = (
    "./examples/tutorials/Dropout_tutorials/Datasets/CIFAR-10/cifar-10-binary.tar.gz"
)
data_dir = "./examples/tutorials/Dropout_tutorials/Datasets/CIFAR-10"

# Extract CIFAR-10
with tarfile.open(cifar_file, "r:gz") as tar:
    tar.extractall(path=data_dir)

# Read CIFAR-10 binary data
batch_file = os.path.join(data_dir, "cifar-10-batches-bin", "test_batch.bin")

num_images = 10000
image_size = 32 * 32 * 3

image_numpy = None
image_label = None

with open(batch_file, "rb") as f:
    for i in range(num_images):
        label = np.frombuffer(f.read(1), dtype=np.uint8)[0]
        image_data = np.frombuffer(f.read(image_size), dtype=np.uint8)

        if label == 8:  # ship
            image_numpy = image_data.reshape(3, 32, 32).astype(np.float32) / 255.0
            image_label = label
            break

# Save .npy
np.save(
    "./examples/tutorials/Dropout_tutorials/NumPy_images/image_ship.npy",
    image_numpy,
)

# Load and verify
loaded_image = np.load(
    "./examples/tutorials/Dropout_tutorials/NumPy_images/image_ship.npy"
)

print(f"Label: {image_label}, Shape: {loaded_image.shape}")

# Save PNG
image_pil = Image.fromarray((image_numpy.transpose(1, 2, 0) * 255).astype(np.uint8))

image_pil.save("./examples/tutorials/Dropout_tutorials/NumPy_images/image_ship.png")

print("Saved image_ship.png")
/tmp/ipykernel_146176/2131495271.py:7: DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. Use the filter argument to control this behavior.
  tar.extractall(path=data_dir)
Label: 8, Shape: (3, 32, 32)
Saved image_ship.png

Normalizing the image numpy#

[ ]:
# Define mean and std
mean = np.array([0.4914, 0.4822, 0.4465])[:, None, None]
std = np.array([0.2023, 0.1994, 0.2010])[:, None, None]

# Denormalize the image
image_numpy_denorm = std * image_numpy + mean

# Normalize it again
image_numpy_norm = (image_numpy_denorm - mean) / std

# Check the shape again after normalization
print(f"image_numpy_norm shape: {image_numpy_norm.shape}")

Displaying the image numpy after normalization#

[ ]:
image_numpy_plot = np.transpose(image_numpy, (1, 2, 0))

mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2023, 0.1994, 0.2010])

image_numpy_plot = (image_numpy_plot * std + mean) * 255

image_numpy_plot = image_numpy_plot.astype(np.uint8)
print(f"image_numpy_plot shape: {image_numpy_plot.shape}")

plt.imshow(image_numpy_plot)
plt.show()

Converting the image numpy array to Aidge tensor#

[ ]:
image_aidge = aidge_core.Tensor(image_numpy)

# Tensor dimensions and size
tensor_dims = getattr(image_aidge, "dims", "Attribute not found")
tensor_size = (
    image_aidge.size if hasattr(image_aidge, "size") else "Attribute not found"
)

print(f"Tensor dimensions: {tensor_dims}, Tensor size: {tensor_size}")

Setting up the backend#

[8]:
aidge_model.set_datatype(aidge_core.dtype.float32)
aidge_model.set_backend("cpu")

Creating the scheduler#

[9]:
scheduler = aidge_core.SequentialScheduler(aidge_model)

Inserting the input producer#

[10]:
input_node = aidge_core.Producer([1, 3, 32, 32], "DataProvider")

Adding the model graph as a child of input_node#

[11]:
input_node.add_child(aidge_model, 0, aidge_model.get_ordered_inputs()[0])

Adding the input_node to Aidge model graph#

[12]:
aidge_model.add(input_node)

Retrieving nodes and displaying information in the ONNX model#

[13]:
in_node = aidge_model.get_node("DataProvider")
print(in_node.name())
print(in_node.type())
DataProvider
Producer

Performing inference with a real CIFAR-10 input image in Aidge#

[ ]:
# Inference function definition
def inference(model, scheduler, image_numpy):
    # Set up the input
    image_numpy = np.reshape(image_numpy, (1, 3, 32, 32))
    image_aidge = aidge_core.Tensor(image_numpy)
    input_node.get_operator().set_output(0, image_aidge)

    # Run the inference
    scheduler.forward()

    # Gather the results
    output_node = model.get_ordered_outputs()[0][0]
    output_aidge = output_node.get_operator().get_output(0)
    return np.array(output_aidge)


# Perform inference on the single image_aidge
output_array = inference(aidge_model, scheduler, image_aidge)

# Print output_array shape, data type and inference result
print(f"\nOutput_array shape: {output_array.shape}")
print(f"\nOutput_array data type: {output_array.dtype}")

# Print inference result
print("\nEXAMPLE INFERENCES:")
print(f"\nInference result: {output_array}")

# Ensure the output is in the expected shape for classification
output_array = output_array.reshape(1, -1)

# Determine the predicted class index
predicted_class_index = np.argmax(output_array)

# Map the class index to the class name
cifar10_classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

# Check if the predicted class index is within valid range
if predicted_class_index < len(cifar10_classes):
    predicted_class_label = cifar10_classes[predicted_class_index]
    print(
        f"\nPredicted Class Index: {predicted_class_index}, Predicted Class Label: {predicted_class_label}"
    )
else:
    print(
        f"\nPredicted index {predicted_class_index} is out of range for CIFAR-10 classes."
    )