Matthew Drago

Discovering Causal Relationships through Gradient Descent

Jupyter Notebook

Learning DAGs from Data

causal-relationships machine-learning

Imagine you’ve collected data on student study hours, sleep patterns, and exam scores. You notice strong correlations between these variables. However, we know correlation doesn’t imply causation. Does more studying cause better exam performance? Does sleep quality affect study efficiency? Or do other factors drive these relationships?

Unlike standard machine learning that focuses on prediction, causal inference seeks to understand the underlying mechanisms: what causes what.

Traditionally, establishing causality required controlled experiments where we manipulate variables and observe outcomes. But what if experiments are impossible, unethical, or too expensive? Can we infer causality from observational data alone?

Approaches differ from constraint-based methods (like the PC algorithm) that test conditional independence, and score-based methods that search for graph structures optimizing certain criteria. Both approaches face a fundamental hurdle: learning a Directed Acyclic Graph (DAG) is a combinatorial optimization problem with a super-exponential search space ie. there are possible edge configurations for n variables.

The Challenge

Learning Dags from data

Simplistically at it’s core the problem is about finding a Directed Acyclic Graph (DAG) from a data generating process. A node would represent a variable, and the edge would represent the relationship. The edge is weighted to represent the strength of the relationship.

Traditional Approaches and Their Limitations

Previous methods for causal discovery generally fall into two categories:

  1. Constraint-Based Methods (like PC and FCI algorithms):
  • Test conditional independence between variables
  • Build a graph consistent with these independence relationships
  • Advantages: Require fewer assumptions about the data distribution
  • Limitations: Sensitive to errors in independence tests, especially with limited data
  1. Score-Based Methods (like GES and various Bayesian approaches):
  • Define a score function that measures how well a graph explains the data
  • Search through the space of possible graphs to maximize this score
  • Advantages: Can incorporate prior knowledge and handle uncertainty
  • Limitations: The search space grows super-exponentially with the number of variables

Both of these class of approaches as discussed by the paper struggle with one thing: ensuring dag acyclicity. The space of all possible graph is large but only a fraction are acyclic.

import numpy as np
import networkx as nx

from utils import timing


def generate_random_dag(d, s0):
    """Generate a random DAG with d nodes and s0 expected edges."""
    # Create a random lower triangular matrix with random weights
    A = np.zeros((d, d))
    for i in range(d):
        for j in range(i):
            if np.random.rand() < s0/(d-1):
                A[j, i] = np.random.uniform(0.5, 2.0) * np.random.choice([-1, 1])

    # Randomly permute the nodes to get a random DAG
    P = np.random.permutation(np.eye(d))
    A = P.T @ A @ P

    return A

def generate_sem_data(W, n, noise_scale=1):
    """Generate data from a linear SEM X = XW + Z."""
    d = W.shape[0]
    I = np.eye(d)
    Z = np.random.normal(0, noise_scale, size=(n, d))

    # For NOTEARS, we need W where W[i,j] means j→i
    W_true = W.T

    # Use matrix form: X = (I-W')^(-1)Z
    X = Z @ np.linalg.inv(I - W_true.T)

    return X, W_true
np.random.seed(42)
nodes, edges = 5, 2
samples = 1000
A = generate_random_dag(nodes, edges)
X, W_true = generate_sem_data(A, 1000)
A
from utils import visualize_dag

visualize_dag(W_true)

Output

The following dag represents causal relationships where the nodes are the variables we’re exploring and the arrows indicate the direction to where the variable directly causes or effects the other. The edge weights represent the strength of the causal effect. A positive increase in cause increase in effect whilst a negative means increase in cause decrease in effect.

Relationships in this toy example:

  • X3 has a negative effect (-1.93) on X1
  • X3 has a positive effect (1.95) on X0
  • X0 influences X2 with a negative effect (-1.99)
  • X1 has a small negative effect (-0.65) on X2
  • X4 appears isolated (no edges connecting to it)

In layman’s terms this could mean

  • X3 could be considered a “common cause” for X2, X3, and X0
  • X4 is independent of the other variables (a confounding variable that wasn’t detected or has no relationship)

