FEAX
A fully differentiable finite element engine built on JAX, designed for gradient-based optimization and machine learning on PDE simulations.
JAX Transformations
All solvers work seamlessly with jax.jit, jax.grad, and jax.vmap, and arbitrary compositions such as jit(grad(...)).
GPU Direct Solver
Native cuDSS integration for sparse direct solves on GPU, with automatic matrix property detection (General / Symmetric / SPD).
End-to-End Differentiability
Gradients flow through assembly, boundary conditions, linear/nonlinear solvers, and post-processing — enabling topology optimization, inverse problems, and physics-informed learning.
Quick Example
3D cantilever beam under traction — solve and compute gradients:
import feax as fe
import jax
import jax.numpy as np
mesh = fe.mesh.box_mesh((100, 10, 10), mesh_size=2)
E, nu = 70e3, 0.3
class LinearElasticity(fe.problem.Problem):
def get_tensor_map(self):
def stress(u_grad, *args):
mu = E / (2. * (1. + nu))
lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))
eps = 0.5 * (u_grad + u_grad.T)
return lmbda * np.trace(eps) * np.eye(self.dim) + 2 * mu * eps
return stress
def get_surface_maps(self):
def surface_map(u, x, traction_mag):
return np.array([0., 0., traction_mag])
return [surface_map]
left = lambda point: np.isclose(point[0], 0., atol=1e-5)
right = lambda point: np.isclose(point[0], 100., atol=1e-5)
problem = LinearElasticity(mesh, vec=3, dim=3, location_fns=[right])
bc_config = fe.DCboundary.DirichletBCConfig([
fe.DCboundary.DirichletBCSpec(location=left, component="all", value=0.)
])
bc = bc_config.create_bc(problem)
traction = fe.InternalVars.create_uniform_surface_var(problem, 1e-3)
internal_vars = fe.InternalVars(volume_vars=(), surface_vars=[(traction,)])
solver = fe.create_solver(problem, bc,
solver_options=fe.DirectSolverOptions(), iter_num=1,
internal_vars=internal_vars)
initial = fe.zero_like_initial_guess(problem, bc)
# Solve
sol = solver(internal_vars, initial)
# Differentiate through the entire solve
grad_fn = jax.grad(lambda iv: np.sum(solver(iv, initial) ** 2))
grads = grad_fn(internal_vars)
Getting Started
New to FEAX? Start here:
- Installation - Install FEAX and its dependencies
- Basic Tutorials - Learn the fundamentals with hands-on examples
API Reference
For detailed API documentation, see the API Reference section.
License
FEAX is licensed under the GNU General Public License v3.0.