Usage

Contents

Usage#

Just use it as an (almost) drop-in replacement for baryrat. Note how the nodes, values and weights of the barycentric approximation are returned instead of a callable barycentric rational.

from diffaaable import aaa
import jax
import jax.numpy as jnp

### sample points ###
z_k_r = z_k_i = jnp.linspace(0, 3, 20)
Z_k_r, Z_k_i = jnp.meshgrid(z_k_r, z_k_r)
z_k = Z_k_r + 1j*Z_k_i

### function to be approximated ###
def f(x, a):
    return jnp.tan(a*x)
f_pi = jax.tree_util.Partial(f, jnp.pi)

### alternatively use pre-calculated function values ###
z_j, f_j, w_j, z_n = aaa(z_k, f_pi(z_k))

z_n
array([ 7.54559084-9.21200407e+00j, -4.38189497-9.35422547e+00j,
        7.76089436-2.40065767e+00j,  6.29895054-4.10633458e-01j,
        5.47871471+1.67060988e-02j, -4.7336582 -2.47428599e+00j,
        4.50004886-2.93464617e-05j,  3.5       +8.46654299e-10j,
       -3.2801036 -4.33698881e-01j,  2.5       -4.03534066e-16j,
       -2.47548883+2.00153822e-02j, -1.50006439-4.48153236e-05j,
        1.5       -6.47451255e-17j, -0.5       +1.64584750e-09j,
        0.5       +6.99525463e-16j])

Gradients#

diffaaable is JAX differentiable. Thus you can freely compose it with other JAX functionality and obtain gradients.

def loss(a):
    f_k = f(z_k, a)

    z_j, f_j, w_j, z_n = aaa(z_k, f_k)

    selected_poles = z_n[z_n.real>1e-2]
    relevant_pole = selected_poles[jnp.argmin(selected_poles.real)]
    return jnp.real(relevant_pole - 2)

g = jax.grad(loss)
g(jnp.pi/2)
Array(-7.25252718, dtype=float64, weak_type=True)