This scenario is ideal for this algorithm to work because:

  • It has no cycles
  • Weighted edges representing the strength of causal effects
  • A mix of relationships (some variables with multiple parents/children, some isolated)

Benchmark

Before diving into the implementation, it’s valuable to understand how NOTEARS compares to traditional causal discovery approaches. This comparison helps illustrate why gradient-based methods represent such a significant advancement in the field.

from causallearn.search.ScoreBased.GES import ges
from causallearn.search.ConstraintBased.PC import pc

from utils import plot_dag_comparison

@timing
def run_ges(X):
    cg = ges(X, score_func="local_score_BIC")
    G = cg['G'].graph
    # Convert to adjacency matrix
    return nx.to_numpy_array(nx.DiGraph(G))

@timing
def run_pc(X):
    record = pc(X, score_func="local_score_BIC")
    G = record.G.graph
    # Convert to adjacency matrix
    return nx.to_numpy_array(nx.DiGraph(G))

W_est_ges, time_taken_ges = run_ges(X)
W_est_pc, time_taken_pc = run_pc(X)
plot_dag_comparison(W_true, W_est_ges)

Output

plot_dag_comparison(W_true, W_est_pc)

Output

What PC Got Right

  1. Correct Edge Detection: PC successfully identified several true causal relationships:
  • The connection between nodes 3 and 0 (though the direction is uncertain)
  • The relationship involving node 2 as a hub/sink
  • Some of the connections to node 1
  1. Sparsity: PC maintained a reasonable level of sparsity, avoiding the creation of an overly dense graph

Limitations

  1. Missing edges
  • Teh algorithm missed the edge from 3 → 1 (weight of -1.93)
  1. Edge Direction
  • This is a fundamental limitation of methods like PC and GES
  • PC can only orient edges when there are clear v-structures (X→Z←Y) or other identifiable patterns
  • Many edges remain unoriented or may be incorrectly oriented due to insufficient statistical power
  1. Strength of Casual Realtionship
  • PC performs conditional independence tests that return binary results (independent/dependent)
  • It builds the graph by deciding whether edges exist or not, without estimating effect strengths

Why these differences?

These differences happen for a number of reasons:

  1. Statistical Testing Limitations:
  • PC relies on conditional independence tests, which can fail with:
    • Limited sample sizes
    • Weak relationships (small effect sizes)
    • Non-linear relationships (even mild ones)
  1. Multiple testing
  • PC performs many statistical tests, increasing the chance of false positives

Before explaining and implementing the no tears algorithm lets run a few scenarios to highlight the problems with the PC algorithm further. I’m going to import the evaluate_reconstruction function from my utils module which will calculate several important plots and metrics.

Performance Metrics

From the confusion matrix we can get

  • TP - True Positive: Correctly identified edges that actually exist
  • FP - False Positive: Incorrectly identified edges that don’t actually exist
  • FN - False Negative: Missed edges that actually exist
  • TN - True Negative: Correctly identified absence of edges

From these results we can calculate the:

  • Precision: Is the fraction of correct edges
  • Recall: Is the fraction of relevant edges that were retrieved
  • F1 score: The harmonic mean between precision and recall

We can also compute the Structural Hamming Distance (SHD) which is a standard distance to compare graphs by their adjacency matrix.

Performance Plots

Left Plot: True DAG Structure (Binary Adjacency Matrix)
  • Dark blue squares: Indicate true causal relationships (edges exist)
  • White squares: No causal relationship exists
  • Reading: If there’s a dark square at position (i,j), it means variable j causes variable i
Middle Plot: Estimated DAG Structure
  • Shows what the PC algorithm discovered
  • Dark blue squares: Edges the PC algorithm detected
  • White squares: No relationship detected by PC
Right Plot: Edge Recovery Analysis

This is a confusion matrix showing the algorithm’s performance:

  • Blue (TP - True Positive)
  • Green (FP - False Positive)
  • Red (FN - False Negative)
  • White (TN - True Negative)
from utils import evaluate_reconstruction

