You can use your own custom models inside the framework. The minimal requirement is a Python class that extends torch.nn.Module
that describes your model. You can either directly define a model for the client and the server part by yourself, or you can rely on a model provider that splits your complete model by specifying a cut layer. Note that there are restrictions when using the cut layer approach. Your model needs to define every layer as an attribute of the class itself. Otherwise, PyTorch will not add them to as children()
of the underlying torch.nn.Module
instance. The cut layer code relies on this invariance.
To use your custom models, simply specify them via a model provider of your choice. You can either save your custom model provider configuration inside the configs/model_provider
folder and import it via the filename, or you can simply nest the configurations inside your root experiment configuration file. Either way, the content under the model_provider
key can look like this:
_target_: edml.models.provider.base.ModelProvider
client:
_target_: your.custom.ClientNet
server:
_target_: your.custom.ServerNet