Source code for matchernet.fn
"""
fn.py
=====
This module contains function handler classes that the BundleNet
architecture needs. It is overridden when you use chainer/TensorFlow to implement the arbitrary parametric functions.
"""
import autograd.numpy as np
from autograd import jacobian
[docs]class Fn(object):
"""An abstract class to implement numerical function
that BundleNet uses.
"""
def __init__(self, A):
self.A = A
self.x = jacobian(self.value, 0)
[docs] def get_params(self):
return self.A
[docs] def value(self, x):
"""Numerically calculates f(x)
x should be a numpy array of shape (dim_in, 1)
outputs a numpy array of shape (dim_out, 1)
"""
return 0
[docs]class LinearFn(Fn):
"""Linear function y = np.dot(A, x) and its derivatives.
"""
def __init__(self, A):
super(LinearFn, self).__init__(A)
self.A = A
self.dx = jacobian(self.value, 0)
[docs] def value(self, x):
return np.dot(self.A, x)
[docs]class LinearFnXU(object):
"""Linear function y = np.dot(A, x) + np.dot(B, u) and its derivatives.
.. note:: The shapes of matrix A and matrix B must match
"""
def __init__(self, A, B):
self.A = A
self.B = B
self.dx = jacobian(self.value, 0)
self.du = jacobian(self.value, 1)
[docs] def value(self, x, u):
return np.dot(self.A, x) + np.dot(self.B, u)