def analyze_sample_size_effects():
    sample_sizes = [100, 500, 1000, 2000]
    results = []

    for n in sample_sizes:
        X_subset = X[:n, :]  # Use first n samples
        W_pc, _ = run_pc(X_subset)
        metrics = evaluate_reconstruction(W_true, W_pc)
        results.append({
            'n_samples': n,
            'f1': metrics['f1']
        })

    return results
analyze_sample_size_effects()

Output

Output

Output

Output

The PC algorithm faired moderately well on this graph structure and varying data points. The f1 metrics ranged between 0.5 and 0.6 However it missed important edges, the failure to detect 3→1 could be crucial in real applications There are also a number of important false discoveries and the green squares show PC is adding edges that don’t exist. These could be potentially misleading.

NOTEARS

NOTEARS (NonParametric Estimation of Acyclic diREcted graphS) introduced a remarkably elegant solution to reduce the search space: reformulate the acyclicity constraint as a differentiable function of the weight matrix W:

where:

  • tr() is the trace operator (sum of diagonal elements)
  • is the matrix exponential of the Hadamard element-wise product
  • d is the number of variables
import jax
from jax import numpy as jnp
import scipy.optimize as sopt
from functools import partial
from jax import jit, value_and_grad


@jit
def h_acyclic(W):
    """
    Calculate the acyclicity constraint value.
    :param W: the weight data matrix W
    :return:
    """
    d = W.shape[0]
    E = jax.scipy.linalg.expm(W * W)
    return jnp.trace(E) - d

@jit
def l2_loss(X, W):
    """
    Calculate L2 loss between data X and model X@W.
    :param X: the input matrix
    :param W: the weight matrix
    :return: the l2 loss between X@W
    """
    M = X @ W
    R = X - M
    return 0.5 / X.shape[0] * jnp.sum(R ** 2)

@partial(jit, static_argnums=(1,))
def _adj(w, d):
    """
    Convert doubled variables [w_pos, w_neg] to weight matrix W = w_pos - w_neg.

    NOTEARS uses a trick to handle both positive and negative edge weights while
    maintaining non-negativity constraints during optimization. Each edge weight
    W[i,j] is represented as w_pos[i,j] - w_neg[i,j] where both w_pos and w_neg
    are constrained to be non-negative.

    Parameters
    ----------
    w : jax.numpy.ndarray
        Flattened weight vector of length 2*d^2, containing [w_pos.flatten(), w_neg.flatten()]
    d : int
        Number of variables (nodes) in the DAG. Must be static for JIT compilation.

    Returns
    -------
    jax.numpy.ndarray
        Weight matrix W of shape (d, d) where W[i,j] represents the causal effect
        of variable j on variable i
    """
    d_squared = d * d
    w_pos = w[:d_squared].reshape((d, d))
    w_neg = w[d_squared:].reshape((d, d))
    return w_pos - w_neg

Our Loss Function - Optimising with constraints

Imagine you’re trying to find the lowest point in a mountainous landscape (in this case minimizing the error from a function), but you’re not allowed to step outside a specific region (satisfying constraints). In NOTEARS, we want to:

  • Minimize: How poorly our DAG explains the data (loss function)
  • Subject to: The graph must be acyclic (no cycles allowed)

The challenge is that the “no cycles” constraint is complex and non-linear, making standard optimization difficult.

The Augmented Lagrangian method is a clever way to turn a constrained problem into a series of unconstrained problems. This is the initution behind it:

Step 1 Turn the constraints into penalties

Rather than having the hard constraint we convert it to:

Step 2 Make the penalty adaptive

If we had to use a fixed penalty the constraint might not be satisfied completely. Imagine you’re teaching a dog to stay in the garden whilst teaching him new tricks. If the penalty for leaving the garden would be losing one treat the dog might decide that it’s worth leaving the garden for one treat. The loss function has two methods to control this behaviour:

  • Lagrange Multiplier : similar to a memory parameter that keeps track of violations
    • In our example this is like remembering how many times the dog left the garden
  • Penalty Parameter : Like a “strictness knob” that makes violations increasingly expensive
    • Every time the dog leaves the garden it will cost him more treats

The loss function will now look smething like:

Note:

  • (Lagrange multiplier): Accumulates “debt” from past constraint violations
  • (L1 regularization) : it encourages many elements of to be exactly zero
  • (Quadratic constraint penalty): Makes current violations increasingly expensive

