Skip to main content

Model Based Public Policy: Bayesian Neural Nets and Trade

 

1 Motivation

Recent changes in U.S. trade policy have sparked debates about their effect on American prosperity, national security, and federal income. This work began as an attempt to formalize the relationship between trade policy and prosperity and continues as a journey learning data science and how to apply modeling to public policy. For my first installment, I recount my experience building a Bayesian neural net to solve the problem of selecting a trade policy that optimizes American prosperity.

2 Defining The Problem

Let’s begin by defining the problem we wish to solve. Problem definitions require identifying three things: 

  1. Given information. This is the data we start with before any computation. For our modeling problem, we will need at a minimum data relevant to the prosperity of a nation along with trade policy data. 

  2. Unknown. The unknown is the set of values you want to identify. For a modeling problem, the unknown are 

    1. The templated function we wish to learn that defines the relationship between the variables in our data sets

    2. the values of the weights that render the templated function a specialized, concrete function characterized by our data.

  3. Constraint. The constraint is a binary function over the unknown and given values. It decides whether a given outcome is acceptable. In the context of modeling, a reasonable constraint is an acceptance criteria on the error between the outcomes the model produces and the real world outcomes it was trained with. The error between the outcomes and reality is called the loss function.


The concrete problem definition is stated below.


Problem Definition

Given

  • Prosperity metrics data set P in RN x T. Pt in RN is the prosperity data at time t

  • Foreign Policy data set F in RM x T. Ft in RM is the foreign policy data at time t


Unknown

  • Template model function Mw : RM → RN. Mw is a templated function with parameter W. In order to execute Mw we must first choose a W* in RK to bind to the parameter W 

  • The value W* in RK


Constraint

(1/T) Σt=1…T loss(Mw*(Ft), Pt) < epsilon for some tolerance epsilon.


Let us consider the practicality of our problem. We are trying to learn a relationship between trade policy and prosperity but we have excluded intermediate economic mechanisms that may have more direct causal influence over prosperity outcomes.


Since tariff rates do not directly impact prosperity, we introduce an additional layer of economic indicators as mediating variables. Foreign policy changes, such as tariff adjustments, influence these indicators, which in turn affect overall prosperity. Explicitly modeling these intermediate variables helps prevent misattributing changes in prosperity to the wrong cause. Let’s redefine our problem with this intermediate layer in mind.


Problem Definition

Given

  • Prosperity metrics data set P in RN x T. Pt in RN is the prosperity data at time t

  • Economic indicators data set I in RY x T. It in RY is the economic indicators at time t

  • Foreign Policy data set F in RM x T. Ft in RM is the foreign policy data at time t


Unknown

  • Template model function M1w : RM → RY. M1w is a templated function with parameter W1. In order to execute M1w we must first choose a W* in RK1 to bind to the parameter W1 

  • The value W* in RK1

  • Template model function M2w : RY → RN. M2w is a templated function with parameter W2. In order to execute M2w we must first choose a W** in RK2 to bind to the parameter W2 

  • The value W** in RK2


Constraint

(1/T) Σt=1…T loss(M1w*(Ft), It) < epsilon for some tolerance epsilon AND (1/T) Σt=1…T loss(M2w*(It), Pt) < epsilon for some tolerance epsilon.


Now that we have an idea of what we need to create (the templated functions), what we need to compute (the weights) and how to measure the accuracy of the model, let us start to collect our data.


3 Defining the Model: Foreign Policy Effects on Economic Indicators

Let us progress to our next step: creating the templated function. In our final problem definition we had two templated functions to identify. The first related trade policy to economic indicators and the second related economic indicators to prosperity. We shall start with the first: relating trade policy to economic indicators. 


Let us simplify the modeling task by selecting a single economic indicator. We will model the relationship between trade policy and the price of goods before expanding to other economic indicators such as wages and employment, tax revenues, and investment and productivity. Our goal then is to express prices as a function of tariff policy, so that we can learn how markets adapt to tariff changes over time.

Intermediate Variables

