Explained like you are five

Temporal Fusion Transformers (TFTs) by hand.

Over the past couple of days, I've been running experiments using Temporal Fusion Transformers (aka TFTs). I haven't found any good resources that fully explore this model.

So, here's my attempt to explore Temporal Fusion Transformers layer by layer and understand how multi-horizon time series forecasting using transformers works.

In time series forecasting, the horizon refers to how far into the future you want to predict. It could be a day, a week, a month, or even a year. You've probably heard of short and long time horizons, right? That's what we mean when we use the word "horizon." TFT employs multi-horizon time series forecasting.

Multi-horizon time series forecasting refers to predicting multiple future time steps simultaneously. Instead of just forecasting the next immediate point, you're predicting several future points at once.

In single-step time series forecasting, you might forecast for the next 30 days (let's say), and your only options for forecasting further are to train a new model for each horizon, recursively forecast for each new horizon, or tune your model to predict multiple steps. But with multi-horizon time series forecasting, you can forecast for the next 360 days in one shot (let's say) by performing the same forecast 12 times.

In order to understand time series forecasting let us take the example of an ice cream shop. Let us say we are trying to forecast the footfall in the ice cream shop. We will call this the Target Variable.

In order to forecast the footfall, we have some data already available to us. These data points are called covariates. Covariates are helper variables that provide additional information to improve the accuracy and understanding of the time series forecast. They can represent various factors that influence the target variable. There are various kinds of covariates available to us.

NB:- Covariates can be categorical as well as numerical

  1. Known Future Information or future covariates This includes events or conditions that are known in advance, such as holidays, promotions, or scheduled maintenance.For example in our ice cream shop analogy, dates of upcoming summer festivals woudl be crucial because they would help us forecast the boost in ice cream sales
  2. Exogenous variables or past covariates- Exogenous variables are similar to independent variables because factors within the model don't affect them. Since the model can't predict the value of exogenous variables, they are considered independent.For our ice cream shop analogy the daily temperature would be an exogenous time series. Higher temperatures generally lead to increased ice cream sales.
  3. Static Metadata or static covariates- Imagine you're forecasting ice cream sales for different countries. Each country has its own sales pattern. Even though you're tracking sales over time, it's essentially like dealing with separate time series for each country. How do you distinguish these countries within your forecasting model? You add the country name as a future covariate. This is where static covariates come in. They represent those unchanging characteristics that differentiate one time series from another.

We are now going to go through every layer of the TFT model, compute the layer manually and understand what it is doing to the data. For this demonstration we will use the electricty dataset as was used by the Google research team to demo the model.

To obtain the electricity dataset run the code on Google Research Github - specifically the download_electricity function.

This dataset is not very clean but fortunately for us, Google has created a data formatter and some other kind sould has isolated this part of the code base, translated it into PyTorch and put it in a Jupyter Notebook. You can access this code in this Pytorch TFT Model Notebook here.

  1. ID- There are (in our train data) 369 Unique meter, each identified by it's own ID. It seems like we need to forecast the power usage in each of these meter IDs and we know that each meter will behave seperately while forecasting. As discussed above, this is the equivalent of having country name as a covariate while forecasting customer footfall. So these IDs will be the static covariates for our forecast.These IDs are processed and label encoded as Categorical ID. To learn more about label encoding refer to this blog post by GeeksForGeeks
  2. Hours from start- The OG Hours from start is now in the t column. The current hours from start column might look weird because it is z-score normalized. Please ignore looking at the actual value. This is what the data creators think you need to know about the dataset

Values are in kW of each 15 min. To convert values in kWh values must be divided by 4. Each column represent one client. Some clients were created after 2011. In these cases consumption were considered zero. All time labels report to Portuguese hour. However all days present 96 measures (24*4). Every year in March time change day (which has only 23 hours) the values between 1:00 am and 2:00 am are zero for all points. Every year in October time change day (which has 25 hours) the values between 1:00 am and 2:00 am aggregate the consumption of two hours.

  1. Power Usage - Read the blurb quoted above. You aren't actually 5

  2. Day of Week - For brevity I must point this out that this column was perfectly sane before being processed and has been z score normalized.


layers of the model

  1. Embedding Layers

    • Static Embeddings: Encodes static (non-time-varying) categorical features.
    • Time-varying Embeddings: Encodes categorical time-varying features using a TimeDistributed wrapper for temporal input.
  2. Linear Layers for Time-varying Inputs

    • Applies TimeDistributed linear transformations to numerical time-varying inputs.
  3. Variable Selection Networks

    • Encoder Variable Selection: Uses a Gated Residual Network (GRN) to select relevant input features dynamically.
    • Decoder Variable Selection: Similar GRN structure to select decoder inputs.
  4. LSTM Layers

    • LSTM Encoder: Captures temporal dependencies in input sequences.
    • LSTM Decoder: Decodes the sequence for predictions.
  5. Post-LSTM Layers

    • GLU and BatchNorm: Gated Linear Unit (GLU) and BatchNorm layers normalize and gate LSTM outputs.
  6. Static Enrichment Layer

    • Enhances temporal features with static information using GRN.
  7. Position Encoding

    • Adds positional information to input sequences for sequence modeling.
  8. Multi-head Attention Layer

    • Captures dependencies across time steps with attention mechanisms.
  9. Post-Attention Layers

    • GLU, BatchNorm, and GRN layers refine the output of the attention mechanism.
  10. Position-wise Feed-forward Layer

  1. Output Layer