Updated: 2023-03-05 Sun 19:01

Jax Study

1. For loop

Jax has structure for for-loops using the jax.lax.fori_loop. The implementation executes the arguments to evaluate the following python code:

def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

The value val should hold a fixed shape and dtype across all iterations. The key thing to note in this is that val can also just be a nested tuple/list/dict container with a fixed structure. There isn’t really a need to combine a fori_loop() with jit() since it compiles the function body_fun and hence jit becomes unnecessary for the function.

A quick way to use this is to have a step function that implements a single iteration of the intended loop and then use that as the argument for body_fun. For example:

import jax.numpy as jnp
from jax import lax

def forloop_example():
    result = jnp.zeros(10)
    def step(i, val):
        val = val.at[i].set(i)
        return val

    result = lax.fori_loop(0, len(result), step, result)
    return result

forloop_example()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

If we instead want to update multiple values in the for loop, we can return them as a tuple. For example

def forloop_example_multi():
    res1 = jnp.zeros(10)
    res2 = jnp.zeros(10)
    def step(i, val):
        res1, res2 = val
        res1 = res1.at[i].set(i)
        res2 = res2.at[i].set(len(res2)-i)
        return (res1, res2)

    res1, res2 = lax.fori_loop(0, len(res1), step, (res1, res2))
    return res1, res2

forloop_example_multi()
Array ((0 1 2 3 4 5 6 7 8 9) dtype=float32) Array ((10 9 8 7 6 5 4 3 2 1) dtype=float32)

A thing to remember in this should be that all arrays can be updated only for a single number of iterations (as in a single for loop range).

2. Gradients

To calculate gradients of scalar functions in jax, we use the jax.grad() function. For example

from jax import grad, value_and_grad
def f(x):
    return x**2

value_and_grad(f)(3.0)
Array (9 dtype=float32 weaktype=True) Array (6 dtype=float32 weaktype=True)

As we can see, using value_and_grad() we can get the value and gradient of a function at a value x. The gradient computation doesn’t support vectors as output of a function but the desired result can be achieved by using vmap().

value_and_grad(f)(jnp.arange(4.))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 value_and_grad(f)(jnp.arange(4.))

    [... skipping hidden 2 frame]

File ~/.local/share/virtualenvs/diff-mpm-zyodPdJl/lib/python3.10/site-packages/jax/_src/api.py:1274, in _check_scalar(x)
   1272 if isinstance(aval, ShapedArray):
   1273   if aval.shape != ():
-> 1274     raise TypeError(msg(f"had shape: {aval.shape}"))
   1275 else:
   1276   raise TypeError(msg(f"had abstract value {aval}"))

TypeError: Gradient only defined for scalar-output functions. Output had shape: (4,).
from jax import vmap
vmap(value_and_grad(f))(jnp.arange(4.))
Array ((0 1 4 9) dtype=float32) Array ((0 2 4 6) dtype=float32)

Therefore, to vectorize gradient calculation on scalar output functions, vmap() can be used. But, this doesn’t help in calculating the gradients of functions that have a vector output.

def fvec(x):
    return jnp.array([x**2, x**3])

value_and_grad(fvec)(2.0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[11], line 4
      1 def fvec(x):
      2     return jnp.array([x**2, x**3])
----> 4 value_and_grad(fvec)(2.0)

    [... skipping hidden 2 frame]

File ~/.local/share/virtualenvs/diff-mpm-zyodPdJl/lib/python3.10/site-packages/jax/_src/api.py:1274, in _check_scalar(x)
   1272 if isinstance(aval, ShapedArray):
   1273   if aval.shape != ():
-> 1274     raise TypeError(msg(f"had shape: {aval.shape}"))
   1275 else:
   1276   raise TypeError(msg(f"had abstract value {aval}"))

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).

To calculate these gradients, one way is to use jacobian().

from jax import jacobian

jacobian(fvec)(2.)
Array([ 4., 12.], dtype=float32, weak_type=True)

jacobian() can also accept vector inputs but it is important to understand the difference between using jacobian() with vector inputs and using vmap() with jacobian().

jacobian(fvec)(jnp.array([2., 3.]))
Array([[[ 4.,  0.],
        [ 0.,  6.]],

       [[12.,  0.],
        [ 0., 27.]]], dtype=float32)
vmap(jacobian(fvec))(jnp.array([2., 3.]))
Array([[ 4., 12.],
       [ 6., 27.]], dtype=float32)

As can be seen, when we use jacobian() with vector inputs, derivatives are calculated based on each of the input element, such that the original function is a function of n elements. On the other hand, if we just want the gradient of each return element of the function wrt to the single input element, we should use vmap(jacobian(f)) instead.

3. 1D Point Axial Vibration

3.1. Initialization

Computational Domain for this problem is set up as

L = 1                           # Domain size
nodes = jnp.array([0, L])       # Nodal coordinates
nnodes = len(nodes)             # Number of nodes
nelements = 1                   # Number of elements
nparticles = 1                  # Number of particles
el_length = L / nelements       # Element length

Material properties:

E = 4 * jnp.pi**2               # Young's modulus
rho = 1.                        # Density

Initial loading conditions:

v0 = 0.1                        # initial velocity
x_loc = 0.5                     # Location to get analytical solution

The material points in MPM keep track of position, mass, velocity, volume, momentum and stress. The material point is at the middle of the element and its volume is the size of the entire length of the bar.