To build a model capable of producing valid causal insights, we must identify and explicitly represent the key variables that mediate the relationship between observed quantities. If we fail to model these variables, their influence is absorbed implicitly into the model's learned parameters, which can lead to inaccurate predictions.
For example, imagine a training data set where high tariffs are paired with low prices and low import volume. Without modeling import volumes explicitly the model might incorrectly learn that high tariffs cause prices to fall when the true causal story might be that high tariffs increase prices if volume is high and have little effect if the volume is low.
While it’s impossible to include every influencing factor, we can use economic reasoning to isolate those with the most importance. In the model relating economic indicators to foreign policy, we will start by explicitly including volume. Because volume is dependent on price, we will model volumes as a function of the previous year’s prices.
Selecting volume as an intermediate variable gives us a starting point to start building a model but at this stage we cannot confirm with certainty that we’ve captured all relevant intermediate variables. Once we have a functioning model with our variables of choice, we can test for key variable omission using the following techniques:


  1. Economic Consistency Checks. After training the model we consider a real world economic situation outside the data set we trained the model on. For example, suppose we know that in 1970 tariffs on coffee skyrocketed and the cost of coffee shot up because it could not be domestically produced, yet strong demand for the bean kept the import volumes high. We could suppose that this happened again this year, and ask our model to predict the price of coffee next year given the tax increase. If the model predicts that the price will fall, we have reason to believe that we are missing some governing variables in our model, specifically the ability for domestic producers to offer a less expensive option.

  2. Sensitivity analysis. We consider other variables to add to the model such as domestic production capacity and consumer demand. If the model’s predictions appear to be sensitive to the variable’s inclusion, then clearly the model was missing an important causal pathway.


Function Shape

We have identified our variables in the templated function but we have not identified the operators that join them into a prediction. Importantly, we will assume a nonlinear relationship between tariffs and volumes, and between volumes and prices, to reflect saturation effects (once volume reaches zero, further taxation will have no effect) and behavioral thresholds (sudden changes in volume due to a large number of consumers no longer willing to higher price).

I have chosen to use a Bayesian Neural Network rather than a traditional neural network because I would like to quantify prediction uncertainty.  For example, under certain combinations of import volume and trade policy, prices may become volatile, making them more difficult to predict. This is not a limitation of modeling but a reality of the physical world. A traditional model may opaquely overfit to this noisy data. A Bayesian neural net however will include uncertainty in its prediction, giving us insight for how much confidence to have in its predictions. 

The final model is specified below.


Figure 0

Price[t] = f_price(

    Price[t-1],         # previous price

    TariffIn[t],        # current import tariff

    TariffIn[t-1],      # previous import tariffs

    VolumeIn[t-1],      # previous import volume

    t,                  # time index 

) + ε[t]                # error term to account for economy wide disruptions


VolumeIn[t] = f_volIn(

    TariffIn[t],        # current import tariffs

    TariffIn[t-1],      # previous import tariffs

    Price[t-1],         # previous prices

    VolumeIn[t-1],      # previous year's import volume

    t                   # time index

) + ε[t]                # error term to account for economy wide disruptions



ε[s, t] represents a time-dependent, sector-specific error value that accounts 

for factors not explicitly modeled but that can account for small inaccuracies 

in the model. For example, supply chain disruptions, regulatory changes, etc. 


The latent variables we will learn are the weights of the learned functions f_price and f_volIn. Each weight is assumed to be sampled from a Normal distribution. In other words, f_price and f_volIn are Bayesian Neural Networks, or BNNs.

4 Model Implementation

The next step is to implement the model. First let us offer more details about Bayesian Neural Nets.

Definition

 A Bayesian Neural Network (BNN) extends a traditional neural network by treating its weights and biases as probability distributions conditioned on data, rather than fixed point estimates. 


In a traditional neural network, inference is performed by computing a linear mapping with fixed value weights followed by an activation function to handle nonlinearities. Concretely, each layer computes the following:


