Usage

Contents

Usage#

Just use it as an (almost) drop-in replacement for baryrat. Note how the nodes, values and weights of the barycentric approximation are returned instead of a callable barycentric rational.

from diffaaable import aaa
import jax
import jax.numpy as jnp

### sample points ###
z_k_r = z_k_i = jnp.linspace(0, 3, 20)
Z_k_r, Z_k_i = jnp.meshgrid(z_k_r, z_k_r)
z_k = Z_k_r + 1j*Z_k_i

### function to be approximated ###
def f(x, a):
    return jnp.tan(a*x)
f_pi = jax.tree_util.Partial(f, jnp.pi)

### alternatively use pre-calculated function values ###
z_j, f_j, w_j, z_n = aaa(z_k, f_pi(z_k))

z_n
array([ 7.54560591-9.21209430e+00j, -4.38199026-9.35421717e+00j,
        7.76091437-2.40067271e+00j,  6.29895508-4.10638898e-01j,
        5.47871396+1.67056246e-02j, -4.73367679-2.47426507e+00j,
        4.50004886-2.93450413e-05j,  3.5       +8.46691784e-10j,
       -3.28011054-4.33694069e-01j,  2.5       -9.97797541e-17j,
       -2.47548937+2.00143333e-02j, -1.50006439-4.48103191e-05j,
        1.5       +1.25893302e-16j, -0.5       +1.64574177e-09j,
        0.5       +4.86626531e-16j])

Gradients#

diffaaable is JAX differentiable. Thus you can freely compose it with other JAX functionality and obtain gradients.

def loss(a):
    f_k = f(z_k, a)

    z_j, f_j, w_j, z_n = aaa(z_k, f_k)

    selected_poles = z_n[z_n.real>1e-2]
    relevant_pole = selected_poles[jnp.argmin(selected_poles.real)]
    return jnp.real(relevant_pole - 2)

g = jax.grad(loss)
g(jnp.pi/2)
Array(-7.25252718, dtype=float64, weak_type=True)