MeLOn
Loading...
Searching...
No Matches
kernel.h
Go to the documentation of this file.
1/***********************************************************************************
2* Copyright (c) 2020 Process Systems Engineering (AVT.SVT), RWTH Aachen University
3*
4* This program and the accompanying materials are made available under the
5* terms of the Eclipse Public License 2.0 which is available at
6* http://www.eclipse.org/legal/epl-2.0.
7*
8* SPDX-License-Identifier: EPL-2.0
9*
10* @file kernel.h
11*
12* @brief File containing declaration of kernel classes
13*
14**********************************************************************************/
15
16#pragma once
17
18#include <vector>
19#include <memory>
20
21namespace melon {
22 namespace kernel {
23
28 template <typename T, typename V>
29 class Kernel {
30 public:
31
32 // Deduces kernel return type, depending on template parameters
33 using RET = decltype(std::declval<T>() + std::declval<V>());
34
38 virtual ~Kernel() = default;
39
49 virtual RET evaluate_kernel(std::vector<T> x1, std::vector<V> x2) = 0;
50 };
51
52
57 template <typename T, typename V>
58 class StationaryKernel : public Kernel<T, V> {
59 public:
60
61 using typename Kernel<T, V> ::RET;
62
70 virtual RET _quadratic_distance(std::vector<T> x1, std::vector<V> x2) {
71 RET distance = 0;
72
73 for (size_t i = 0; i < x1.size(); i++) { // i the demension of X and X_test
74 distance += pow(x1.at(i) - x2.at(i), 2);
75 }
76
77 return distance;
78 };
79
87 virtual RET evaluate_kernel(RET distance) = 0; // works only for stationary kernels
88
98 virtual RET calculate_distance(std::vector<T> x1, std::vector<V> x2) = 0;
99 };
100
101
106 template <typename T, typename V>
107 class KernelCompositeAdd : public Kernel<T, V> {
108 public:
109
110 using typename Kernel<T, V> ::RET;
111
117 void add(std::shared_ptr<Kernel<T, V>> kernel) { children.push_back(kernel); }
118
128 RET evaluate_kernel(std::vector<T> x1, std::vector<V> x2) {
129 RET value = 0;
130 for (auto kernel : children) {
131 value += kernel->evaluate_kernel(x1, x2);
132 }
133 return value;
134 }
135
136 private:
137 std::vector<std::shared_ptr<Kernel<T, V>>> children;
138 };
139
144 template <typename T, typename V>
145 class KernelCompositeMultiply : public Kernel<T, V> {
146 public:
147
148 using typename Kernel<T, V> ::RET;
149
155 void add(std::shared_ptr<Kernel<T, V>> kernel) { children.push_back(kernel); }
156
166 RET evaluate_kernel(std::vector<T> x1, std::vector<V> x2) {
167 RET value = 1;
168 for (auto kernel : children) {
169 value *= kernel->k(x1, x2);
170 }
171 return value;
172 }
173
174 private:
175 std::vector<std::shared_ptr<Kernel<T, V>>> children;
176 };
177
182 template <typename T, typename V>
183 class KernelConstant : public Kernel<T, V> {
184 public:
185
186 using typename Kernel<T, V> ::RET;
187
192
198 KernelConstant(const T f) : _f(f) {};
199
205 KernelConstant(const V f) : _f(f) {};
206
216 RET evaluate_kernel(std::vector<T> x1, std::vector<V> x2) {
217 return _f;
218 }
219
220 private:
221 const RET _f;
222 };
223
224
229 template <typename T, typename V>
230 class KernelRBF : public StationaryKernel<T, V> {
231 public:
232
233 using typename Kernel<T, V> ::RET;
234
240 KernelRBF(const double gamma) : _gamma(gamma) {};
241
242
252 RET evaluate_kernel(std::vector<T> x1, std::vector<V> x2) override {
253
254 RET distance = calculate_distance(x1, x2);
255 return evaluate_kernel(distance);
256 }
257
267 RET calculate_distance(std::vector<T> x1, std::vector<V> x2) override {
268 return this->_quadratic_distance(x1, x2);
269 }
270
278 RET evaluate_kernel(RET distance) override {
279
280 return exp(-_gamma * distance);
281
282 }
283
284 private:
285 const double _gamma;
286 };
287
288 }
289
290}
Composite kernel which on evaluation adds the evaluation results of its subkernels.
Definition kernel.h:107
void add(std::shared_ptr< Kernel< T, V > > kernel)
Function for adding another subkernel to the composite kernel.
Definition kernel.h:117
std::vector< std::shared_ptr< Kernel< T, V > > > children
Definition kernel.h:137
RET evaluate_kernel(std::vector< T > x1, std::vector< V > x2)
Function for evalualting the kernel.
Definition kernel.h:128
Composite kernel which on evaluation multiplies the evaluation results of its subkernels.
Definition kernel.h:145
RET evaluate_kernel(std::vector< T > x1, std::vector< V > x2)
Function for evalualting the kernel.
Definition kernel.h:166
void add(std::shared_ptr< Kernel< T, V > > kernel)
Function for adding another subkernel to the composite kernel.
Definition kernel.h:155
std::vector< std::shared_ptr< Kernel< T, V > > > children
Definition kernel.h:175
Kernel which always returns a constant value.
Definition kernel.h:183
const RET _f
Definition kernel.h:221
RET evaluate_kernel(std::vector< T > x1, std::vector< V > x2)
Function for evalualting the kernel.
Definition kernel.h:216
KernelConstant(const T f)
Constructor.
Definition kernel.h:198
KernelConstant()
Constructor. Initializes the kernels return value to 1.
Definition kernel.h:191
KernelConstant(const V f)
Constructor.
Definition kernel.h:205
Abstract parent class for kernel implementations.
Definition kernel.h:29
virtual ~Kernel()=default
Destructor.
virtual RET evaluate_kernel(std::vector< T > x1, std::vector< V > x2)=0
Function for evalualting the kernel for the points x1 and x2.
decltype(std::declval< T >()+std::declval< V >()) RET
Definition kernel.h:33
Implementation of Radial Basis Function kernel.
Definition kernel.h:230
RET calculate_distance(std::vector< T > x1, std::vector< V > x2) override
Function for calculating the distance used in the kernel (type of distance used can vary among kernel...
Definition kernel.h:267
const double _gamma
Definition kernel.h:285
RET evaluate_kernel(RET distance) override
Function for evalualting the kernel for a given distance.
Definition kernel.h:278
KernelRBF(const double gamma)
Constructor.
Definition kernel.h:240
RET evaluate_kernel(std::vector< T > x1, std::vector< V > x2) override
Function for evalualting the kernel.
Definition kernel.h:252
Definition kernel.h:58
virtual RET _quadratic_distance(std::vector< T > x1, std::vector< V > x2)
Calculates the quadratic distance between two points x1 and x2.
Definition kernel.h:70
virtual RET calculate_distance(std::vector< T > x1, std::vector< V > x2)=0
Function for calculating the distance used in the kernel (type of distance used can vary among kernel...
virtual RET evaluate_kernel(RET distance)=0
Function for evalualting the kernel for a given distance.
Definition kernel.h:21