Installation
This guide provides detailed instructions for installing FEAX and its dependencies.
JAX
FEAX requires JAX as its core dependency. JAX provides automatic differentiation and JIT compilation capabilities.
CPU-only Installation
For CPU-only usage (Linux/macOS/Windows):
pip install -U jax
This is sufficient for learning and small-scale problems.
NVIDIA GPU (CUDA 12)
For NVIDIA GPU acceleration with CUDA 12:
pip install -U "jax[cuda12]"
Requirements: NVIDIA driver ≥525
NVIDIA GPU (CUDA 13)
For CUDA 13:
pip install -U "jax[cuda13]"
For more details, see the official JAX installation guide.
PyPI
Install FEAX
Once JAX is installed, install FEAX using pip:
pip install feax
This will automatically install all required dependencies including:
numpy,scipy- Numerical computingmeshio- Mesh I/O operationsgmsh- Mesh generationfenics-basix- Finite element basis functionsmatplotlib- Visualizationpandas- Data handlingequinox- Neural networks / pytree utilitieslineax- Linear solvers
Install from Source
To get the latest development version:
pip install git+https://github.com/Naruki-Ichihara/feax.git@main
For development with editable install:
git clone https://github.com/Naruki-Ichihara/feax.git
cd feax
pip install -e .
Optional Extras
FEAX provides optional dependency groups via pyproject.toml:
| Extra | Contents | Usage |
|---|---|---|
.[cuda12] | JAX (cuda12) + cuDSS for CUDA 12 | GPU acceleration (CUDA 12) |
.[cuda13] | cuDSS + cuBLAS + cuDNN for CUDA 13 | GPU acceleration (CUDA 13, without JAX) |
.[jax] | jax[cuda13] | JAX for CUDA 13 (use with cuda13) |
.[dev] | pytest, black, ruff, mypy | Development and testing |
For GPU use outside Docker, combine cuda13 with jax:
pip install "feax[cuda13,jax]"
cuDSS Direct Solver
When using the cuDSS direct solver, the spineax package is also required:
pip install --no-build-isolation git+https://github.com/johnviljoen/spineax.git
Note: This is pre-installed in the Docker image.
Docker
FEAX provides a Dockerfile based on NVIDIA's JAX image (nvcr.io/nvidia/jax:25.10-py3), which includes JAX pre-compiled for CUDA 13. JAX is not reinstalled during the build.
Build Arguments
| Argument | Default | Description |
|---|---|---|
INSTALL_DOCS | false | Install Node.js 20 + pydoc-markdown + Docusaurus dependencies |
Build Docker Image
Standard build:
git clone https://github.com/Naruki-Ichihara/feax.git
cd feax
docker build -t feax:latest .
Build with Docusaurus support (for docs development):
docker build --build-arg INSTALL_DOCS=true -t feax:latest .
Run Container
docker run --gpus all -it feax:latest
Docker Compose (Recommended)
For development with persistent volumes and GPU support:
cd feax
docker-compose up -d
docker exec -it feax bash
The docker-compose configuration includes:
- GPU support (
deploy.resources.reservations.devices) - Volume mounting for development (
./:/workspace) - WSL2 display support (for GUI applications like gmsh)
- Shared memory configuration (
shm_size: 4gb)
Docs Development
To develop the documentation site locally, Node.js 20+ and pydoc-markdown are required.
Install Dependencies
# Node.js 20
curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
apt-get install -y nodejs
# pydoc-markdown (API doc generator)
pip install pydoc-markdown
# Docusaurus + npm packages
cd docs
npm install
Start Dev Server
Use the provided script to generate API docs and start the server in one step:
./docs/dev.sh
Or manually:
cd docs
npm run api:generate # Generate API reference from Python source
npm run start # Start dev server at http://localhost:3000
Build for Production
cd docs
npm run build
Colab
You can use FEAX in Google Colab notebooks with free GPU/TPU access.
Basic Setup
In a Colab notebook cell, first install system dependencies for gmsh:
!apt update
!apt install -y libglu1 libxcursor-dev libxft2 libxinerama1 libfltk1.3-dev libfreetype6-dev libgl1-mesa-dev libocct-foundation-dev libocct-data-exchange-dev
!pip install feax
With GPU Support
Colab provides CUDA-enabled GPUs by default. Install system dependencies and FEAX:
!apt update
!apt install -y libglu1 libxcursor-dev libxft2 libxinerama1 libfltk1.3-dev libfreetype6-dev libgl1-mesa-dev libocct-foundation-dev libocct-data-exchange-dev
!pip install feax
# Verify GPU is available
import jax
print(jax.devices()) # Should show GPU
Verification
After installation, verify your setup:
import jax
import feax
import jax.numpy as np
print(f"JAX version: {jax.__version__}")
print(f"FEAX version: {feax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"JAX 64-bit enabled: {jax.config.jax_enable_x64}")
Expected output:
JAX version: 0.4.x
FEAX version: 0.1.0
JAX devices: [CpuDevice(id=0)] or [GpuDevice(id=0)]
JAX 64-bit enabled: True