x_p = 0.5 * el_length           # position of material point
mass_p = 1.                     # Mass of material point
vol_p = el_length / nparticles  # Volume
vel_p = v0                      # Initial velocity
stress_p = 0.                   # Stress
strain_p = 0.                   # Strain
momentum_p = mass_p * vel_p

3.1.1. Shape functions

For the shape function, we use a two-noded single element with linear elements.

def shape_fn(x):
    return 1 - abs(x - nodes)/L

For this shape function, we can write its derivative using jacobian(). The computed value can be confirmed by comparing to the analytical value of \(B(x) = [-1/L, 1/L]\).

vmap(jacobian(shape_fn))(jnp.array([0.1, 0.8]))
Array([[-1.,  1.],
       [-1.,  1.]], dtype=float32)

As we see, we get the correct value of the derivatives for 2 different values of x. Hence, we can define the derivative of the shape function as

def shape_fn_grad(x):
    return vmap(jacobian(shape_fn))(x)

3.2. Solution for a single step of time

During a single timestep, we perform the following actions

  1. Compute the nodal mass
  2. Compute nodal momentum
  3. Apply boundary conditions
  4. Compute external forces
  5. Compute internal forces
  6. Compute total unbalanced nodal forces
  7. Update nodal momentum
  8. Update particle position and velocities
  9. Update particle momentum
  10. Update nodal velocity
  11. Compute stress and strain

During this entire process, we want to store the evolution of velocity, position and energies with time.

t0, T, dt = 0, 10, 0.01
time = jnp.arange(t0, T, dt)
velocity = jnp.zeros(time.shape)
position = jnp.zeros(time.shape)
strain_energy = jnp.zeros(time.shape)
kinetic_energy = jnp.zeros(time.shape)

position = position.at[0].set(x_p)
velocity = velocity.at[0].set(vel_p)

Now, we will write a function that will perform one timestep update.

def step(i, kwargs):
    # Shape function and its derivative
    N = shape_fn(kwargs["position"][i-1])
    dN = jacobian(shape_fn)(kwargs["position"][i-1])

    # Nodal mass and momentum
    mass_n = N * kwargs["mass_p"]
    momentum_n = N * kwargs["momentum_p"]

    # Boundary conditions
    momentum_n = momentum_n.at[0].set(0)

    # External forces
    f_ext = jnp.array([0., 0.])

    # Internal forces
    f_int = -dN * kwargs["vol_p"] * kwargs["stress_p"]

    # Total nodal forces
    f_total = f_ext + f_int
    f_total = f_total.at[0].set(0)

    momentum_n += f_total * kwargs["dt"]

    # Update particle position and velocity
    vel_p = jnp.sum(kwargs["dt"] * N * f_total / mass_n) + kwargs["velocity"][i-1]
    pos_p = jnp.sum(kwargs["dt"] * N * momentum_n / mass_n) + kwargs["position"][i-1]

    # Update particle momentum
    kwargs["momentum_p"] = kwargs["mass_p"] * vel_p

    # Map nodal velocity
    vel_n = kwargs["mass_p"] * vel_p * jnp.divide(N, mass_n)
    vel_n = vel_n.at[0].set(0)

    # Strain rate at the particle
    strain_rate_p = jnp.dot(dN, vel_n)
    # Strain increment
    dstrain_p = strain_rate_p * dt
    # Strain
    kwargs["strain_p"] += dstrain_p
    kwargs["stress_p"] += kwargs["E"] * dstrain_p

    kwargs["velocity"] = kwargs["velocity"].at[i].set(vel_p)
    kwargs["position"] = kwargs["position"].at[i].set(pos_p)

    # Compute and store strain energy
    se = 0.5 * kwargs["stress_p"] * kwargs["strain_p"] * kwargs["vol_p"]
    kwargs["strain_energy"] = kwargs["strain_energy"].at[i].set(se)

    # Compute and store kinetic energy
    ke = 0.5 * vel_p**2 * kwargs["mass_p"]
    kwargs["kinetic_energy"] = kwargs["kinetic_energy"].at[i].set(ke)
    return kwargs

We can now use this function in the jax.lax.fori_loop() function to run the iterations.

kwargs = {
    "mass_p": mass_p,
    "vol_p": vol_p,
    "stress_p": stress_p,
    "strain_p": strain_p,
    "momentum_p": momentum_p,
    "velocity": velocity,
    "position": position,
    "strain_energy": strain_energy,
    "kinetic_energy": kinetic_energy,
    "E": E,
    "dt": dt,
}
result = lax.fori_loop(1, len(time), step, kwargs)

We can compare this result with the analytical solution which can be evaluated as follows:

def analytical_vibration(E, rho, v0, x_loc, duration, dt, L):
    omega = 1 / L * jnp.sqrt(E / rho)
    t = jnp.arange(0, duration, dt)
    v = v0 * jnp.cos(omega * t)
    x = x_loc * jnp.exp(v0 / (L * omega) * jnp.sin(omega * t))
    return x, v


xa, va = analytical_vibration(E, rho, v0, x_loc, T, dt, L)
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(16, 6))
ax[0].plot(time, va, "r", label="analytical")
ax[0].plot(time, result["velocity"], "ob", markersize=2, label="mpm")
ax[0].legend()
ax[0].set_title("Velocity")

ax[1].plot(time, xa, "r", label="analytical")
ax[1].plot(time, result["position"], "ob", markersize=2, label="mpm")
ax[1].legend()
ax[1].set_title("Position")
fig.savefig("../assets/mpm_plots.png")
mpm_plots.png