Discovering Causal Relationships through Gradient Descent
Learning DAGs from Data
Imagine you’ve collected data on student study hours, sleep patterns, and exam scores. You notice strong correlations between these variables. However, we know correlation doesn’t imply causation. Does more studying cause better exam performance? Does sleep quality affect study efficiency? Or do other factors drive these relationships?
Unlike standard machine learning that focuses on prediction, causal inference seeks to understand the underlying mechanisms: what causes what.
Traditionally, establishing causality required controlled experiments where we manipulate variables and observe outcomes. But what if experiments are impossible, unethical, or too expensive? Can we infer causality from observational data alone?
Approaches differ from constraint-based methods (like the PC algorithm) that test conditional independence, and score-based methods that search for graph structures optimizing certain criteria. Both approaches face a fundamental hurdle: learning a Directed Acyclic Graph (DAG) is a combinatorial optimization problem with a super-exponential search space ie. there are
The Challenge
Learning Dags from data
Simplistically at it’s core the problem is about finding a Directed Acyclic Graph (DAG) from a data generating process. A node would represent a variable, and the edge would represent the relationship. The edge is weighted to represent the strength of the relationship.
Traditional Approaches and Their Limitations
Previous methods for causal discovery generally fall into two categories:
- Constraint-Based Methods (like PC and FCI algorithms):
- Test conditional independence between variables
- Build a graph consistent with these independence relationships
- Advantages: Require fewer assumptions about the data distribution
- Limitations: Sensitive to errors in independence tests, especially with limited data
- Score-Based Methods (like GES and various Bayesian approaches):
- Define a score function that measures how well a graph explains the data
- Search through the space of possible graphs to maximize this score
- Advantages: Can incorporate prior knowledge and handle uncertainty
- Limitations: The search space grows super-exponentially with the number of variables
Both of these class of approaches as discussed by the paper struggle with one thing: ensuring dag acyclicity. The space of all possible graph is large but only a fraction are acyclic.
import numpy as np
import networkx as nx
from utils import timing
def generate_random_dag(d, s0):
"""Generate a random DAG with d nodes and s0 expected edges."""
# Create a random lower triangular matrix with random weights
A = np.zeros((d, d))
for i in range(d):
for j in range(i):
if np.random.rand() < s0/(d-1):
A[j, i] = np.random.uniform(0.5, 2.0) * np.random.choice([-1, 1])
# Randomly permute the nodes to get a random DAG
P = np.random.permutation(np.eye(d))
A = P.T @ A @ P
return A
def generate_sem_data(W, n, noise_scale=1):
"""Generate data from a linear SEM X = XW + Z."""
d = W.shape[0]
I = np.eye(d)
Z = np.random.normal(0, noise_scale, size=(n, d))
# For NOTEARS, we need W where W[i,j] means jâi
W_true = W.T
# Use matrix form: X = (I-W')^(-1)Z
X = Z @ np.linalg.inv(I - W_true.T)
return X, W_true
np.random.seed(42)
nodes, edges = 5, 2
samples = 1000
A = generate_random_dag(nodes, edges)
X, W_true = generate_sem_data(A, 1000)
A
from utils import visualize_dag
visualize_dag(W_true)
The following dag represents causal relationships where the nodes are the variables we’re exploring and the arrows indicate the direction to where the variable directly causes or effects the other. The edge weights represent the strength of the causal effect. A positive increase in cause increase in effect whilst a negative means increase in cause decrease in effect.
Relationships in this toy example:
- X3 has a negative effect (-1.93) on X1
- X3 has a positive effect (1.95) on X0
- X0 influences X2 with a negative effect (-1.99)
- X1 has a small negative effect (-0.65) on X2
- X4 appears isolated (no edges connecting to it)
In layman’s terms this could mean
- X3 could be considered a “common cause” for X2, X3, and X0
- X4 is independent of the other variables (a confounding variable that wasn’t detected or has no relationship)
This scenario is ideal for this algorithm to work because:
- It has no cycles
- Weighted edges representing the strength of causal effects
- A mix of relationships (some variables with multiple parents/children, some isolated)
Benchmark
Before diving into the implementation, it’s valuable to understand how NOTEARS compares to traditional causal discovery approaches. This comparison helps illustrate why gradient-based methods represent such a significant advancement in the field.
from causallearn.search.ScoreBased.GES import ges
from causallearn.search.ConstraintBased.PC import pc
from utils import plot_dag_comparison
@timing
def run_ges(X):
cg = ges(X, score_func="local_score_BIC")
G = cg['G'].graph
# Convert to adjacency matrix
return nx.to_numpy_array(nx.DiGraph(G))
@timing
def run_pc(X):
record = pc(X, score_func="local_score_BIC")
G = record.G.graph
# Convert to adjacency matrix
return nx.to_numpy_array(nx.DiGraph(G))
W_est_ges, time_taken_ges = run_ges(X)
W_est_pc, time_taken_pc = run_pc(X)
plot_dag_comparison(W_true, W_est_ges)
plot_dag_comparison(W_true, W_est_pc)
What PC Got Right
- Correct Edge Detection: PC successfully identified several true causal relationships:
- The connection between nodes 3 and 0 (though the direction is uncertain)
- The relationship involving node 2 as a hub/sink
- Some of the connections to node 1
- Sparsity: PC maintained a reasonable level of sparsity, avoiding the creation of an overly dense graph
Limitations
- Missing edges
- Teh algorithm missed the edge from 3 â 1 (weight of -1.93)
- Edge Direction
- This is a fundamental limitation of methods like PC and GES
- PC can only orient edges when there are clear v-structures (XâZâY) or other identifiable patterns
- Many edges remain unoriented or may be incorrectly oriented due to insufficient statistical power
- Strength of Casual Realtionship
- PC performs conditional independence tests that return binary results (independent/dependent)
- It builds the graph by deciding whether edges exist or not, without estimating effect strengths
Why these differences?
These differences happen for a number of reasons:
- Statistical Testing Limitations:
- PC relies on conditional independence tests, which can fail with:
- Limited sample sizes
- Weak relationships (small effect sizes)
- Non-linear relationships (even mild ones)
- Multiple testing
- PC performs many statistical tests, increasing the chance of false positives
Before explaining and implementing the no tears algorithm lets run a few scenarios to highlight the problems with the PC algorithm further. I’m going to import the evaluate_reconstruction function from my utils module which will calculate several important plots and metrics.
Performance Metrics
From the confusion matrix we can get
- TP - True Positive: Correctly identified edges that actually exist
- FP - False Positive: Incorrectly identified edges that don’t actually exist
- FN - False Negative: Missed edges that actually exist
- TN - True Negative: Correctly identified absence of edges
From these results we can calculate the:
- Precision: Is the fraction of correct edges
- Recall: Is the fraction of relevant edges that were retrieved
- F1 score: The harmonic mean between precision and recall
We can also compute the Structural Hamming Distance (SHD) which is a standard distance to compare graphs by their adjacency matrix.
Performance Plots
Left Plot: True DAG Structure (Binary Adjacency Matrix)
- Dark blue squares: Indicate true causal relationships (edges exist)
- White squares: No causal relationship exists
- Reading: If there’s a dark square at position (i,j), it means variable j causes variable i
Middle Plot: Estimated DAG Structure
- Shows what the PC algorithm discovered
- Dark blue squares: Edges the PC algorithm detected
- White squares: No relationship detected by PC
Right Plot: Edge Recovery Analysis
This is a confusion matrix showing the algorithm’s performance:
- Blue (TP - True Positive)
- Green (FP - False Positive)
- Red (FN - False Negative)
- White (TN - True Negative)
from utils import evaluate_reconstruction
def analyze_sample_size_effects():
sample_sizes = [100, 500, 1000, 2000]
results = []
for n in sample_sizes:
X_subset = X[:n, :] # Use first n samples
W_pc, _ = run_pc(X_subset)
metrics = evaluate_reconstruction(W_true, W_pc)
results.append({
'n_samples': n,
'f1': metrics['f1']
})
return results
analyze_sample_size_effects()
The PC algorithm faired moderately well on this graph structure and varying data points. The f1 metrics ranged between 0.5 and 0.6 However it missed important edges, the failure to detect 3â1 could be crucial in real applications There are also a number of important false discoveries and the green squares show PC is adding edges that don’t exist. These could be potentially misleading.
NOTEARS
NOTEARS (NonParametric Estimation of Acyclic diREcted graphS) introduced a remarkably elegant solution to reduce the search space: reformulate the acyclicity constraint as a differentiable function of the weight matrix W:
where:
- tr() is the trace operator (sum of diagonal elements)
is the matrix exponential of the Hadamard element-wise product - d is the number of variables
import jax
from jax import numpy as jnp
import scipy.optimize as sopt
from functools import partial
from jax import jit, value_and_grad
@jit
def h_acyclic(W):
"""
Calculate the acyclicity constraint value.
:param W: the weight data matrix W
:return:
"""
d = W.shape[0]
E = jax.scipy.linalg.expm(W * W)
return jnp.trace(E) - d
@jit
def l2_loss(X, W):
"""
Calculate L2 loss between data X and model X@W.
:param X: the input matrix
:param W: the weight matrix
:return: the l2 loss between X@W
"""
M = X @ W
R = X - M
return 0.5 / X.shape[0] * jnp.sum(R ** 2)
@partial(jit, static_argnums=(1,))
def _adj(w, d):
"""
Convert doubled variables [w_pos, w_neg] to weight matrix W = w_pos - w_neg.
NOTEARS uses a trick to handle both positive and negative edge weights while
maintaining non-negativity constraints during optimization. Each edge weight
W[i,j] is represented as w_pos[i,j] - w_neg[i,j] where both w_pos and w_neg
are constrained to be non-negative.
Parameters
----------
w : jax.numpy.ndarray
Flattened weight vector of length 2*d^2, containing [w_pos.flatten(), w_neg.flatten()]
d : int
Number of variables (nodes) in the DAG. Must be static for JIT compilation.
Returns
-------
jax.numpy.ndarray
Weight matrix W of shape (d, d) where W[i,j] represents the causal effect
of variable j on variable i
"""
d_squared = d * d
w_pos = w[:d_squared].reshape((d, d))
w_neg = w[d_squared:].reshape((d, d))
return w_pos - w_neg
Our Loss Function - Optimising with constraints
Imagine you’re trying to find the lowest point in a mountainous landscape (in this case minimizing the error from a function), but you’re not allowed to step outside a specific region (satisfying constraints). In NOTEARS, we want to:
- Minimize: How poorly our DAG explains the data (loss function)
- Subject to: The graph must be acyclic (no cycles allowed)
The challenge is that the “no cycles” constraint is complex and non-linear, making standard optimization difficult.
The Augmented Lagrangian method is a clever way to turn a constrained problem into a series of unconstrained problems. This is the initution behind it:
Step 1 Turn the constraints into penalties
Rather than having the hard constraint
Step 2 Make the penalty adaptive
If we had to use a fixed penalty the constraint might not be satisfied completely. Imagine you’re teaching a dog to stay in the garden whilst teaching him new tricks. If the penalty for leaving the garden would be losing one treat the dog might decide that it’s worth leaving the garden for one treat. The loss function has two methods to control this behaviour:
- Lagrange Multiplier
: similar to a memory parameter that keeps track of violations - In our example this is like remembering how many times the dog left the garden
- Penalty Parameter
: Like a “strictness knob” that makes violations increasingly expensive - Every time the dog leaves the garden it will cost him more treats
The loss function will now look smething like:
Note:
(Lagrange multiplier): Accumulates “debt” from past constraint violations (L1 regularization) : it encourages many elements of to be exactly zero (Quadratic constraint penalty): Makes current violations increasingly expensive
The following visualisation shows the method makes sure that the constraint is satisfied through the adaptive penalty with more iterations
from utils import create_augmented_lagrangian_visualization
create_augmented_lagrangian_visualization()
@partial(jit, static_argnums=(2,))
def augmented_lagrangian(w, X, d, alpha, rho, lambda1, loss_fun=l2_loss):
"""
Compute the augmented Lagrangian objective. This transforms the constrained NOTEARS problem into an unconstrained
optimization by incorporating the acyclicity constraint as adaptive penalties.
:param w: the weight matrix
:param X: the input matrix
:param d: the number of variables
:param alpha: the lagrange multiplier
:param rho: the penalty parameter
:param lambda1: the norm of the matrix
:return: data_loss + quadratic_penalty + linear_penalty + sparsity_penalty
"""
W = _adj(w, d)
loss = loss_fun(X, W)
h = h_acyclic(W)
return loss + 0.5 * rho * h * h + alpha * h + lambda1 * jnp.sum(w)
# Create compiled value_and_grad function
aug_lagrangian_with_grad = value_and_grad(augmented_lagrangian, argnums=0)
aug_lagrangian_with_grad = partial(jit, static_argnums=(2,))(aug_lagrangian_with_grad)
def create_dag_bounds(d):
"""
Create bounds for the optimization variables in the doubled-variable formulation.
In NOTEARS, each weight W[i,j] is represented as w_pos[i,j] - w_neg[i,j] where
both w_pos and w_neg are non-negative. This function creates bounds for the
flattened vector [w_pos.flatten(), w_neg.flatten()].
Parameters
----------
d : int
Number of variables (nodes) in the DAG
Returns
-------
list of tuple
List of (min, max) bounds for each optimization variable.
Length is 2*d^2. Diagonal elements are bounded to (0,0) to prevent
self-loops, off-diagonal elements are bounded to (0, None).
"""
bounds = []
for k in range(2): # For w_pos and w_neg
for i in range(d):
for j in range(d):
if i == j:
bounds.append((0, 0)) # No self-loops
else:
bounds.append((0, None)) # Non-negative weights
return bounds
def create_scipy_wrapper(w, X, d, alpha, rho, lambda1):
"""
Create a wrapper function for scipy optimizer that converts between JAX and NumPy.
SciPy's optimizers expect NumPy arrays and return NumPy arrays, while our
JAX implementation uses JAX arrays. This wrapper handles the conversion.
Parameters
----------
w : jax.numpy.ndarray
Current weight vector (used for function signature, not actual computation)
X : jax.numpy.ndarray
Data matrix of shape (n_samples, n_variables)
d : int
Number of variables (nodes) in the DAG
alpha : float
Lagrange multiplier for the acyclicity constraint
rho : float
Penalty parameter for the augmented Lagrangian
lambda1 : float
L1 regularization strength
Returns
-------
callable
Function that takes NumPy array and returns (objective_value, gradient)
suitable for SciPy's minimize function with jac=True
"""
def scipy_obj_and_grad(w_np):
w_jax = jnp.array(w_np)
obj, grad = aug_lagrangian_with_grad(w_jax, X, d, alpha, rho, lambda1)
return float(obj), np.array(grad)
return scipy_obj_and_grad
def optimize_weights(w_est, X, d, alpha, rho, lambda1, bounds):
"""
Perform one step of L-BFGS-B optimization for the augmented Lagrangian.
This function wraps SciPy's L-BFGS-B optimizer to solve the constrained
optimization problem in NOTEARS using the augmented Lagrangian method.
Parameters
----------
w_est : jax.numpy.ndarray
Current estimate of the weight vector (length 2*d^2)
X : jax.numpy.ndarray
Data matrix of shape (n_samples, n_variables)
d : int
Number of variables (nodes) in the DAG
alpha : float
Lagrange multiplier for the acyclicity constraint
rho : float
Penalty parameter for the augmented Lagrangian
lambda1 : float
L1 regularization strength
bounds : list of tuple
Bounds for each optimization variable from create_dag_bounds()
Returns
-------
jax.numpy.ndarray
Optimized weight vector of length 2*d^2
Notes
-----
Uses L-BFGS-B which is a quasi-Newton method that handles box constraints
efficiently. The method is well-suited for the smooth optimization problems
arising in the augmented Lagrangian formulation.
"""
w_np = np.array(w_est)
# Create the objective function wrapper
obj_func = create_scipy_wrapper(w_est, X, d, alpha, rho, lambda1)
# Run L-BFGS-B optimization
result = sopt.minimize(
obj_func, w_np, method='L-BFGS-B',
jac=True, bounds=bounds
)
# Return JAX array
return jnp.array(result.x)
def threshold_weights(W, threshold):
"""
Apply thresholding to remove small edges from the estimated DAG.
Sets edge weights with absolute value below the threshold to exactly zero.
This is a common post-processing step in causal discovery to produce
sparse graphs and remove weak, potentially spurious connections.
Parameters
----------
W : jax.numpy.ndarray
Weight matrix of shape (d, d) representing the DAG
threshold : float
Minimum absolute weight value to retain. Edges with |weight| < threshold
are set to zero
Returns
-------
jax.numpy.ndarray
Thresholded weight matrix with the same shape as W
"""
return jnp.where(jnp.abs(W) < threshold, 0, W)
@timing
def notears_linear(X, lambda1=0.1, max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3):
"""
Learn a DAG from data using the NOTEARS algorithm.
Args:
X (ndarray): [n, d] data matrix
lambda1 (float): L1 regularization parameter
max_iter (int): Maximum number of dual ascent steps
h_tol (float): Exit if |h(W)| <= h_tol
rho_max (float): Exit if rho >= rho_max
w_threshold (float): Remove edges with |weight| < threshold
Returns:
W_est (ndarray): [d, d] estimated DAG
"""
# Setup
n, d = X.shape
X = X - jnp.mean(X, axis=0) # Center the data
# Initialize parameters
w_est = jnp.zeros(2 * d * d) # [w_pos, w_neg]
rho, alpha, h = 1.0, 0.0, jnp.inf
# Create bounds for optimization
bounds = create_dag_bounds(d)
# Augmented Lagrangian optimization
for i in range(max_iter):
# Inner optimization
w_new = None
inner_converged = False
while rho < rho_max and not inner_converged:
# Optimize weights
w_new = optimize_weights(w_est, X, d, alpha, rho, lambda1, bounds)
# Check acyclicity
W_new = _adj(w_new, d)
h_new = h_acyclic(W_new)
# Update penalty parameter if needed
if h_new > 0.25 * h:
rho *= 10
else:
inner_converged = True
# Update estimates
w_est = w_new
h = h_acyclic(_adj(w_est, d))
alpha += rho * h
# Log progress
print(f"Iteration {i}: h={h:.6e}, rho={rho:.2e}")
# Check convergence
if h <= h_tol or rho >= rho_max:
break
# Final processing
W_est = _adj(w_est, d)
# Thresholding
W_est = jnp.where(jnp.abs(W_est) < w_threshold, 0, W_est)
print(f"Final h(W): {h_acyclic(W_est):.6e}")
# Transpose to get standard adjacency matrix where W[i,j]â 0 means iâj
W_est = W_est.T
return W_est
lambda1=0.05
loss_type='l2'
max_iter=100
h_tol=1e-8
rho_max=1e+16
w_threshold=0.3
estimated_A, time_taken_jax = notears_linear(X, lambda1, max_iter, h_tol, rho_max, w_threshold)
plot_dag_comparison(W_true, estimated_A)
evaluate_reconstruction(W_true, estimated_A)