The following visualisation shows the method makes sure that the constraint is satisfied through the adaptive penalty with more iterations

from utils import create_augmented_lagrangian_visualization

create_augmented_lagrangian_visualization()

Output

@partial(jit, static_argnums=(2,))
def augmented_lagrangian(w, X, d, alpha, rho, lambda1, loss_fun=l2_loss):
    """
    Compute the augmented Lagrangian objective. This transforms the constrained NOTEARS problem into an unconstrained
    optimization by incorporating the acyclicity constraint as adaptive penalties.
    :param w: the weight matrix
    :param X: the input matrix
    :param d: the number of variables
    :param alpha: the lagrange multiplier
    :param rho: the penalty parameter
    :param lambda1: the norm of the matrix
    :return: data_loss + quadratic_penalty + linear_penalty + sparsity_penalty
    """
    W = _adj(w, d)
    loss = loss_fun(X, W)
    h = h_acyclic(W)
    return loss + 0.5 * rho * h * h + alpha * h + lambda1 * jnp.sum(w)

# Create compiled value_and_grad function
aug_lagrangian_with_grad = value_and_grad(augmented_lagrangian, argnums=0)
aug_lagrangian_with_grad = partial(jit, static_argnums=(2,))(aug_lagrangian_with_grad)
def create_dag_bounds(d):
    """
    Create bounds for the optimization variables in the doubled-variable formulation.

    In NOTEARS, each weight W[i,j] is represented as w_pos[i,j] - w_neg[i,j] where
    both w_pos and w_neg are non-negative. This function creates bounds for the
    flattened vector [w_pos.flatten(), w_neg.flatten()].

    Parameters
    ----------
    d : int
        Number of variables (nodes) in the DAG

    Returns
    -------
    list of tuple
        List of (min, max) bounds for each optimization variable.
        Length is 2*d^2. Diagonal elements are bounded to (0,0) to prevent
        self-loops, off-diagonal elements are bounded to (0, None).
    """
    bounds = []
    for k in range(2):  # For w_pos and w_neg
        for i in range(d):
            for j in range(d):
                if i == j:
                    bounds.append((0, 0))  # No self-loops
                else:
                    bounds.append((0, None))  # Non-negative weights
    return bounds

def create_scipy_wrapper(w, X, d, alpha, rho, lambda1):
    """
    Create a wrapper function for scipy optimizer that converts between JAX and NumPy.

    SciPy's optimizers expect NumPy arrays and return NumPy arrays, while our
    JAX implementation uses JAX arrays. This wrapper handles the conversion.

    Parameters
    ----------
    w : jax.numpy.ndarray
        Current weight vector (used for function signature, not actual computation)
    X : jax.numpy.ndarray
        Data matrix of shape (n_samples, n_variables)
    d : int
        Number of variables (nodes) in the DAG
    alpha : float
        Lagrange multiplier for the acyclicity constraint
    rho : float
        Penalty parameter for the augmented Lagrangian
    lambda1 : float
        L1 regularization strength

    Returns
    -------
    callable
        Function that takes NumPy array and returns (objective_value, gradient)
        suitable for SciPy's minimize function with jac=True
    """
    def scipy_obj_and_grad(w_np):
        w_jax = jnp.array(w_np)
        obj, grad = aug_lagrangian_with_grad(w_jax, X, d, alpha, rho, lambda1)
        return float(obj), np.array(grad)
    return scipy_obj_and_grad

