Skip to main content
FEAX Logo

FEAX

A fully differentiable finite element engine built on JAX, designed for gradient-based optimization and machine learning on PDE simulations.

⚙️

JAX Transformations

All solvers work seamlessly with jax.jit, jax.grad, and jax.vmap, and arbitrary compositions such as jit(grad(...)).

🚀

GPU Direct Solver

Native cuDSS integration for sparse direct solves on GPU, with automatic matrix property detection (General / Symmetric / SPD).

End-to-End Differentiability

Gradients flow through assembly, boundary conditions, linear/nonlinear solvers, and post-processing — enabling topology optimization, inverse problems, and physics-informed learning.

Quick Example

3D cantilever beam under traction — solve and compute gradients:

import feax as fe
import jax
import jax.numpy as np

mesh = fe.mesh.box_mesh((100, 10, 10), mesh_size=2)
E, nu = 70e3, 0.3

class LinearElasticity(fe.problem.Problem):
def get_tensor_map(self):
def stress(u_grad, *args):
mu = E / (2. * (1. + nu))
lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))
eps = 0.5 * (u_grad + u_grad.T)
return lmbda * np.trace(eps) * np.eye(self.dim) + 2 * mu * eps
return stress

def get_surface_maps(self):
def surface_map(u, x, traction_mag):
return np.array([0., 0., traction_mag])
return [surface_map]

left = lambda point: np.isclose(point[0], 0., atol=1e-5)
right = lambda point: np.isclose(point[0], 100., atol=1e-5)

problem = LinearElasticity(mesh, vec=3, dim=3, location_fns=[right])

bc_config = fe.DCboundary.DirichletBCConfig([
fe.DCboundary.DirichletBCSpec(location=left, component="all", value=0.)
])
bc = bc_config.create_bc(problem)

traction = fe.InternalVars.create_uniform_surface_var(problem, 1e-3)
internal_vars = fe.InternalVars(volume_vars=(), surface_vars=[(traction,)])

solver = fe.create_solver(problem, bc,
solver_options=fe.DirectSolverOptions(), iter_num=1,
internal_vars=internal_vars)
initial = fe.zero_like_initial_guess(problem, bc)

# Solve
sol = solver(internal_vars, initial)

# Differentiate through the entire solve
grad_fn = jax.grad(lambda iv: np.sum(solver(iv, initial) ** 2))
grads = grad_fn(internal_vars)

Getting Started

New to FEAX? Start here:

  1. Installation - Install FEAX and its dependencies
  2. Basic Tutorials - Learn the fundamentals with hands-on examples

API Reference

For detailed API documentation, see the API Reference section.

License

FEAX is licensed under the GNU General Public License v3.0.