|
|
|
You can use your own custom models inside the framework. The minimal requirement is a Python class that extends `torch.nn.Module` that describes your model. You can either directly define a model for the client and the server part by yourself, or you can rely on a model provider that splits your complete model by specifying a cut layer. Note that there are restrictions when using the cut layer approach. Your model needs to define every layer as an attribute of the class itself. Otherwise, PyTorch will not add them to as `children()` of the underlying `torch.nn.Module` instance. The cut layer code relies on this invariance.
|
|
|
|
|
|
|
|
To use your custom models, simply specify them via a model provider of your choice. You can either save your custom model provider configuration inside the `configs/model_provider` folder and import it via the filename, or you can simply nest the configurations inside your root experiment configuration file. Either way, the content under the `model_provider` key can look like this:
|
|
|
|
|
|
|
|
```yaml
|
|
|
|
_target_: edml.models.provider.base.ModelProvider
|
|
|
|
client:
|
|
|
|
_target_: your.custom.ClientNet
|
|
|
|
server:
|
|
|
|
_target_: your.custom.ServerNet
|
|
|
|
``` |