f(x) = ɸ(x W + b

where 

  • x in Rd is the input vector

  • W in Rd x h  is the weight matrix

  • b in Rh is the bias vector

  •  ɸ(.) is a non linear activation function (e.g. ReLU, tanh)


In contrast, when we query the trained BNN we sample multiple weights from the posterior distribution over parameters conditioned on the data during training. The forward pass then becomes the following:

W(s) ~ p(W | D), b(s) ~ p(b | D)                                 [1]


f(s)(x) =  ɸ(x W(s) + b(s)) where W(s), b(s) ~ p(. | D)    [2]


where


  • D is the training data

  • W(s), b(s) is a particular sample

  • W, b represents the full set of weight parameters in the network



Since each forward pass draws new samples W(s), b(s), the output f(x) becomes a random variable: each pass yields a different value for f(s)(x), each drawn from a distribution of plausible outcomes. Repeating this process yields a predictive distribution over outputs. That is, for a single query to a BNN there are multiple passes required to create the outcome distribution.


By computing a distribution rather than a single point estimate for a prediction, a Bayesian model provides both the predicted value (mean of the distribution) and its uncertainty (the variance of the distribution). For example, if the relationship between prices and tariffs is volatile, a single predicted price may be misleading. In contrast, a full predictive distribution—potentially with wide or fat tails—captures the range of likely outcomes and their probabilities. This allows us to assess not just the expected price, but also how uncertain that estimate is, enabling more risk-aware decisions.


Architecture

A Bayesian neural net has a slightly different architecture than a traditional neural net during training. Instead of a single model function to optimize the Bayesian neural net uses two. The first function, the model, specifies the generative process: how outputs f(s) are generated given inputs x over latent parameters W, b. The model function corresponds to equation [2] in the previous section. The second function, called the guide, represents the posterior of the weights under observed data. The posterior distribution is the distribution of a random variable given observations. It is defined as

p(θ | data) = p(data | θ) p(θ) / p(data)                  [3]


For example, in the context of our model we want the distribution of prices conditioned on observed tariff policy and resulting import volumes. 


Computing [3] directly is intractable so we approximate it using variational inference. Variational inference is a technique that transforms the problem of computing a complicated posterior distribution p(θ | data) into an optimization problem.


The optimization problem we solve instead of computing the posterior directly is the following:


Given: a family of simpler, parameterized family of distributions q(θ; ɸ )

Unknown: the values ɸ of q

Constraint: the distance between p(θ | X) and q( θ ; ɸ ) is minimized


The guide function defines q(θ; ɸ ), the variational approximation to the true posterior, and corresponds to equation [1] in the previous section. To create the BNN then, we define our model and our guide functions and solve the above optimization problem to choose the parameters of our approximated posterior distribution (equation [1]). To illustrate how the guide and the model are used together to create a predictive function I have included the following pseudo code based on a model built atop the NumPyro framework.


# Define model architecture (layers, weights, priors, activation functions)

def model(X, Y=None):

    W ~ sample(“W1”, prior)                # sample latent weights from prior

    b ~ sample(“b1”, prior)

    # more layers here if needed


    Y_hat = tanh(XW + b)       # NN produces parameters of our distribution

    Y ~ likelihood(Y_hat)      # sample from resulting distribution



# Define the approximation of the posterior 

# This function defines q(W, b; θ) = Normal(𝜇, σ ) where 𝜇, σ  are learned

def guide(X, Y=None):

    sample(“W1”, Normal(𝜇, σ )

    sample(“b1”, Normal(𝜇, σ ))


for each step in training:

    # Sample parameters from the guide 

    # Internally, the inference engine traces the guide

    # to obtain a sample W_sample ~ q(W; θ)

    param_sample = guide(X, Y; θ)


    # Replay these samples into the model and compute log-probabilities

    # p(Y, W | X) → probability of observing Y and W given X

    log_prob = log_prob(model, X, Y, guide_trace=param_sample)



    # Compute ELBO

# More information on this quantity is in a later section. 

    elbo = log_prob - log_q(param_sample; θ)


    # Take gradient of ELBO with respect to guide parameters θ

    θ = θ + optimizer.step(gradient(elbo, θ))


What does it mean to “replay these samples into the model and compute the log probabilities?” This is getting into NumPyro internals, but it does help to understand how this model learns so I will include it. To make the explanation clear, I will include the actual code for the model and the guide.


Model:

def bnn_model(X, Y=None):

   D_in = X.shape[1]

   H = 16

   D_out = 1


   # Sample weights and biases from prior

   W1 = numpyro.sample("W1", dist.Normal(0.5, 0.1).expand([D_in, H]).to_event(2))

   b1 = numpyro.sample("b1", dist.Normal(0.5, 0.1).expand([H]).to_event(1))

   W2 = numpyro.sample("W2", dist.Normal(0.5, 0.1).expand([H, D_out]).to_event(2))

   b2 = numpyro.sample("b2", dist.Normal(0.5, 0.1).expand([D_out]).to_event(1))


   # Forward pass

   hidden = jnp.tanh(jnp.dot(X, W1) + b1)

   output = jnp.dot(hidden, W2) + b2

   output = output.squeeze()


   sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))

   with numpyro.plate("data", X.shape[0]):

       numpyro.sample("obs", dist.Normal(output, sigma), obs=Y)


