#66
# 1. Conceptual target
A JAX-based neural mass model in VBI should:
* Rep…resent **local dynamics via ODE/SDE**
* Support **vectorized simulation (batch / SBI)**
* Be **backend-consistent (same API as numba/C++ models) as much as possible**
At the modeling level:
* Neural mass models = **low-dimensional ODE systems per node + coupling**
* Whole-brain = **local dynamics + connectivity + delays + noise**
That's the final goal, start with the most simple case.
Start with exploring neural_mass.py from [vbjax](https://github.com/ins-amu/vbjax/blob/main/vbjax/neural_mass.py).
# 2. Design philosophy
## 2.1 Functional core (JAX-native)
Avoid class-heavy logic internally. Use:
* pure functions
* immutable parameters
* explicit state passing
## 2.2 Two-layer abstraction
### Layer 1: Model definition
* equations
* parameters
* initial state
### Layer 2: Simulation engine
* integrator (Heun / Euler / scan)
* coupling
* noise
* batching
👉 vbjax already follows this separation (neural_mass + simulation utilities).
---
# 3. Recommended VBI-JAX model API
You want **one consistent interface across backends** (your earlier idea is correct).
### Minimal interface
```python
class JaxModel:
def __init__(self, params):
self.params = params
def rhs(self, x, t, inputs):
"""dx/dt = f(x, params, inputs)"""
def noise(self, key, shape):
"""optional stochastic term"""
def initial_state(self, key):
"""x0"""
def get_parameters(self):
return self.params
```
---
### Strong recommendation (based on your goals)
Unify ALL models (numba / C++ / JAX) to:
```python
model(x, t, params, inputs) -> dxdt
```
This lets you:
* JIT everything
* plug into SBI
* auto-generate code later idea
---
# 4. File structure (for vbi)
Inside:
```
vbi/models/
jax/
base.py
integrators.py
coupling.py
noise.py
mpr.py
wongwang.py
```
---
# 5. Implementation blueprint (step-by-step)
## Step 1 Define parameters as structured PyTree
Use NamedTuple / dataclass:
```python
from typing import NamedTuple
class MPRParams(NamedTuple):
tau: float
g: float
sigma: float
```
✔ Works with JAX transformations
✔ Clean for SBI
---
## Step 2 Define RHS (core model)
```python
def mpr_rhs(x, t, params, coupling):
# x: (n_nodes, n_state)
# coupling: input from network
v = x[..., 0]
w = x[..., 1]
dv = ...
dw = ...
return jnp.stack([dv, dw], axis=-1)
```
---
## Step 3 Add coupling (IMPORTANT)
Design it separately:
```python
def diffusive_coupling(x, weights):
return weights @ x
```
Then inside simulation:
```python
inputs = coupling_fn(x, SC)
dx = rhs(x, t, params, inputs)
```
👉 This mirrors TVB logic:
* local dynamics + weighted inputs ([[PMC](https://pmc.ncbi.nlm.nih.gov/articles/PMC11674928/?utm_source=chatgpt.com)][1])
---
## Step 4 Integrator (JAX style)
Use `lax.scan` (as you already do)
```python
def step(x, key):
noise = sigma * random.normal(key, x.shape)
dx = rhs(x, t, params, inputs)
x_new = x + dt * dx + noise
return x_new, x_new
_, traj = lax.scan(step, x0, keys)
```
---
## Step 5 Vectorization (CRITICAL for SBI)
Two levels:
### A. Multi-node
already inside state shape
### B. Multi-simulation
```python
vmap(run_simulation)(batch_params)
```
---
## Step 6 Noise design
Support:
* shared noise (same across simulation, different across nodes)
* independent noise
* deterministic mode
```python
def noise(key, shape, same_noise=False):
if same_noise:
z = random.normal(key, (shape[-1],))
return jnp.broadcast_to(z, shape)
else:
return random.normal(key, shape)
```
---
## Step 7 BOLD / observation model
Keep separate:
```python
def bold_transform(neural_activity, params):
return bold_signal
```
👉 Do NOT mix with neural dynamics
---
# 6. vbjax-inspired patterns (what to reuse)
From vbjax neural_mass design:
### Key ideas you should copy:
1. **Separation of dynamics and integration**
2. **Stateless model functions**
3. **Explicit parameter passing**
4. **Composable pipeline**
---
# 7. Standard template (your reusable skeleton)
Use this as your base file:
```python
class JaxNeuralMass:
def __init__(self, params):
self.params = params
def rhs(self, x, t, inputs):
raise NotImplementedError
def step(self, x, key, inputs):
dx = self.rhs(x, 0.0, inputs)
noise = self.noise(key, x.shape)
return x + dt * dx + noise
def run(self, x0, keys, coupling):
def body(x, key):
inputs = coupling(x)
x_new = self.step(x, key, inputs)
return x_new, x_new
_, traj = lax.scan(body, x0, keys)
return traj
```
---
# 8. Advanced features (you WILL need)
## 8.1 Delays (TVB-style)
* buffer-based (ring buffer)
* indexed access
```python
buffer[t - delay_idx]
```
---
## 8.2 Sparse connectivity (important for scale)
Use CSR:
```python
jnp.take(buffer, indices) * weights
```
---
## 8.3 Chunked vmap
For large SBI:
```python
for chunk in batches:
vmap(run)(chunk)
```
---
# 9. Testing & validation (must-have)
You should standardize:
### Unit tests
* shape consistency
* deterministic runs (fixed key)
### Scientific tests
* reproduce known dynamics
* parameter sensitivity
### Cross-backend tests
* JAX vs numba vs C++
---
# 10. Common pitfalls (based on your current work)
### ❌ Mixing model + integrator
→ breaks composability
### ❌ Hidden state in class
→ breaks JIT / vmap
### ❌ non-reproducible RNG
→ always use `fold_in(step)`
---
# 11. Suggested roadmap for you
### Phase 1
* implement **MPR JAX model (minimal)**
* match vbjax behavior
### Phase 2
* unify API across backends
### Phase 3
* add:
* delays
* BOLD
* SBI-ready batch pipeline
### Phase 4
* build **auto-generator**
---
# 12. Optional: future-proof design
Design models as:
```python
model_spec = {
"state_vars": [...],
"parameters": {...},
"equations": ...
}
```