def optimize_weights(w_est, X, d, alpha, rho, lambda1, bounds):
    """
    Perform one step of L-BFGS-B optimization for the augmented Lagrangian.

    This function wraps SciPy's L-BFGS-B optimizer to solve the constrained
    optimization problem in NOTEARS using the augmented Lagrangian method.

    Parameters
    ----------
    w_est : jax.numpy.ndarray
        Current estimate of the weight vector (length 2*d^2)
    X : jax.numpy.ndarray
        Data matrix of shape (n_samples, n_variables)
    d : int
        Number of variables (nodes) in the DAG
    alpha : float
        Lagrange multiplier for the acyclicity constraint
    rho : float
        Penalty parameter for the augmented Lagrangian
    lambda1 : float
        L1 regularization strength
    bounds : list of tuple
        Bounds for each optimization variable from create_dag_bounds()

    Returns
    -------
    jax.numpy.ndarray
        Optimized weight vector of length 2*d^2

    Notes
    -----
    Uses L-BFGS-B which is a quasi-Newton method that handles box constraints
    efficiently. The method is well-suited for the smooth optimization problems
    arising in the augmented Lagrangian formulation.
    """

    w_np = np.array(w_est)

    # Create the objective function wrapper
    obj_func = create_scipy_wrapper(w_est, X, d, alpha, rho, lambda1)

    # Run L-BFGS-B optimization
    result = sopt.minimize(
        obj_func, w_np, method='L-BFGS-B',
        jac=True, bounds=bounds
    )

    # Return JAX array
    return jnp.array(result.x)

def threshold_weights(W, threshold):
    """
    Apply thresholding to remove small edges from the estimated DAG.

    Sets edge weights with absolute value below the threshold to exactly zero.
    This is a common post-processing step in causal discovery to produce
    sparse graphs and remove weak, potentially spurious connections.

    Parameters
    ----------
    W : jax.numpy.ndarray
        Weight matrix of shape (d, d) representing the DAG
    threshold : float
        Minimum absolute weight value to retain. Edges with |weight| < threshold
        are set to zero

    Returns
    -------
    jax.numpy.ndarray
        Thresholded weight matrix with the same shape as W

    """
    return jnp.where(jnp.abs(W) < threshold, 0, W)
@timing
def notears_linear(X, lambda1=0.1, max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3):
    """
    Learn a DAG from data using the NOTEARS algorithm.

    Args:
        X (ndarray): [n, d] data matrix
        lambda1 (float): L1 regularization parameter
        max_iter (int): Maximum number of dual ascent steps
        h_tol (float): Exit if |h(W)| <= h_tol
        rho_max (float): Exit if rho >= rho_max
        w_threshold (float): Remove edges with |weight| < threshold

    Returns:
        W_est (ndarray): [d, d] estimated DAG
    """
    # Setup
    n, d = X.shape
    X = X - jnp.mean(X, axis=0)  # Center the data

    # Initialize parameters
    w_est = jnp.zeros(2 * d * d)  # [w_pos, w_neg]
    rho, alpha, h = 1.0, 0.0, jnp.inf

    # Create bounds for optimization
    bounds = create_dag_bounds(d)

    # Augmented Lagrangian optimization
    for i in range(max_iter):
        # Inner optimization
        w_new = None
        inner_converged = False

        while rho < rho_max and not inner_converged:
            # Optimize weights
            w_new = optimize_weights(w_est, X, d, alpha, rho, lambda1, bounds)

            # Check acyclicity
            W_new = _adj(w_new, d)
            h_new = h_acyclic(W_new)

            # Update penalty parameter if needed
            if h_new > 0.25 * h:
                rho *= 10
            else:
                inner_converged = True

        # Update estimates
        w_est = w_new
        h = h_acyclic(_adj(w_est, d))
        alpha += rho * h

        # Log progress
        print(f"Iteration {i}: h={h:.6e}, rho={rho:.2e}")

        # Check convergence
        if h <= h_tol or rho >= rho_max:
            break

    # Final processing
    W_est = _adj(w_est, d)

    # Thresholding
    W_est = jnp.where(jnp.abs(W_est) < w_threshold, 0, W_est)

    print(f"Final h(W): {h_acyclic(W_est):.6e}")

    # Transpose to get standard adjacency matrix where W[i,j]≠0 means i→j
    W_est = W_est.T

    return W_est


lambda1=0.05
loss_type='l2'
max_iter=100
h_tol=1e-8
rho_max=1e+16
w_threshold=0.3
estimated_A, time_taken_jax = notears_linear(X, lambda1, max_iter, h_tol, rho_max, w_threshold)
plot_dag_comparison(W_true, estimated_A)

Output

evaluate_reconstruction(W_true, estimated_A)

Output

Interactive Version

Run on Binder

Binder allows you to run this notebook in an interactive environment.