Guide:

def bnn_guide(X, Y=None):

   D_in = X.shape[1]

   H = 16

   D_out = 1


   def normal_param(name, shape):

       loc = numpyro.param(f"{name}_loc", jnp.zeros(shape) + 0.5)

       scale = numpyro.param(f"{name}_scale", jnp.ones(shape) * 1e-2, constraint=dist.constraints.positive)

       return dist.Normal(loc, scale).to_event(len(shape))


   numpyro.sample("W1", normal_param("W1", (D_in, H)))

   numpyro.sample("b1", normal_param("b1", (H,)))

   numpyro.sample("W2", normal_param("W2", (H, D_out)))

   numpyro.sample("b2", normal_param("b2", (D_out,)))

   numpyro.sample("sigma", normal_param("sigma", ()))



Step 1: Trace the Guide

NumPyro traces the guide to collect samples for the weight values. This trace is represented as a dictionary that records what was sampled (for example “W1”), from what distribution ( “Normal(loc, scale”)), and the result of the sampling (“value”). The trace has the following structure:


trace = {

  "W1": {"value": jnp.array(...), "fn": Normal(loc, scale), "type": "sample", ...},

  "b1": {"value": jnp.array(...), "fn": Normal(...), ...},

  ...

}


Step 2: Replay the model with the trace

Execute the model function, but replace computations in the model with results from the trace where applicable. For example, if in the model we have a sample from the prior Normal(0.5,0.1), written something like this:


W1 = numpyro.sample("W1", dist.Normal(0.5, 0.1))


Then during replay NumPyro sees that “W1” is in the trace so instead of sampling from Normal(0,1) it instead uses the value from trace[“W1”][“value”].


The result of the model trace has something like the following form:

{

  "W1": {

    "type": "sample",

    "fn": dist.Normal(0, 1),

    "value": tensor([[ 0.5, -0.2], [0.1, 0.7]]),  # from guide

    "log_prob": tensor(-3.5), # log probability of the weight w1 under its prior distribution

    ...

  },

  "b1": {

    "type": "sample",

    "fn": dist.Normal(0, 1),

    "value": tensor([0.0, 0.1]),

    "log_prob": tensor(-1.2),

    ...

  },

  "Y": {

    "type": "sample",

    "fn": dist.Normal(loc=..., scale=...),

    "value": Y,  # observed!

    "log_prob": tensor(-10.8)

    ...

  }

}

Of particular importance is the highlighted log_prob value at the "Y" site. This value represents the log-likelihood of observing Y given the input X and the weights (which are fixed via replay).

This is the quantity we aim to maximize during training. A higher value indicates that, under the current model (with the sampled weights), the observed data is more likely — meaning the model better explains the observed outcomes.

Step 3: Compute Evidence Lower Bound (ELBO)

