Loss functions are binary functions that expect an input vector and a target vector and return the corresponding loss. You can either class-based loss functions like torch.nn.CrossEntropyLoss
, or you can use normal Python methods. The framework expects the configuration object loss_fn
to point to a Python language construct that can be called as a method with the above two arguments.
In the case that you use Pytorch's builtin loss functions inside the torc.nn
module, you can simply point to the class via the loss_fn._target_
key:
loss_fn:
_target_: torch.nn.CrossEntropyLoss
If you want to use a normal Python method as a loss function, you have to use a hydra
utility function to make this work properly:
_target_: hydra.utils.get_method
path: torch.nn.functional.nll_loss