Demo 2: Neural Network Interpretability

We created a very small neural network with GELU-GELU-Linear architecture (49 → 3 → 3 → 10) for MNIST digit classification, with Softmax for post-processing to obtain class probabilities. This tiny network allows us to build exact polytope representations using linear programming.

Mathematical Formulation

Affine Transformations

a = W · zℓ-1 + b Where: • W is the weight matrix for layer ℓ • b is the bias vector • zℓ-1 is the output from the previous layer

GELU Activation Function

GELU(x) = x · Φ(x) Where Φ(x) is the CDF of the standard normal distribution. Approximation used: GELU(x) ≈ 0.5x(1 + tanh(√(2/π) · (x + 0.044715x³)))

Polytope Encoding of GELU

For each neuron with pre-activation a and post-activation z: Lower envelope: z ≥ αL · a + βL Upper envelope: z ≤ αU · a + βU Where αL, βL, αU, βU are computed using: • Interval bounds [L, U] for a (from IBP) • Tight linear envelopes that bound GELU over [L, U] This replaces the nonlinear GELU with linear constraints!

Softmax (Post-processing)

Softmax converts logits to probabilities: pi = exp(a3[i]) / Σj exp(a3[j]) Where: • a3[i] is the logit for class i • pi is the probability for class i • Σ pi = 1 Note: The polytope operates on logits (a3), not probabilities. Softmax is only used for final classification and visualization.

Network Forward Pass

x0 = input (49-dimensional, 7×7 flattened) a1 = W1 · x0 + b1 (shape: 3) z1 = GELU(a1) a2 = W2 · z1 + b2 (shape: 3) z2 = GELU(a2) a3 = W3 · z2 + b3 (shape: 10, output logits) Prediction = argmax(a3) [or argmax(Softmax(a3))]

What is the Polytope?

A polytope is a geometric region defined by linear inequalities. For neural network verification, we construct a polytope that over-approximates all possible network behaviors for inputs in a given region.

Variables in our polytope:

Constraints in our polytope:

  1. Input box: x₀[i] ∈ [x₀[i] - ε, x₀[i] + ε] ∩ [0, 1] for all i
  2. Affine relations: a = W · zℓ-1 + b (equality constraints)
  3. GELU envelopes: Linear lower/upper bounds on z = GELU(a)

Why it's useful: Any output reachable by the network for inputs in the ε-ball is guaranteed to be in our polytope. This approach, inspired by Singh et al.'s DeepPoly, enables us to analyze network behavior through linear programming. Beyond verification, this representation provides interpretability: we can probe how individual neurons contribute to predictions by equipping the polytope with the right objective functions and doing the optimization.

Current demo uses ε = 0.01, giving 49 input constraints + 6 GELU envelope constraints (2 per neuron × 3 neurons in each hidden layer).

Interactive Network Visualization

Node color:
(Low → High activation)
━━ Positive weight
━━ Negative weight
Hover for details • Click hidden neurons for patterns
Current Digit:
0
Loading...

Mechanistic Interpretability

By stepping through the NN, we can understand how the network composes features to make predictions. Each hidden neuron learns interpretable patterns that combine to form digit classifiers. We formally verify these properties with the polytope.

Click on the hidden neurons in the visualization above to see what patterns they detect. The 6 hidden neurons (3 in each layer) learn distinct visual features:

Example: How the network recognizes Digit 0

Digit 0 ∝ (++ Frame) - (- Spine) - (- Belt)
Must act like a container    Must have empty center    Must have empty middle
Mechanistic trace for digit 0

The dashboard shows how Layer 1 neurons detect basic patterns (Frame, Spine, Belt), and Layer 2 neurons combine them with learned weights to produce the final digit 0 logit. The network learns that digit 0 should strongly activate the "Frame" detector while avoiding activation of "Spine" and "Belt" detectors (which would indicate filled regions).

Robustness Analysis

The plot below shows robustness rates (percentage of test samples where the LP makes the correct prediction within an ε-ball) across different perturbation sizes for 600 test samples:

Robustness vs Perturbation Size

Key findings: The LP maintains high accuracy for small perturbations (ε ≤ 0.02), with varying sensitivity across different digit classes. For example, digit 1 remains highly robust even at ε = 0.02, while digit 4 degrades more quickly. Even more interestingly, the LP appears to be a good classifier for the MNIST problem in and of itself.