The evidence lower bound is the objective function we maximize during training. It is a lower bound on the log likelihood of some observed data Y conditioned on known inputs X, written log(p(Y|X)). We want to maximize log(p(Y|X)) but it is intractable to do so directly. Below we will step by step work our way to the expression that is feasible to maximize instead, the lower bound (ELBO).

Start with Bayes Theorem: 


p(W | Y) = p(Y, W)/p(Y)


p(Y) is the normalizing constant and is defined 


p(Y) = ∫ p(Y|W) d P(W)   // average p(Y|W) weighted by density of W

     = ∫ p(Y|W) p(W) dW  // in terms of parameter space


First we write this as the log evidence:


log p(Y) = log  ∫ p(Y|W) p(W) dW

log p(Y) = log ∫ p(Y, W) dW


Now introduce the approximation q(W; θ ) and apply Jenson’s inequality:

log p(Y) = log ∫ q(W; θ) * (p(Y, W) / q(W; θ)) dW

         ≥ ∫ q(W; θ) log [p(Y, W) / q(W; θ)] dW [6]


and since


log [p(Y, W) / q(W; θ)] = log p(Y,W) - log q(W;θ )

                            = log p(Y | W​) + log(p(W​)) - log(q(Wi;θ))


We can rewrite [6]:


log p(Y) ≥ ∫ q(W; θ) [log p(Y | W​) + log(p(W​)) - log(Q(W​;θ))] dW [7]


The RHS of [7] is the evidence lower bound (ELBO). Note that the expectation E with respect to a parameterized distribution q(W; 𝛉) of a function f(W) is defined as 


E_q(W;θ)[f(W)] = ∫ f(W) * q(W; θ) dW      [7]


so the ELBO can be rewritten


log p(Y) E_q(W;θ)[log p(Y | W​) + log(p(W​)) - log(q(W​;θ))] [8]


where

  • θ is a vector of parameters of the variational distribution that approximates the true posterior over the model’s weights 

  • W are the weights of the neural network

  • q(W; θ ) is the variational distribution over weights

  • Eq(W; θ )[f(W)] is the expected value of f(W) 

  • Y is the observed data


The loss function is the negative ELBO. This is the function we aim to minimize by taking the derivative of the loss function with respect to the parameter vector θ and evaluating the derivative function at the current value of θ. 


Let’s consider how the loss is computed. Note that the integral in [8] is calculated in practice using discrete integration numerical methods, Monte Carlo in our case. That is, we compute the following:


loss(θ)= −1/K ∑i=1…K ​[log p(Y | W[i]​) + log(p(W[i]​)) - log(q(W[i]​;θ))]

       = −1/K ∑i=1…K [∑j=1…N​[log p(y[j] | x[j], W[i]​)] + log(p(W[i]​)) - log(q(W[i]​;θ))]     [9]

To evaluate the ELBO, we use the following Monte Carlo approximation:

  1. Sample K weight vectors W[i] from the variational distribution q(W; θ).

  2. For each sample W[i]:

    • Bind W[i] to the model weights.

    • Compute the model output Ŷ[i] = f(X; W[i]).

    • Evaluate the log-likelihood log p(Y | W[i]) by assuming a fixed variance Gaussian likelihood: Y ~ N(Ŷ[i], σ²), with σ held constant.

    • Compute the log prior probability log p(W[i]).

    • Compute the log variational probability log q(W[i]; θ).

  3. Aggregate the terms:

    • Compute the average over K samples of log p(Y | W[i]) + log p(W[i]) − log q(W[i]; θ).

    • Negate the result to form the loss (since most optimizers minimize rather than maximize).

Understanding this approximation is a requirement to interpreting the ELBO loss values and assessing whether a given value is reasonable in light of data noise and model uncertainty.

