Add a custom operator to the C++ export#
The main objective of this tutorial is to demonstrate the toolchain to detect unsupported operators and add them in an export module. For this tutorial, we use the CPP export module aidge_export_cpp
to demonstrate the toolchain.
Install requirements#
Ensure that the Aidge modules are properly installed in the current environment. If it is the case, the following setup step can be skipped.
Note: When running this notebook on Binder, all required components are pre-installed.
[ ]:
%pip install aidge-core \
aidge-backend-cpu \
aidge-export-cpp \
aidge-onnx \
aidge-model-explorer
Import Aidge#
[ ]:
import aidge_core
import aidge_backend_cpu
import aidge_export_cpp
import aidge_onnx
import aidge_model_explorer
import numpy as np
Load ONNX model#
[ ]:
file_url = "https://huggingface.co/EclipseAidge/LeNet/resolve/main/lenet_mnist.onnx?download=true"
file_path = "lenet_mnist.onnx"
aidge_core.utils.download_file(file_path, file_url)
[ ]:
model = aidge_onnx.load_onnx("lenet_mnist.onnx")
[ ]:
# Remove Flatten node, useless in the CPP export
aidge_core.remove_flatten(model)
digit = np.load("digit.npy", allow_pickle=True)
# Create Producer Node for the Graph
# Note: This means the graph will have no input!
input_tensor = aidge_core.Tensor(digit)
input_tensor.set_data_format(aidge_core.dformat.nchw)
input_node = aidge_core.Producer(input_tensor, "input")
input_node.add_child(model)
model.add(input_node)
Replace ReLU operators by Swish operators#
Let’s say you want to replace ReLU with another activation like Switch.
[ ]:
# Forward the dimensions in the graph in order to get the size for the beta vector of the Swish
model.forward_dims()
# Use GraphMatching to replace ReLU with Swish
matches = aidge_core.SinglePassGraphMatching(model).match("ReLU")
print('Number of match : ', len(matches))
switch_id = 0
for i, match in enumerate(matches):
print('Match ', i)
node_ReLU = match.graph.root_node()
print('Replacing ', node_ReLU.type(), ' : ', node_ReLU.name())
# We instantiate Swish as a generic operator
node_swish = aidge_core.GenericOperator("Swish", nb_data=1, nb_param=0, nb_out=1, name=f"swish_{switch_id}")
node_swish.get_operator().attr.betas = [1.0]*node_ReLU.get_operator().get_input(0).dims()[1]
print('Replacing ', node_ReLU.type(), ' : ', node_ReLU.name(), ' with ' , node_swish.name())
# Note: ignore new outputs to avoid adding MaxPooling optional output to the graph outputs
aidge_core.GraphView.replace(set([node_ReLU]), set([node_swish]), ignore_new_outputs=True)
switch_id+=1
[ ]:
aidge_model_explorer.visualize(model, "myModel", embed=True)
Schedule the graph#
Add the function to specify how Swish activation transforms the dimensions. This forward_dims
function is required to perform a sequential scheduling.
[ ]:
class GenericImpl(aidge_core.OperatorImpl): # Inherit OperatorImpl to interface with Aidge !
def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op, 'cpu')
# no need to define forward() function in python as we do not intend to run a scheduler on the model
for node in model.get_nodes():
if node.type() == "Swish":
node.get_operator().set_forward_dims(lambda x: x) # to propagate dimensions in the model
node.get_operator().set_impl(GenericImpl(node.get_operator())) # Setting implementation
[ ]:
model.compile("cpu", aidge_core.dtype.float32, dims=[[1, 1, 28, 28]])
scheduler = aidge_core.SequentialScheduler(model)
scheduler.generate_scheduling()
s = scheduler.get_sequential_static_scheduling()
Add Swish to the CPP export support#
[ ]:
# Note: we register a GenericOperator so we need to use ``register_generic``.
# For registering an existing operator use ``register``
# For registering a MetaOperator use ``register_metaop``
@aidge_export_cpp.ExportLibCpp.register_generic("Swish", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class SwishCPP(aidge_core.export_utils.ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
self.config_template = "swish_export_files/swish_config.jinja"
self.forward_template = "swish_export_files/swish_forward.jinja"
self.include_list = []
self.kernels_to_copy = [
"swish_export_files/swish_kernel.hpp",
]
[ ]:
aidge_export_cpp.export("my_export", model, scheduler)
[ ]:
!tree my_export
[ ]:
!cd my_export && make
[ ]:
!./my_export/bin/run_export