Lattice Structure Homogenization
This tutorial demonstrates computational homogenization of lattice structures using FEAX's flat toolkit. We compute the effective stiffness tensor of a BCC (Body-Centered Cubic) lattice using periodic boundary conditions and graph-based structure definition.
Overview
Computational homogenization determines effective material properties of periodic microstructures by:
- Defining a representative unit cell with periodic boundary conditions
- Applying prescribed macroscopic strain states
- Computing volume-averaged stress response
- Assembling the homogenized stiffness tensor
The relation between average stress and strain:
The feax.flat Toolkit
FEAX provides the flat module for periodic structures and homogenization:
import feax.flat as flat
Key modules:
flat.unitcell- Unit cell base class with boundary detectionflat.graph- Graph-based lattice structure generationflat.pbc- Periodic boundary condition utilitiesflat.solver- Specialized homogenization solversflat.utils- Visualization tools (stiffness sphere, etc.)
Problem Setup
Material Properties
import feax as fe
import feax.flat as flat
import jax.numpy as np
E_base = 210e9 # Pa (steel)
nu = 0.3
mesh_size = 0.1
Linear Elasticity Problem
class LinearElasticity(fe.problem.Problem):
def get_tensor_map(self):
def stress(u_grad, E, nu_val):
mu = E / (2.0 * (1.0 + nu_val))
lmbda = E * nu_val / ((1 + nu_val) * (1 - 2 * nu_val))
epsilon = 0.5 * (u_grad + u_grad.T)
return lmbda * np.trace(epsilon) * np.eye(self.dim) + 2 * mu * epsilon
return stress
Unit Cell Definition
Use flat.unitcell.UnitCell to define the computational domain:
class BCCUnitCell(flat.unitcell.UnitCell):
"""BCC lattice unit cell."""
def mesh_build(self, mesh_size):
return fe.mesh.box_mesh(size=1.0, mesh_size=mesh_size, element_type='HEX8')
# Create unit cell
unitcell = BCCUnitCell(mesh_size=mesh_size)
mesh = unitcell.mesh
Key features of UnitCell:
- Automatically computes bounding box (
unitcell.lb,unitcell.ub) - Provides boundary detection methods (
is_corner,is_edge,is_face) - Generates mapping functions for periodic pairings
- Compatible with
flat.pbc.periodic_bc_3D()
Graph-Based Lattice Structure
Use flat.graph to define strut-based lattice structures:
# Define BCC lattice: 8 corners + 1 center node
corners = np.array([[i, j, k] for i in [0, 1] for j in [0, 1] for k in [0, 1]], dtype=np.float32)
center = np.array([[0.5, 0.5, 0.5]], dtype=np.float32)
nodes = np.vstack([corners, center])
# BCC edges: all corners connect to center
edges = np.array([[i, 8] for i in range(8)])
# Create problem first
problem = LinearElasticity(mesh=mesh, vec=3, dim=3, ele_type='HEX8', location_fns=[])
# Create lattice density field using graph
lattice_func = flat.graph.create_lattice_function(nodes, edges, radius=0.1)
rho = flat.graph.create_lattice_density_field(problem, lattice_func,
density_solid=1.0, density_void=0.01)
How flat.graph works:
create_lattice_function(nodes, edges, radius)- Creates function that evaluates if point is near any strutcreate_lattice_density_field(problem, lattice_func, ...)- Evaluates lattice at element centroids- Returns element-based density array
(num_elements,)
Advantages:
- Clean node-edge representation
- Differentiable through JAX
- Efficient vectorized evaluation
- Works with arbitrary lattice topologies
Periodic Boundary Conditions
Use flat.pbc.periodic_bc_3D() for full 3D periodicity:
pairings = flat.pbc.periodic_bc_3D(unitcell, vec=3, dim=3)
P = flat.pbc.prolongation_matrix(pairings, mesh, vec=3)
What periodic_bc_3D() does:
- Creates pairings for all 3 face pairs (x, y, z directions)
- Creates pairings for all 12 edge pairs
- Creates pairings for all 7 corner pairs (origin excluded)
- Total: 25 geometric pairings × 3 components = 75 periodic constraints
The prolongation matrix maps reduced DOFs to full DOFs:
Internal Variables with Density
Use element-based variables for density-dependent properties:
bc_config = fe.DCboundary.DirichletBCConfig([])
bc = bc_config.create_bc(problem)
# Density-based Young's modulus (rho is already per-cell from create_lattice_density_field)
E_field = E_base * rho
nu_field = fe.internal_vars.InternalVars.create_cell_var(problem, nu)
internal_vars = fe.internal_vars.InternalVars(volume_vars=(E_field, nu_field), surface_vars=())
Why cell-based variables?
- Density field from
flat.graphis element-based - More efficient than quad-point based for homogenization
- Natural for topology optimization
!!! note
E_field is computed directly as E_base * rho since rho is already a per-cell array from create_lattice_density_field. The create_cell_var helper is only for uniform scalar values.
Homogenization Solver
Use flat.solver.create_homogenization_solver() to compute :
solver_options = fe.IterativeSolverOptions(
solver="cg", tol=1e-10, atol=1e-10, maxiter=10000, verbose=True
)
compute_C_hom = flat.solver.create_homogenization_solver(
problem, bc, P, mesh, solver_options=solver_options, dim=3
)
result = compute_C_hom(internal_vars)
C_hom = result.C_hom
How it works:
For 3D, the solver:
- Applies 6 unit strain cases:
- Solves each case with periodic BCs:
- Computes volume-averaged stress:
- Assembles stiffness matrix: (6×6 in Voigt notation)
Key properties:
- Fully differentiable w.r.t.
internal_vars(topology optimization) - Uses affine displacement method for efficiency
- Automatically handles periodic constraints via matrix
JIT Compilation Benchmark
The homogenization solver supports JAX JIT compilation for significant speedups:
import jax
import time
# Without JIT
t0 = time.time()
result = compute_C_hom(internal_vars)
jax.block_until_ready(result)
t_no_jit = time.time() - t0
# With JIT (1st call = compile + run)
compute_C_hom_jit = jax.jit(compute_C_hom)
t0 = time.time()
result = compute_C_hom_jit(internal_vars)
jax.block_until_ready(result)
t_jit_compile = time.time() - t0
# With JIT (2nd call = cached)
t0 = time.time()
result = compute_C_hom_jit(internal_vars)
jax.block_until_ready(result)
t_jit_cached = time.time() - t0
C_hom = result.C_hom
!!! tip After the first JIT-compiled call (which includes compilation overhead), subsequent calls use the cached compiled version and run significantly faster.
Extract Engineering Constants
For cubic symmetry materials:
C11 = C_hom[0, 0]
C12 = C_hom[0, 1]
C44 = C_hom[3, 3]
# Effective Young's modulus (assuming cubic symmetry)
E_eff = (C11 - C12) * (C11 + 2*C12) / (C11 + C12)
nu_eff = C12 / (C11 + C12)
G_eff = C44
print(f"Effective Young's modulus: {E_eff/1e9:.2f} GPa")
print(f"Effective Poisson's ratio: {nu_eff:.3f}")
print(f"Effective shear modulus: {G_eff/1e9:.2f} GPa")
print(f"Relative stiffness (E_eff/E_base): {E_eff/E_base:.3f}")
Visualization
Save Lattice Structure
import os
output_dir = os.path.join(os.path.dirname(__file__), "data", "vtk")
os.makedirs(output_dir, exist_ok=True)
lattice_file = os.path.join(output_dir, "bcc_lattice_structure.vtu")
fe.utils.save_sol(
mesh=mesh,
sol_file=lattice_file,
cell_infos=[("density", rho)]
)
Visualize Stiffness Sphere
Use flat.utils.visualize_stiffness_sphere() for directional stiffness:
sphere_file = os.path.join(output_dir, "bcc_stiffness_sphere.vtk")
flat.utils.visualize_stiffness_sphere(
C_hom,
output_file=sphere_file,
)
The stiffness sphere shows Young's modulus in each direction:
Interpretation:
- Sphere radius = directional stiffness
- Perfectly spherical = isotropic material
- Elongated = anisotropic (stiffer in certain directions)
Complete Code
import os
import time
import jax
import jax.numpy as np
import feax as fe
import feax.flat as flat
# Material properties
E_base = 210e9 # Pa (steel)
nu = 0.3
mesh_size = 0.1
class LinearElasticity(fe.problem.Problem):
def get_tensor_map(self):
def stress(u_grad, E, nu_val):
mu = E / (2.0 * (1.0 + nu_val))
lmbda = E * nu_val / ((1 + nu_val) * (1 - 2 * nu_val))
epsilon = 0.5 * (u_grad + u_grad.T)
return lmbda * np.trace(epsilon) * np.eye(self.dim) + 2 * mu * epsilon
return stress
class BCCUnitCell(flat.unitcell.UnitCell):
"""BCC lattice unit cell."""
def mesh_build(self, mesh_size):
return fe.mesh.box_mesh(size=1.0, mesh_size=mesh_size, element_type='HEX8')
# Create unit cell
unitcell = BCCUnitCell(mesh_size=mesh_size)
mesh = unitcell.mesh
# Define BCC lattice structure
corners = np.array([[i, j, k] for i in [0, 1] for j in [0, 1] for k in [0, 1]], dtype=np.float32)
center = np.array([[0.5, 0.5, 0.5]], dtype=np.float32)
nodes = np.vstack([corners, center])
edges = np.array([[i, 8] for i in range(8)])
# Create problem and density field
problem = LinearElasticity(mesh=mesh, vec=3, dim=3, ele_type='HEX8', location_fns=[])
lattice_func = flat.graph.create_lattice_function(nodes, edges, radius=0.1)
rho = flat.graph.create_lattice_density_field(problem, lattice_func, density_solid=1.0, density_void=0.01)
# Periodic boundary conditions
pairings = flat.pbc.periodic_bc_3D(unitcell, vec=3, dim=3)
P = flat.pbc.prolongation_matrix(pairings, mesh, vec=3)
# Boundary conditions and internal variables
bc = fe.DCboundary.DirichletBCConfig([]).create_bc(problem)
E_field = E_base * rho # rho is already per-cell from create_lattice_density_field
nu_field = fe.internal_vars.InternalVars.create_cell_var(problem, nu)
internal_vars = fe.internal_vars.InternalVars(volume_vars=(E_field, nu_field), surface_vars=())
# Homogenization
solver_options = fe.IterativeSolverOptions(solver="cg", tol=1e-10, atol=1e-10, maxiter=10000, verbose=True)
compute_C_hom = flat.solver.create_homogenization_solver(
problem, bc, P, mesh, solver_options=solver_options, dim=3
)
# Benchmark: without JIT
t0 = time.time()
result = compute_C_hom(internal_vars)
jax.block_until_ready(result)
t_no_jit = time.time() - t0
# Benchmark: with JIT (1st call = compile + run)
compute_C_hom_jit = jax.jit(compute_C_hom)
t0 = time.time()
result = compute_C_hom_jit(internal_vars)
jax.block_until_ready(result)
t_jit_compile = time.time() - t0
# Benchmark: with JIT (2nd call = cached)
t0 = time.time()
result = compute_C_hom_jit(internal_vars)
jax.block_until_ready(result)
t_jit_cached = time.time() - t0
C_hom = result.C_hom
# Extract properties
C11, C12, C44 = C_hom[0, 0], C_hom[0, 1], C_hom[3, 3]
E_eff = (C11 - C12) * (C11 + 2*C12) / (C11 + C12)
nu_eff = C12 / (C11 + C12)
G_eff = C44
print(f"Effective Young's modulus: {E_eff/1e9:.2f} GPa")
print(f"Effective Poisson's ratio: {nu_eff:.3f}")
print(f"Effective shear modulus: {G_eff/1e9:.2f} GPa")
print(f"Relative stiffness (E_eff/E_base): {E_eff/E_base:.3f}")
# Save results
output_dir = os.path.join(os.path.dirname(__file__), "data", "vtk")
os.makedirs(output_dir, exist_ok=True)
fe.utils.save_sol(mesh=mesh, sol_file=os.path.join(output_dir, "bcc_lattice_structure.vtu"),
cell_infos=[("density", rho)])
flat.utils.visualize_stiffness_sphere(C_hom, output_file=os.path.join(output_dir, "bcc_stiffness_sphere.vtk"))
Advanced Topics
Custom Lattice Topologies
Define any lattice using node-edge graphs:
# Octet truss lattice
nodes = np.array([
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], # Bottom
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], # Top
[0.5, 0.5, 0.5] # Center
])
edges = np.array([
[0, 8], [1, 8], [2, 8], [3, 8], # Bottom to center
[4, 8], [5, 8], [6, 8], [7, 8], # Top to center
])
lattice_func = flat.graph.create_lattice_function(nodes, edges, radius=0.1)
2D Homogenization
For 2D problems (plane stress/strain):
mesh = fe.mesh.rectangle_mesh(Nx=32, Ny=32, domain_x=1.0, domain_y=1.0)
unitcell = MyUnitCell2D() # Implement mesh_build() for 2D
# 2D periodic BCs (only x, y directions)
compute_C_hom = flat.solver.create_homogenization_solver(
problem, bc, P, mesh, solver_options=solver_options, dim=2
)
# Returns 3×3 stiffness matrix (ε11, ε22, γ12)
Summary
Key concepts:
flat.unitcell.UnitCell- Abstract base for unit cell definitionflat.graph- Node-edge graph → density fieldflat.pbc.periodic_bc_3D()- Automatic 3D periodic constraintsflat.solver.create_homogenization_solver()- Computesflat.utils.visualize_stiffness_sphere()- Directional stiffness visualization
Workflow:
- Define
UnitCellsubclass withmesh_build() - Create lattice structure using
flat.graph - Apply periodic BCs with
flat.pbc.periodic_bc_3D() - Compute homogenized stiffness with
flat.solver - Visualize results with
flat.utils
Further Reading
- Periodic Boundary Conditions - Detailed PBC tutorial
examples/advance/lattice_homogenization.py- Complete working example- API: flat.graph - Graph-based structure generation
- API: flat.solver - Homogenization solvers