For example, I trained my model with K = 10 (sometimes referred to as the ‘particle count’) and 100 observations (N = 100). I found that it was difficult to reduce the loss beyond around 300. But what does this number 300 really mean?
For K particles and 100 observations that means the mean loss contribution is -300/10*100. This corresponds to a log-likelihood of approximately -0.3 per observation, which is consistent with a reasonably high likelihood under a Normal distribution (log(p(y[i]|W) = -0.3 => p(y[i]|W) = e^(-0.3) = ~0.741 => 74.1% likelihood). It is unreasonable to expect any fixed set of weights to perfectly explain all observations, given that the data-generating process is noisy. The goal is not to fit the noise, but to capture the underlying distribution of responses. Thus achieving a log likelihood of -0.3 per observation may be considered acceptable.

Training Risks 

We have looked at the model and guide definitions and discussed how to assess the loss. Next we will take a deeper look at the training loop. Recall from the pseudo code that the parameters of the variational distribution are updated using gradient descent. Gradient descent is a technique used in numerical analysis to solve optimization problems. It iteratively updates a candidate solution until that candidate is sufficiently close to the true minimum of the target function. 
Gradient descent produces an approximation of the minimum by iteratively updating a candidate x by the negative of its gradient, -▽f(x), scaled by a step size η. A smooth and stable surface with a single global minimum is thus ideal for this type of algorithm. 
In practice the optimization surface may have high curvature, non-convexity, or many local minima. This section identifies symptoms of each of these problems and how to address them. In the next section we will show the full code example that incorporates remedies to each of these problems. 


Problem 1: Parameter Value Gets “Stuck”. 

If you observe a fluctuating parameter value paired with a high non monotonically decreasing loss pattern you may have an issue with high gradient variance. There are two types of gradient variance.


Variance Across Minibatches. 

Gradient variance with respect to mini batches refers to the variance of the set of gradients hypothetically collected during a single training step given a fixed parameter value θ and different input sets (the minibatches). In practice, a single mini batch is sampled at random and the gradient from this isolated set is used as an approximation of the gradient of the loss function over the entire data set. 
But if the variance of the minibatch gradients is high, then the gradient of any one minibatch is unlikely to be a sufficient estimation of the gradient of the loss under the entire dataset. As a result, updating the parameter value with the gradient from one subset may reduce the loss from the contributions of that subset but increase the loss from contributions in another, resulting in a loss that fails to converge.
A solution for this problem is increasing the mini batch size or ensuring that each mini batch has the same distribution as its superset. In my toy problem I did not use mini batching so this was not a problem for me. However, as I scale my problem and start to rely on real world data I will need to incorporate this technique.


Variance Across Training Steps (Across time). 

This type of gradient variance is with respect to the set of gradients collected over the entire course of the training. That is, a single gradient of this set was calculated under a unique parameter value at a unique time step. A low variance represents a fairly stable and smooth optimization landscape while a high variance suggests a more volatile optimization landscape, potentially with sharp curvature, frequent oscillations, or numerous local minima.


The problem with high variance is that it can cause updates in the wrong direction. Let’s consider two plots to illustrate the difference visually. Below are two images. The first is an optimization landscape with high variance and the second is an optimization with low variance. 


                


Note this map of the loss function. Because of the jagged dips, the gradients of loss function have a high variance which means it's common to find yourself in a location where updating the current parameter value actually gets you farther away from your goal instead of closer. Without detecting this drawback the training loop may get stuck in one of these dips.




This loss function has a much smoother and more gradual shape, which indicates gradients that are closer together and therefore a lower variance. In this case it is not possible to get stuck in a fluctuating loss pattern if you continue to update the parameters by the negative gradient value.

One mitigation to this problem is scaling the data used to train the model. When inputs vary widely in magnitude as they do in the case of import volumes and price of a good, the loss surface can become steep in some directions and flat in others.

High variance amongst gradients across time steps is a particularly thorny problem when training Bayesian neural nets relative to traditional neural nets. In traditional neural nets, the gradient used to update the parameter value is the gradient g(t) which is estimated from the batch and can be written g(t) = g(t) + 𝛜1(t), where g(t) is the full gradient and 𝛜1(t) is the error term and is equal to the variance of random variable g(t).

In Bayesian neural nets, not only do we have the gradient variance from the mini batch selection 𝛜1(t), but we also have the variance 𝛜2(t) due to the parameter W being a random variable rather than a fixed point estimate. Note that in equation [9] we sample K parameter values from the variational distribution q(W; θ ). If the variance of that distribution is high, we will not compute a gradient that is representative of full parameter space. That is, in the Bayesian neural net training loop we compute  g(t) = g(t) + 𝛜1(t) + 𝛜2(t). Because of this extra error term the risk of divergence when training a BNN is higher than in a traditional neural net.

 

Problem 2: NaN Parameters. 

It’s common to observe a point in the training loop where the parameter becomes NaN and stays there. The parameter usually becomes Nan because either the gradient of the loss function grows too large or we have produced an input that is undefined under our loss function. 


Exploding Parameter Norms

Recall the gradient descent update function:


θ t+1 = θ t - η ▽L(θ t)             [10]


If the second term is NaN, then the RHS of [10] is NaN. This situation is called “exploding parameter norms”. 
Exploding parameter norms can be diagnosed by inspecting the Hessian of the loss function. The Hessian of a function L quantifies the rate of the rate of change of L, or the second derivative. By inspecting the sign and magnitude of the eigenvalues of the Hessian we can  predict if the model will diverge during gradient descent. In particular, if any of the eigenvalues of the Hessian have large magnitudes, it indicates a sharply increasing or decreasing slope of the loss function in that direction, perhaps by a quantity not representable in the chosen bit width of our integers in the training loop, hence the NaN. If any of the eigenvalues have a negative sign then it indicates that there is no minimum of the loss function in that dimension. The value in that dimension will diverge.


For example, consider a loss function that resembles a saddle. In particular, the loss function is L(x, y) = x2 - y2.  In the x direction, it has a valley but in the y direction it is more like a hill. The shape is crudely illustrated below.





The Hessian of the loss function is H = [2 0, 0 -2]. The magnitudes of the eigenvalues are reasonable but the sign on the y direction eigenvalue is negative. Writing out the gradient descent formula we can see that this will result in divergence of the y element in the parameter:


xt+1, yt+1 = xt - 2η xt, yt + 2ηyt 


The Hessian can be used to guide values like step size selection but it is an expensive technique and does not guarantee the prevention of exploding gradients. We will use gradient clipping, the bounding of gradients to a maximum value, to prevent exploding gradients instead.


Undefined Values

Our loss function [9] contains several log expressions. Recall that the natural log function converges to negative infinity as x converges to zero and is undefined at x = 0. The upshot is that if p(Y|W), p(W), or q(W; θ) is zero or close enough to zero to produce a result whose magnitude exceeds the max value representable in the chosen bitwidth, we would get NaN as the output of our loss function and our training loop would die. 

Input scaling as we suggested earlier can also combat NaN gradients because it can address underflow or overflow by moving extreme values into a more reasonable range. I also found that selecting appropriate initial values was really important. Otherwise the first proposed value can have a likelihood low enough to NaN out the log expression. Lower variance in the initial normal distributions for the variational parameters helped solve this.


Problem 3: Convergence is too Slow

Note the parameter η in our gradient descent equation [10]. The value of this parameter will determine the size of our steps during training. Too small of a value and we may have to wait too long and too large of a value and we may jump past the minimum. Intuitively, we want to jump further when the curvature of the loss function is shallow and jump smaller when it is steep. Thus using the values that quantify the steepness, the Hessian, to adjust the step size is a highly effective way to keep the step size at the optimal size. 

Training Debugging

After developing an intuition for how to address problems that may arise during training, I began to train the model I defined earlier. During this testing phase I used synthetic data I generated. Synthetic data generation is a convenient technique for testing model convergence because I define the relationship between values and their uncertainty so I can test if the model correctly learns those relationships. Below is an illustration of the data I generated:



The first illustrates that as tariff rates decrease the prices also decrease. The second illustrates the increase in import volumes as tariff rates decrease. Note that this is admittedly a very simple almost linear relationship and probably does not need the uncertainty and nonlinear modeling of the BNN...but we want to start with something simple! Now let us see if the model learns this relationship.


I started with a relatively simple training loop:


def train(ACC, model, guide, X_train, Y_train, num_steps=5000, learning_rate=1e-3, seed=0): N = len(Y_train) rng_key = random.PRNGKey(seed) optimizer = numpyro.optim.Adam(learning_rate) K = 10 elbo = TraceMeanField_ELBO(num_particles=K) svi = SVI(model, guide, optimizer, loss=elbo) svi_state = svi.init(rng_key, X_train, Y_train) max_grad_norm = 100.0 print("\nStarting training loop...") params = svi.get_params(svi_state) opt_state = optimizer.init(params) for step in range(num_steps): rng_key, rng_loss, rng_noise = jax.random.split(rng_key, 3) def loss_fn(params, k): raw_loss = elbo.loss(k, params, model, guide, X_train, Y_train) return raw_loss loss_val, grads = value_and_grad(loss_fn)(params, rng_loss) grads = jax.tree_util.tree_map(lambda g: jnp.nan_to_num(g, nan=0.0, posinf=1e3, neginf=-1e3), grads) grads, _ = clip_gradients(grads, max_grad_norm) optimizer = numpyro.optim.Adam(learning_rate) opt_state = optimizer.update(grads, opt_state) params = optimizer.get_params(opt_state) if step % 100 == 0: per_obs_ll = -loss_val / (10 * len(Y_train)) print(f"Step {step}, Loss: {loss_val:.2f}, Per-obs LL: {per_obs_ll:.3f}") if is_loss_acceptable(loss_val, K, N, ACC): break print("Training complete.") params = optimizer.get_params(opt_state) return params, svi, rng_key


This simple training loop was sufficient to learn the relationship between pricing and tariffs and and import volumes (The first learned function in figure 0) with ACC = 0.75 and sufficient to learn the second function in figure 1 with ACC = 0.72, so for this particular problem I did not have to employ any of the techiques for taming divergence.

To visualize the trained model, I queried my new functions to discover its prediction for prices. I created a policy in the form of an array of tariffs, increasing over the next 20 time steps and used the output of each time step as the input into the following query. I then plotted the results. The mean is a hard line and the variance is a faded band around the mean:


The model correctly predicted the relationship of prices increasing when import volumes plummet. Because I control the data generators I can test more complex relationships and also test that the variance is learned correctly. Now that I have a model that learns from synthetic data, I can

  1. Add other economic indicators to this model
  2. Build out the second tier of the model that relates the economic indicators to prosperity
  3. create more complex relationships in the test data to further test the model's ability
  4. Train the model on real world data
Once the fourth step is complete we can finally discuss how to use those results to shape policy. I will cover these steps in future installments.





Comments

Popular posts from this blog

Asset Pricing Revised

  In a previous post I included a problem definition and an example. Upon reviewing the post, I discovered that not only were both the problem definition and the example incorrect, they were inconsistent. In this post I aim to correct the problem definition and then reimplement the example with the new problem definition. Part 1: Correct the Problem Definition Let's start with a recap of the problem we are trying to solve. Note that I altered the problem definition a bit to correct two major issues: The constraint in the original problem definition was not in terms of the givens. The requirement is that the expression we optimize must be expressible entirely in the terms we are given. A previous definition omitted the payouts and the potential economic scenarios, both of which are referenced in our constraint equation. Optimization was not with respect to the correct property. The original expression I defined was optimizing payout minus the cost of the portfolio. But we are not o...

Back To Basics: An Introduction to Bayesian Modeling

I was preparing another blog series on how transformers work when the real world disrupted my focus and once again pulled me into the world of Bayesian inference. This time it was not geopolitical tensions that caught my attention but the relentless ascent of the S & P 500 despite the perceptively turbulent social and political environment of the United States. Confused, I sought answers by trying to identify relationships between economic activity and the share price of the largest American corporations. Assuming this relationship to be noisy, I reached for the tool that would quantify this noise via reported uncertainty. This blog post is a recount of this journey, starting with a review of the tools I plan to use to learn the relationships of interest.  This first installment is introductory. A subsequent post will attempt to recreate the results of a decades old case study with modern data. In a final installment, a hypothesis about relationships in modern times will be pro...