The Dual Form of Neural Networks Revisited: Connecting Test Time Predictions to Training Patterns via Spotlights of Attention

Hi Everyone,

In this article, we understand the paper “The Dual Form of Neural Networks Revisited: Connecting Test Time Predictions to Training Patterns via Spotlights of Attention.” Neural networks are widely used for various tasks. Still, their exact working of them is not fully understood. Multiple ways are proposed to understand the trained network, such as activation and weight visualization. In this paper, the authors explore an alternate method of understanding the neural network using the dual form of the neural network were they as the following question:

is it possible to point out exactly which training samples are the original sources of that specific output?

We use some of the equations directly from the paper. We try to discuss only some aspects of the article in detail. The title summarizes the main contributions perfectly as it can be listed as follows:

  1. How Unnormalized dot attention is a particular case of the linear layer and equivalence between the two.
  2. Write gradient descent (GD) update equation of the weight for both linear layer and the attention for the whole course of training.
  3. Show that the test prediction based on the final weight can be written as a linear combination of the error signal during the training.

The connection between linear layer and unnormalized dot product attention

The unnormalized dot attention can be defined as Atten(K,V,\mathbf{q})=\sum_{t=1}^{T} \alpha_t v_t = VK^T\mathbf{q} .

where K=(k_1,..,k_T), V=(v_1,..,v_T) and \alpha_t=k_t^T\mathbf{q}. This equation takes the linear combination of weight value vectors, and the weight depends on the inner product between the key vector and the input. Please note the attention is written in both summation form as well as in the matrix form.

If we consider a linear layer with parameter W with input \mathbf{x} , the output of the linear layer is S_2(\mathbf{x}) = W\mathbf{x} . The same equation can be re-written as \mathbf{y}=S_2(\mathbf{x}) =Atten(K,V,\mathbf{x}), where W=VK^T=\sum_{t=1}^{T} v_t k_t^T . The W is the sum of the outer product of the T set of vectors.

The dual form of linear layer trained by gradient descent

The linear layer is the core of many neural networks, such as fully connected layers, convolutional neural nets, LSTM, and transformers. We focus on the parameter evolution of one linear layer (part of a big network). The loss function is L. The gradient descent update for the weight at t-th update for the input x_t can be written as follows:

 W_{t} = W_{t-1}+\eta \frac{\partial L}{\partial W_t}  = W_t+\eta \frac{\partial L}{\partial \mathbf{y}}\frac{\partial \mathbf{y}}{\partial W_t}= W_t+e_tx_t^T

where e_t is the backpropagation error computed as part of the training. The final weight at the end of the training (after T-updates) is given by

 W=W_N =W_0+\sum_{t=1}^{T} e_t x_t^T

Given a test sample to the input layer x_{test}, we compute the linear layer output as follows:

 y_{test}=Wx_{test} =W_0x_{test}+\sum_{t=1}^{T} e_t x_t^Tx_{test} =W_0x_{test}+\sum_{t=1}^{T} \beta_t^{test} e_t     =W_0x_{test}+Atten(X,E,x_{test})

where \beta_t^{test}=x_t^Tx_{test}, the last equation can be written from the Atten equation from the last section This equation is very interesting…! This connects the test time predictions y_{test} to the training inputs ( x_{t}) and the corresponding backpropagation signals e_t. The linear combination depends on the inner product between the test and training samples. This seems very trivial (it’s not..!), but establishing that connection is very interesting. We will see clear examples to understand this relationship in the coming sections.

Important Remarks

  1. In the case of dual form, We need the inputs and the error signal for the entire training course to compute the test sample’s output. We just need one matrix multiplication in the direct linear layer. So the decision on the test input depends on the whole course of training and not just the final weight. This kind of a strange result, and we will try to illustrate it with examples in later sections.
  2. Note that the analysis is for a simple linear; it applies to all classes of neural nets with a linear layer (Almost all types of nets). The article considers general input, loss, and backpropagation errors.

Some more exciting remarks are found in the paper.

Illustrations:

The paper illustrates results using various interesting examples for different tasks. In this article, we stick to a simple two-dimensional example of concentric circles shown in Fig. 1. The main idea is to visualize the connection between the training and the test data, as discussed above. We will see only a few examples; you regenerate and play around with different parameters in the colab notebook. If you have some questions, please post them in the comment section.

Fig1: The plot of the training data. Different colors indicate different classes.

Let’s use a simple feedforward 3-layer network with 32 hidden dimensions with tanh activation. The network is trained for 50 epochs and achieves >90% accuracy (A straightforward problem anyway). The simple classification boundary and the test points are shown in Fig 2.

Fig 2: Classification decision boundary and the test data.

Lets a sample test point and compute the attention weights( \alpha_{t}) associated with each training point. Please note that there are multiple \alpha_{t} for each training point (For different epochs). We just consider the average value. We standardize these attention weights for plotting purposes. Fig 3-5 shows three examples of these weights for three test points. We select two test points that are well within one of the classes (The first two rows) and the second one that is close to the classification boundary (the last row).

It is clear from the figure that the test predictions mainly depend on only a subset of the data. The attention weights are typically high around the test points, but it is not always the case. If the test point is well within the boundary, then the final layer output depends mainly on the data in only one class (rows 1 and 2). The points near the classification boundary rely on the training data from both classes.

Conclusion

The dual form of the linear layer is used to connect the test predictions to training data. It can be used in any neural network that has a linear layer. This can be used to explain the test prediction using the training data. The method can be computationally infeasible for some large models or data settings because it needs to store the whole history of training. Understanding the evolution of \alpha_{t} throughout training can be exciting.

Leave a comment