Add a custom operator to the C++ export#

Binder

The main objective of this tutorial is to demonstrate the toolchain to detect unsupported operators and add them in an export module. For this tutorial, we use the CPP export module aidge_export_cpp to demonstrate the toolchain.

Install requirements#

Ensure that the Aidge modules are properly installed in the current environment. If it is the case, the following setup step can be skipped.
Note: When running this notebook on Binder, all required components are pre-installed.
[ ]:
%pip install aidge-core \
    aidge-backend-cpu \
    aidge-export-cpp \
    aidge-onnx \
    aidge-model-explorer

Import Aidge#

[ ]:
import aidge_core
import aidge_backend_cpu
import aidge_export_cpp
import aidge_onnx
import aidge_model_explorer
import numpy as np

Load ONNX model#

[ ]:
file_url = "https://huggingface.co/EclipseAidge/LeNet/resolve/main/lenet_mnist.onnx?download=true"
file_path = "lenet_mnist.onnx"

aidge_core.utils.download_file(file_path, file_url)
[ ]:
model = aidge_onnx.load_onnx("lenet_mnist.onnx")
[ ]:
# Remove Flatten node, useless in the CPP export
aidge_core.remove_flatten(model)

digit = np.load("digit.npy", allow_pickle=True)

# Create Producer Node for the Graph
# Note: This means the graph will have no input!
input_tensor = aidge_core.Tensor(digit)
input_tensor.set_data_format(aidge_core.dformat.nchw)
input_node = aidge_core.Producer(input_tensor, "input")
input_node.add_child(model)
model.add(input_node)

Replace ReLU operators by Swish operators#

Let’s say you want to replace ReLU with another activation like Switch.

[ ]:
# Forward the dimensions in the graph in order to get the size for the beta vector of the Swish
model.forward_dims()

# Use GraphMatching to replace ReLU with Swish
matches = aidge_core.SinglePassGraphMatching(model).match("ReLU")
print('Number of match : ', len(matches))

switch_id = 0
for i, match in enumerate(matches):
    print('Match ', i)
    node_ReLU = match.graph.root_node()

    print('Replacing ', node_ReLU.type(), ' : ', node_ReLU.name())

    # We instantiate Swish as a generic operator
    node_swish = aidge_core.GenericOperator("Swish", nb_data=1, nb_param=0, nb_out=1, name=f"swish_{switch_id}")
    node_swish.get_operator().attr.betas = [1.0]*node_ReLU.get_operator().get_input(0).dims()[1]

    print('Replacing ', node_ReLU.type(), ' : ', node_ReLU.name(), ' with ' , node_swish.name())

    # Note: ignore new outputs to avoid adding MaxPooling optional output to the graph outputs
    aidge_core.GraphView.replace(set([node_ReLU]), set([node_swish]), ignore_new_outputs=True)
    switch_id+=1
[ ]:
aidge_model_explorer.visualize(model, "myModel", embed=True)

Schedule the graph#

Add the function to specify how Swish activation transforms the dimensions. This forward_dims function is required to perform a sequential scheduling.

[ ]:
class GenericImpl(aidge_core.OperatorImpl): # Inherit OperatorImpl to interface with Aidge !
    def __init__(self, op: aidge_core.Operator):
        aidge_core.OperatorImpl.__init__(self, op, 'cpu')
    # no need to define forward() function in python as we do not intend to run a scheduler on the model

for node in model.get_nodes():
    if node.type() == "Swish":
        node.get_operator().set_forward_dims(lambda x: x) # to propagate dimensions in the model
        node.get_operator().set_impl(GenericImpl(node.get_operator())) # Setting implementation
[ ]:
model.compile("cpu", aidge_core.dtype.float32, dims=[[1, 1, 28, 28]])
scheduler = aidge_core.SequentialScheduler(model)
scheduler.generate_scheduling()
s = scheduler.get_sequential_static_scheduling()

Add Swish to the CPP export support#

[ ]:
# Note: we register a GenericOperator so we need to use ``register_generic``.
# For registering an existing operator use ``register``
# For registering a MetaOperator use ``register_metaop``
@aidge_export_cpp.ExportLibCpp.register_generic("Swish", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class SwishCPP(aidge_core.export_utils.ExportNodeCpp):
    def __init__(self, node, mem_info):
        super().__init__(node, mem_info)
        self.config_template = "swish_export_files/swish_config.jinja"
        self.forward_template = "swish_export_files/swish_forward.jinja"
        self.include_list = []
        self.kernels_to_copy = [
            "swish_export_files/swish_kernel.hpp",
        ]

[ ]:
aidge_export_cpp.export("my_export", model, scheduler)
[ ]:
!tree my_export
[ ]:
!cd my_export && make
[ ]:
!./my_export/bin/run_export