Using Transformers for Tabular Data: A Complete Guide with AWS SageMaker

Quick Note: If you’re already familiar with training on SageMaker or are short on time, feel free to jump directly to the Putting It All Together section. This section links to the GitHub repository containing the complete project source files. You can always return to this guide for a more detailed view of the project components.


In the vast landscape of machine learning, the use of tabular data for predictive modeling remains a cornerstone. However, the methods we employ to handle such data are constantly evolving. In this post, we train a TabTransformer—a deep learning architecture designed to harness the power of attention—to solve a binary classification problem.

At its core, the TabTransformer leverages self-attention mechanism, which are fundamental to the Transformer architecture. Check out this online textbook and lecture to learn about self-attention and the Transformer. This approach enables the TabTransformer to generate rich and context-aware embeddings for categorical variables, enhancing the model’s ability to capture intricate relationships in the data.

This guide walks you through the steps for data processing, model training, and the deployment of a simple Flask-based endpoint for inference. To offer a bird’s-eye view of the project’s infrastructure, we utilize the following resources and tools:

AWSCloudFormationOrchestrates a cohesive stack of AWS services
Aurora MySQLDatabase for hyperparameter tuning results
SageMaker• Notebook instance for model development and training
Processing Jobs: Preprocess datasets for training
Training Jobs: Train models using preprocessed dataset
Secrets ManagerSafeguards database secrets
Elastic Container Registry & DockerDocker: Creates custom images for processing and serving the model
ECR: Stores custom Docker images
HydraConfiguration Management ToolAWS-specific settings: S3 bucket name, resource specs, training framework version
Training metadata: Categorical/numerical features, unique categories, fixed hyperparameters
Additional training configs: Computing resources, file paths, spot instance setup
PoetryDependency ManagementManages dependencies and packages the source code of the project
OptunaHyperparameter Optimization FrameworkEfficiently searches for optimal hyperparameters using various algorithms and samplers

Setting Up AWS Resources with CloudFormation

The following CloudFormation template is adapted from the aws-examples repository and sets up a various AWS resources necessary for the project. Here is the architecture diagram for a visual representation:

CloudFormation Template

The following CloudFormation template (link) is written in YAML. Note that understanding every intricacy of this template is not necessary for completing the project.

By default, Optuna retains its hyperparameter optimization history in memory. For a more durable and persistent logging solution, one can employ relational databases to store the hyperparameter tuning results.

We will leverage Amazon’s relational database service (RDS), Aurora MySQL, which acts as the backend for storing and navigating through the search space during hyperparameter optimization.

  • The Aurora database is deployed within a dedicated Virtual Private Cloud (Amazon VPC), ensuring isolation from other AWS networks. Given that there’s no requirement for this database to establish connections with the public internet, it’s strategically deployed within a private subnet of the VPC. A subnet is a range of IP addresses in the VPC. A subnet must reside in a single Availability Zone. After adding subnets, we can deploy the database in the VPC.

  • For model development, an Amazon SageMaker notebook instance is initiated. This serves as a Jupyter environment and is located within the same VPC. As this instance demands internet accessibility during the development phase, it’s positioned in a public subnet.

  • To connect to the Aurora database, we need to configure the notebook instance’s virtual firewall settings, the security group, and appropriately setting the route table to define the correct network traffic path. In addition, a training container is launched within the VPC for specific training tasks.

  • To facilitate model training and hosting in the container, a NAT Gateway is established and the security group is set to allow for outbound connections.

The sections below provide some additional details on the most important components in this template.

Relational Database Service (RDS)

The template provisions an RDS Aurora cluster. Aurora is a MySQL and PostgreSQL-compatible relational database engine. The DBCluster and DBInstance resources define the actual database cluster (group of instances) and instance respectively.

The template also utilizes AWS Secrets Manager (DBSecret resource) to create and store the database credentials securely. This ensures that the database username and password are not hard-coded in the template, providing an added layer of security.

DBClusterDefines the Aurora database cluster.
DBInstanceDefines the Aurora database instance within the cluster.
DBSecretCreates and stores the database credentials in AWS Secrets Manager.

To keep the template updated moving forward, see the release notes:

SageMaker Notebook Instance and Lifecycle Configuration

The template provisions a SageMaker Notebook instance (NotebookInstance resource). This instance serves as development environment for modeling and training.

Additionally, the NotebookLifecycleConfig resource defines a lifecycle configuration policy. This policy contains a base64-encoded shell script that, when decoded and executed, installs VS Code (a popular code editor) onto the SageMaker Notebook instance.

Feel free to add any other custom shell script. For encoding the script, we can use the following website.

NotebookInstanceDefines the SageMaker Notebook instance.
NotebookLifecycleConfigContains the script to install VS Code onto the Notebook instance during its startup.

IAM Policy Configuration

Identity and Access Management (IAM) in AWS helps securely control access to resources in AWS. The NotebookExecutionRole resource defines an IAM role that the SageMaker Notebook instance assumes to access other AWS services. This role has policies attached that grant full access to SageMaker and S3 services.

NotebookExecutionRoleAn IAM role with attached policies granting full access to SageMaker and S3.

This can be modified for more fine-grained minimum permissions if so desired.

VPC and Security Group

The template also sets up a new VPC, two public and private subnets each, and configures route tables for them.

Security groups act as virtual firewalls to control inbound and outbound traffic. The RDSSecurityGroup and SageMakerSecurityGroup resources define security groups for the RDS database and the SageMaker Notebook instance, respectively.

VPC, PublicSubnet*, PrivateSubnet*Define the VPC and associated subnets.
RDSSecurityGroupDefines the security group for the RDS database.
SageMakerSecurityGroupDefines the security group for the SageMaker Notebook instance.

Stack Creation

The template linked above can be saved as optuna_template.yaml and stored on S3. Then, a stack can be created as follows:

With this CloudFormation setup, the infrastructure required for the project will be fully provisioned and properly secured, ensuring a smooth and efficient development and deployment process.

Note: Some values such as DBInstanceType and SageMakerInstanceType are parameterized in the template, which means that they can be tweaked in the UI. The stack-name will be referenced during model training later, so it is recommended to store this information for use later.

Setting Up Project in the SageMaker Notebook Instance

Every SageMaker notebook instance comes equipped with a dedicated storage volume. To set up our project, we’ll begin by creating a shell script within this storage. From this point on, you could either use JupyerLab (provided by SageMaker) or VSCode, which we installed via the lifecycle configuration above.

Open the terminal, navigate to the SageMaker directory, and create a shell script:

$ cd /home/ec2-user/SageMaker
# You can use any text editor of your choice
$ nano

Copy and paste the following commands into the script:


# Activate conda environment
source ~/anaconda3/etc/profile.d/ && conda activate tensorflow2_p310 && pip3 install poetry

# This creates a new project with directories src and tests and a pyproject.toml file
poetry new income-classification --name src 
# Use sed to update the Python version constraint in pyproject.toml
sed --in-place 's/python = "^3.10"/python = ">=3.10, <3.12"/' income-classification/pyproject.toml 

# Install project, test, notebook dependencies
cd income-classification
poetry add "polars==0.18.15" "tensorflow==2.13.0" "tensorflow-io==0.32.0" "hydra-core==1.3.2" "boto3==1.26.131" "optuna==3.1.0" "s3fs==2023.6.0" "pymysql==1.1.0"
poetry add "pytest==7.4.2" --group test
poetry add "scikit-learn==1.3.1" "ipykernel==6.25.2" "ipython==8.15.0" "kaleido==0.2.1" "matplotlib==3.8.0" --group notebook

poetry install

The script above accomplishes the following:

  • Activates the pre-installed tensorflow2_p310 conda environment and installs Poetry

  • Creates a poetry-managed project named income-classification

  • Modifies the Python version constraint in pyproject.toml (as of the time of writing the article, the python version of the tensorflow2_p310 conda environment is 3.10)

  • Installs dependencies into separate groups:

    • Packages used for training
    • Packages used for testing
    • Packages used in the jupyter notebook for interactive development

    Note: We employ exact versioning (==) for all dependencies. This approach isn’t about futureproofing the package. Instead, the objective of packaging the training code is to firmly lock in the set of dependencies that have proven to work, ensuring consistency during testing, training, and notebook usage.

  • Installs the src package

Run the script:

$ bash

To confirm that the src package has been installed in the Conda environment

$ source activate tensorflow2_p310
$ conda list src
# Name                    Version                   Build  Channel
src                       0.1.0                    pypi_0    pypi

The pyproject.toml file should resemble:

name = "src"
version = "0.1.0"
description = ""
authors = ["Yang Wu"]
readme = ""

python = ">=3.10, <3.12"
polars = "0.18.15"
tensorflow = "2.13.0"
tensorflow-io = "0.32.0"
hydra-core = "1.3.2"
boto3 = "1.26.131"
optuna = "3.1.0"
s3fs = "2023.6.0"
pymysql = "1.1.0"

pytest = "7.4.2"

scikit-learn = "1.3.1"
ipykernel = "6.25.2"
ipython = "8.15.0"
kaleido = "0.2.1"
matplotlib = "3.8.0"

requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Checkpoint I

The income-classification directory should now be organized as follows:

├── poetry.lock
├── pyproject.toml
├── src
│   └──
└── tests

Configuration Management with Hydra

Managing configurations for machine learning projects can be cumbersome, especially as projects grow and involve multiple components. Hardcoding configurations directly into scripts or notebooks can lead to difficulties in maintenance, collaboration, and deployment. Hydra is a framework that simplifies complex configurations in applications, making them easier to manage, reuse, and override.

Why Use Hydra?

  • Centralization: All configurations are centralized in one or more YAML files, making it easier to track and make changes.

  • Flexibility: Enables easy override of configurations from the command line.

  • Structure: Organizes configurations hierarchically, making it simple to define and override specific subsets of configurations.

  • Reproducibility: By tracking configurations, it ensures consistent setups across different runs, aiding in reproducibility.

Main Configuration: main.yaml

The main.yaml serves as the central configuration file for the project, providing a single source of truth for various project settings, such as paths, AWS resources, metadata related to raw data, and more.

Configuration CategoryDescription
AWS ConfigDefines S3 bucket, key, model directory, output paths, and other AWS-specific configurations.
Spot TrainingConfigurations for using AWS Spot Instances, including max runtime, retries, and checkpoint locations.
File SystemDirectories for source code, notebooks, and containers.
Meta DataURLs for training and test data, CSV headers, target labels, and other metadata.
OptunaConfiguration specific to Optuna hyperparameter optimization, including stack name.

Additionally, the main.yaml configuration can point to group configurations stored as separate YAML files. In the provided configuration file below, there is one group configuration called tf_keras. Navigate to the src directory and create a config directory:

$ cd src && \
  mkdir -p config && \
  cd config && \
  nano main.yaml

This configuration file should be modified to your use case:

# AWS config
s3_bucket: your-s3-bucket
s3_key: s3-key-for-project
model_dir: /opt/ml/model
output_path: s3://your-s3-bucket/s3-key-for-project/models
code_location: s3://your-s3-bucket/s3-key-for-project/code
volume_size: 30
# Spot training
use_spot_instances: true
max_run: 86400
max_retry_attempts: 2
checkpoint_s3_uri: s3://your-s3-bucket/s3-key-for-project/checkpoints
# File system
src_dir_path: /home/ec2-user/SageMaker/income-classification/src
notebook_dir_path: /home/ec2-user/SageMaker/income-classification/notebooks
docker_dir_path: /home/ec2-user/SageMaker/income-classification/docker
# Meta data for ingestion and uploading to s3
processing_job_output: /opt/ml/processing
  - age
  - workclass
  - fnlwgt
  - education
  - education_num
  - marital_status
  - occupation
  - relationship
  - race
  - gender
  - capital_gain
  - capital_loss
  - hours_per_week
  - native_country
  - income_bracket
target: income_bracket
validation_size: 0.2
random_seed: 12
header: false
# Optuna
stack_name: optuna-stack
  - _self_
  - tf_keras: tf_keras

Group Configuration: tf_keras.yaml

Group configurations allow for the modular organization of related configurations. For instance, we might have different group configurations for various machine learning frameworks or algorithms. The tf_keras.yaml contains configurations specific to a Keras implementation of the model:

Configuration CategoryDescription
AWSAWS-specific configurations for the TensorFlow framework, like instance types, ECR repository, and base job name.
Training MetadataIncludes settings for the number of epochs, batch size, deterministic behavior, and other training-related configurations.
Feature MetadataSpecifies the categorical and numerical features, their vocabularies, and default values.

This approach of separating configurations into main and group files ensures modularity. Imagine also implementing models in PyTorch; we can have a separate group configuration (pytorch.yaml) for PyTorch-specific settings, while the main configurations remain in main.yaml and are shared by all group configurations.

Create the tf_keras subdirectory under config:

$ mkdir -p tf_keras && \
  cd tf_keras && \ 
  nano tf_keras.yaml

This configuration should be modified if necessary (e.g., training script name, training job name, instance type, ECR repository name, etc.):

framework_version: '2.13.0'
py_version: py310
ecr_repository: name-of-your-ecr-repository

instance_type: ml.p3.8xlarge
inference_instance_type: ml.c5.xlarge
instance_count: 1

train_base_job_name: tab-transformer-hpo
preprocess_base_job_name: tab-transformer-preprocess
endpoint_name: tf-keras-endpoint
model_name: tf-keras-model
study_name: tf_keras_hpo

# Stop training if the loss is not improving in this many epochs
patience: 3
# An int specifying the number of times this tf.dataset is repeated (if none, cycles through foreverxs)
num_epochs: 1
# If set to False, the transformation is allowed to yield elements out of order to trade determinism for performance
deterministic: false
# In distributed training, this can be multiplied by the number of replicas to a global batch size where each of the 'x' GPUs can processes a batch of (batch_size / x) samples
batch_size: 32
na_value: '?'
num_oov_indices: 0

  - ' <=50K'
  - ' >50K'
weight_feat: fnlwgt
  - age
  - education_num
  - capital_gain
  - capital_loss
  - hours_per_week
    - ' ?'
    - ' Federal-gov'
    - ' Local-gov'
    - ' Never-worked'
    - ' Private'
    - ' Self-emp-inc'
    - ' Self-emp-not-inc'
    - ' State-gov'
    - ' Without-pay'
    - ' 10th'
    - ' 11th'
    - ' 12th'
    - ' 1st-4th'
    - ' 5th-6th'
    - ' 7th-8th'
    - ' 9th'
    - ' Assoc-acdm'
    - ' Assoc-voc'
    - ' Bachelors'
    - ' Doctorate'
    - ' HS-grad'
    - ' Masters'
    - ' Preschool'
    - ' Prof-school'
    - ' Some-college'
    - ' Divorced'
    - ' Married-AF-spouse'
    - ' Married-civ-spouse'
    - ' Married-spouse-absent'
    - ' Never-married'
    - ' Separated'
    - ' Widowed'
    - ' ?'
    - ' Adm-clerical'
    - ' Armed-Forces'
    - ' Craft-repair'
    - ' Exec-managerial'
    - ' Farming-fishing'
    - ' Handlers-cleaners'
    - ' Machine-op-inspct'
    - ' Other-service'
    - ' Priv-house-serv'
    - ' Prof-specialty'
    - ' Protective-serv'
    - ' Sales'
    - ' Tech-support'
    - ' Transport-moving'
    - ' Husband'
    - ' Not-in-family'
    - ' Other-relative'
    - ' Own-child'
    - ' Unmarried'
    - ' Wife'
    - ' Amer-Indian-Eskimo'
    - ' Asian-Pac-Islander'
    - ' Black'
    - ' Other'
    - ' White'
    - ' Female'
    - ' Male'
    - ' ?'
    - ' Cambodia'
    - ' Canada'
    - ' China'
    - ' Columbia'
    - ' Cuba'
    - ' Dominican-Republic'
    - ' Ecuador'
    - ' El-Salvador'
    - ' England'
    - ' France'
    - ' Germany'
    - ' Greece'
    - ' Guatemala'
    - ' Haiti'
    - ' Holand-Netherlands'
    - ' Honduras'
    - ' Hong'
    - ' Hungary'
    - ' India'
    - ' Iran'
    - ' Ireland'
    - ' Italy'
    - ' Jamaica'
    - ' Japan'
    - ' Laos'
    - ' Mexico'
    - ' Nicaragua'
    - ' Outlying-US(Guam-USVI-etc)'
    - ' Peru'
    - ' Philippines'
    - ' Poland'
    - ' Portugal'
    - ' Puerto-Rico'
    - ' Scotland'
    - ' South'
    - ' Taiwan'
    - ' Thailand'
    - ' Trinadad&Tobago'
    - ' United-States'
    - ' Vietnam'
    - ' Yugoslavia'
# Default values for numerical and categorical features
  - 0.0
  - NA

Checkpoint II

Before moving on to modeling, we can check the current directory tree, which should resemble:

├── poetry.lock
├── pyproject.toml
├── src
│   ├── config
│   │   ├── main.yaml
│   │   └── tf_keras
│   │       └── tf_keras.yaml
│   └──
└── tests

Processing Job

Efficient and accurate data processing is pivotal for a machine learning project. By leveraging Amazon SageMaker’s processing jobs, we can seamlessly ingest, clean, split, and upload our data to S3 in preparation for model training.

The diagram below provides a visual representation of how SageMaker orchestrates a processing job. SageMaker takes our processing script, retrieves our data from S3 (if applicable), and then deploys a processing container. This container image can be a built-in SageMaker image or a custom one we provide. The advantage of processing jobs is that Amazon SageMaker handles the underlying infrastructure, ensuring resources are provisioned only for the duration of the job and then reclaimed afterward. Upon completion, the output of the processing job is stored in the specified Amazon S3 bucket.

Two additional resources to learn about SageMaker processing jobs:

Processing Entry Script

Create a (link) script in src. This script takes several crucial steps:

Initialization and Logger SetupSets up a logger to monitor progress and fetches configurations from the main.yaml using Hydra.
Data LoadingRetrieves the training and test datasets from their respective URLs, ensuring the correct headers are applied.
Ad-hoc Data CleaningStandardizes labels in the test dataset by removing any trailing dots, ensuring label consistency across datasets.
Validation SplitUses stratified shuffle split to segment the training dataset, thereby creating a validation subset. All three datasets’ (train, val, and test) dimensions are then logged.
Target Distribution ReportingCalculates and logs the class distribution of target in all three datasets.
Saving Processed DataThe preprocessed datasets are saved as CSV files in the designated output directory, unless the script is in test mode.

It’s important to note that the script heavily relies on the configurations from the main.yaml file (e.g., processing_job_output, csv_header, train_data_url). Any modifications to this configuration file will automatically be reflected in the processing script, eliminating the need to manually track these variables throughout the script, thus reducing potential errors.

Custom Docker Image for Processing

To facilitate a streamlined data processing workflow, we’ll create a custom Docker image. In this project, we will not only build an image to process our data but also build another image to serve our model once trained. To manage the creation and deployment of these images, we will create a parameterized bash script. This script will automate various steps, such as building and pushing custom images to Amazon Elastic Container Registry (ECR). This step requires that we create a private repository to store the docker images. To create a private repository, see the official AWS documentation. The bash script, shown below, takes three primary arguments:

  1. Docker image tag: groups different images and acts as an alias for the image ID
  2. Mode (‘preprocess’ or ‘serve’): either builds the processing image or the serving image
  3. ECR repository name: name of the ECR repository we created

Create a docker directory in the root directory of the project and add the following text files:

  1. Automate image building & pushing with This bash script includes steps to create and deploy custom Docker images.

# Always anchor the execution to the directory it is in, so we can run this bash script from anywhere
SCRIPT_DIR=$(python3 -c "import os; print(os.path.dirname(os.path.realpath('$0')))")

# Set BUILD_CONTEXT as the parent directory of SCRIPT_DIR

# Check if arguments are passed, otherwise prompt
if [ "$#" -eq 3 ]; then
    read -p "Enter the custom image tag name: " image_tag
    read -p "Serve or preprocess: " mode
    read -p "Enter the ECR repository name: " ecr_repo

# Check if the image tag is provided where [-z string]: True if the string is null (an empty string)
if [ -z "$image_tag" ] || [ -z "$ecr_repo" ]; then
  echo "Please provide both the custom image tag name and the ECR repository name."
  exit 1

# Choose Dockerfile based on mode
if [ "$mode" == "serve" ]; then
elif [ "$mode" == "preprocess" ]; then
    echo "Invalid mode specified, which must either be 'serving' or 'preprocess'."
    exit 1

# Variables
account_id=$(aws sts get-caller-identity --query Account --output text)
region=$(aws configure get region)

# Login to ECR based on ''
aws ecr get-login-password --region "$region" | docker login --username AWS --password-stdin "$account_id.dkr.ecr.$"

# Docker buildkit is required to use dockerfile specific ignore files
DOCKER_BUILDKIT=1 docker build \
    -t "$image_name" \

docker push "$image_name"
  1. Write the docker file preprocess.Dockerfile: This Dockerfile defines our custom image, installs the necessary dependencies, and sets up the processing script as its primary entry point. More details on the special naming convention can be found here.
FROM python:3.10.12-slim-bullseye


# Only copy files not listed in the dockerfile specific .dockerignore file
COPY ./src/ ./

RUN pip install polars==0.18.15 \
                scikit-learn==1.3.1 \

# Ensure python I/O is unbuffered so log messages are immediate
# Disable the generation of bytecode '.pyc' files

ENTRYPOINT ["python3", ""]
  1. Define docker ignore rules using preprocess.Dockerfile.dockerignore: To ensure our Docker image remains lightweight, this file specifies which files or directories to exclude during the Docker build process. Note that we only copy the processing script and the configuration yaml files onto the image. More details on syntax can be found here.

Checkpoint III

By the end of the data processing step, the directory tree should resemble:

├── docker
│   ├──
│   ├── preprocess.Dockerfile
│   └── preprocess.Dockerfile.dockerignore
├── poetry.lock
├── pyproject.toml
├── src
│   ├── config
│   │   ├── main.yaml
│   │   └── tf_keras
│   │       └── tf_keras.yaml
│   ├──
│   └──
└── tests

Training Job

Amazon SageMaker provides a robust platform for training machine learning models at scale. The infrastructure revolves around the concept of training jobs. These jobs are essentially encapsulated environments wherein models are trained using the data, training algorithms, and compute resources we specify.

The diagram below, taken from AWS’s official documentation, offers a visual representation of how SageMaker orchestrates a training job. Once a training job is initiated, SageMaker handles the heavy lifting: it deploys the ML compute instances, applies the training code and dataset to train the model, and subsequently saves the model artifacts in the designated S3 bucket.

Key Aspects of a SageMaker Training Job:

  • Training Data: Stored in an Amazon S3 bucket, the training data should reside in the same AWS Region as the training job. In our case, this is the data outputted by the processing job.

  • Compute Resources: These are the machine learning compute instances (EC2 instances) managed by SageMaker, tailored for model training. When we created the notebook instance, the EC2 instance with a storage volume and pre-installed conda environments is automatically provisioned.

  • Output: Results from the training job, including model artifacts, are stored in a specified S3 bucket.

  • Training Code: The location of the training code is typically specified via an Amazon Elastic Container Registry path if we are using a SageMaker built-in algorithm. In this project, we will use our custom training code in the src package.

For this specific project, while SageMaker offers a plethora of built-in algorithms and pre-trained models, we opt for a more tailored approach by using custom code. Our choice of deep-learning framework is TensorFlow, one of the most popular and versatile deep learning frameworks available. We predominantly use Keras, a high-level API that offers a more intuitive and streamlined interface for building and training models.

A few additional resources for SageMaker training jobs:

  • For documentation on API for TensorFlow, see here.

  • Example usage with Tensorflow and SageMaker Python SDK.

Local Mode with SageMaker’s Python SDK

With the SageMaker Python SDK, we can take advantage of the Local Mode feature. This powerful tool lets us create estimators, processors, and pipelines, then deploy them right in our local environment (SageMaker Notebook Instance). It’s an excellent way for us to test our deep learning training and processing scripts before transitioning them to SageMaker’s comprehensive training or hosting platforms.

Local Mode is compatible with any custom images we might want to use. To utilize Local Mode, we need to have Docker Compose V2 installed. We can use the installation guidelines from Docker. It’s crucial to ensure that our docker-compose version aligns with our docker engine installation. To determine a compatible version, refer to the Docker Engine release notes.

To check the compatibility of our Docker Engine with Docker Compose, run the following commands:

$ docker --version
$ docker-compose --version

After executing these, we should cross-reference these versions with those listed in the Docker Engine release notes to ensure compatibility. For reference, as of writing this tutorial, the versions on SageMaker notebooks are currently:

  • Docker: 20.10.23, build 7155243
  • Docker Compose: v2.21.0

If local model fails, try switch back to an older version of docker-compose and see the following github issues for more details:

# Select a compatible version
# Download docker compose based on version, kernel operating system (uname -s), and machine hardware (uname -m)
$ sudo curl --location${DOCKER_COMPOSE_VERSION}/docker-compose-`uname -s`-`uname -m` --output /usr/local/bin/docker-compose
# Make the Docker Compose binary executable
$ sudo chmod +x /usr/local/bin/docker-compose

Managed Spot Training

Another powerful feature of Amazon SageMaker is called Managed Spot Training, which allows us to train machine learning models using Amazon EC2 Spot instances. These Spot instances can be significantly cheaper compared to on-demand instances, potentially reducing the cost of training by up to \(90\%\).

Benefits of Using Managed Spot Training

  • Cost-Efficient: Spot instances can be much cheaper than on-demand instances, leading to substantial cost savings.

  • Managed Interruptions: Amazon SageMaker handles Spot instance interruptions, ensuring that our training process isn’t adversely affected.

  • Monitoring: Metrics and logs generated during the training runs are readily available in Amazon CloudWatch.

To enable spot training, we need to specify the following parameters when launching the training job:

  • max_run: Represents the maximum time (in seconds) the training job is allowed to run.

  • max_wait: This should be set to a value equal to or greater than max_run. It denotes the maximum time (in seconds) SageMaker waits for a Spot instance to become available.

  • max_retry_attempts: In the event of training failures, this parameter defines the maximum number of retry attempts.

  • use_spot_instances: Set this to True to use Spot instances for training. For on-demand instances, set this to False.

  • checkpoint_s3_uri: This is the S3 URI where training checkpoints will be saved, ensuring that in the event of interruptions, the training can be resumed from the last saved state.

The availability and potential interruption of spot instances are influenced by several factors including the type of instance (e.g., Multi-GPU, Single GPU, Multi-CPU), the geographical region, and the specific availability zone. For GPU-intensive tasks like training, there’s a possibility of encountering an ‘insufficient capacity error’. This happens when AWS lacks the requisite on-demand capacity for a particular Amazon EC2 instance type in a designated region or availability zone. It’s important to remember that capacity isn’t a fixed value; it fluctuates based on the time of day and the prevailing workloads within a given Region or Availability Zone.

To mitigate such capacity issues, there are several strategies we can adopt:

  • Consider switching to a different instance type that may have more available capacity.
  • Try changing to a different size within the same instance family, which might offer a balance between performance and availability.
  • Given that our CloudFormation template launches a notebook instance, another approach is to launch the instance using the desired type but specify subnets across more availability zones (e.g., the template above provisions a notebook instance that spans two availability zones). This diversifies the launch attempts and may increase the likelihood of successful provisioning. However, always ensure to cross-check that the SageMaker instance types are available in the chosen Region.

TabTransformer Architecture

The TabTransformer architecture is built upon the Transformers, which is popularized because of its state-of-the-art performances in solving Natural Language Processing tasks. As detailed in the paper “TabTransformer: Tabular Data Modeling Using Contextual Embeddings”, the authors adapted the attention-based architecture for tabular data. This innovative architecture is composed of three core components:

  1. Column Embedding Layer: This layer learns parametric embeddings for each categorical feature in the dataset. In the original paper, the continuous features are simply normalized and concatenated with the transformed categorical feature embeddings before being fed into a FeedForward Network. There exists another deep-learning-based architecture, namely, the FT-Transformer, that transforms both categorical and continuous features to embeddings.

  2. Stack of N Transformer Layers: These layers apply sequences of multi-head attention mechanisms to transform the column embeddings into contextual embeddings. The presence of multiple Transformer layers offers a deep and rich representation of the input data. The number of transformer layers can be treated as a hyperparameter.

  3. Multi-Layer Perceptron (MLP): Positioned after the Transformer layers, a final MLP is used to aid in the final processing of the transformed data for predictions or other downstream tasks. Note that this is separate from the position-wise feedforward network inside of each transformer layer.

Every Transformer layer in the architecture diagram is based on the original Transformer model. It comprises a multi-head self-attention mechanism followed by a position-wise feed-forward network. The architectural diagram from the TabTransformer paper is illustrated below:

Each transformer layer consists of:

  • Multi-Head Self-Attention: Allows simultaneous focus on different parts of the input.
  • Position-wise Feed-Forward Network: Processes each position in the input independently.
  • Addition (Skip connection) & Layer Normalization: Applied after each multi-head self-attention and MLP to help with training deeper models.

Focal Loss Function

In binary classification problems, particularly those with a skewed class distribution, standard loss functions like cross entropy can be dominated by the majority class, leading to suboptimal model performance. To address class imbalance in our dataset, we use the Focal Loss function, which was originally introduced in the context of object detection when the foreground and background classes are highly imbalanced.

Cross-Entropy Loss

Starting from the foundation, the traditional cross entropy (CE) loss for binary classification is given by

\[ \operatorname{CE}(p, y)= \begin{cases}-\log (p) & \text { if } y=1 \\ -\log (1-p) & \text { otherwise }\end{cases} \] Where:

  • \(y \in\{ \pm 1\}\) denotes the ground-truth class
  • \(p \in[0,1]\) is the model’s estimated probability for the positive class \(y=1\)

To simplify notation, the original paper introduces the term \(p_t\) defined as:

\[ p_t= \begin{cases}p & \text { if } y=1 \\ 1-p & \text { otherwise }\end{cases} \] This allows us to concisely express the cross-entropy loss as:

\[\operatorname{CE}(p, y)=\operatorname{CE}\left(p_t\right)=-\log \left(p_t\right)\] Alpha-balanced Cross-Entropy

To counter the domination of the majority class in the loss computation, we can employ a weighting factor \(\alpha\) to adjust the contribution of each class to the overall loss. This is represented by the factor \(\alpha \in[0,1]\) for the positive class and its complement \(1-\alpha\) for the negative class. This factor, in practice, can be determined by the inverse class frequency or be treated as a hyperparameter, optimized via cross-validation. The \(\alpha\)-balanced cross entropy (CE) loss becomes:

\[ \operatorname{CE}\left(p_t\right)=-\alpha_t \log \left(p_t\right) \]

This alpha-balanced formulation ensures that each class’s contribution to the loss is adjusted based on its prevalence, offering a direct mechanism to tackle class imbalance.

Focal Loss

The focal loss introduces an additional modulating factor, \(\left(1-p_t\right)^\gamma\), to the CE loss. This factor down-weights the contribution of easy examples regardless of class membership, allowing the model to focus more on hard-to-classify examples. The focal loss is thus reformulated as:

\[ \operatorname{FL}\left(p_{\mathrm{t}}\right)=-\left(1-p_{\mathrm{t}}\right)^\gamma \log \left(p_{\mathrm{t}}\right) \]

Substituting the original expressions for \(p_{t}\):

\[ \operatorname{FL}\left(p, y\right)=\begin{cases} -(1-p)^\gamma \log(p) & \text { if } y=1 \\ -p^\gamma \log(1 - p) & \text { otherwise }\end{cases} \] This is because when \(y=1\), \(p_t=p\):

\[ -(1-p)^\gamma \log(p) \] When \(y=-1\) or \(y=0\), \(p_t=(1 - p)\):

\[ \begin{align*} &-(1-(1-p))^\gamma \log(1-p) \\ &-(1-1+p)^{\gamma} \log(1-p) \\ &-(p)^{\gamma} \log(1-p) \end{align*} \] Key Properties of the Focal Loss Function

Intuitively, the focal loss function has two main properties:

Key Property I

Misclassified (harder) examples for both positive and negative classes exhibit smaller \(p_t\) values:

  • For positive examples \(y=1, p_t=p\). A misclassification implies that the predicted probability \(p_t=p\) is small when it ideally should be close to 1
  • For negative examples \(y=-1\) or \(y=0, p_t=1-p\). Here, misclassification means that the predicted probability \(p\) is large (when it should be close to 0), making \(p_t=1 - \text{(large $p$ close to 1)}\) small

Given this, the modulating factor \(\left(1-p_t\right)^\gamma\) will be close to 1 for misclassified examples:

\[ \begin{aligned} \mathrm{FL}\left(p_t\right) & =-\left(1-\text { small } p_t \text { close to } 0\right)^\gamma \log \left(p_t\right) \\ & =-(\text { a value close to } 1)^\gamma \log \left(p_t\right) \end{aligned} \]

For \(\gamma \in[0,5]\), a value close to 1 raised to the power of \(\gamma\) remains approximately 1 since \(\text{(a value close to 1)}^\gamma \approx 1\).

On the other hand, for correctly classified (easier) examples:

  • For positive examples \(y=1\), the predicted probability \(p_t=p\) is close to 1, indicating a correct classification
  • For negative examples \(y=-1\) or \(y=0, p_t=1-p\) is close to 1 since \(p\) is close to 0

As \(p_t\) approaches 1 for correctly classified examples, their contribution to the loss is reduced:

\[ \begin{aligned} \mathrm{FL}\left(p_t\right) & =-\left(1-\text { large } p_t \text { close to } 1\right)^\gamma \log \left(p_t\right) \\ & =-(\text { a value close to } 0)^\gamma \log \left(p_t\right) \end{aligned} \]

As can be seen, the modulating factor \((\text{a value close to 0})^{\gamma}\) becomes small for well-classified examples. The aggregate affect is that the optimization is forced to focus on the harder examples since they now dominate the loss computations.

Key Property II

The tunable focusing parameter \(\gamma \geq 0\) determines the extent to which easy examples are down-weighted. For instance, with \(\gamma=2\), an example classified with \(p_t=0.9\) would have a loss 100 times lower than the CE loss. When \(p_t \approx 0.968\), the loss is reduced by a factor of 1000. In practice, an alpha-balanced variant of the Focal Loss is employed:

\[ \mathrm{FL}\left(p_t\right)=-\alpha_t\left(1-p_t\right)^\gamma \log \left(p_t\right) \]

Incorporating this modulation and balancing, the focal loss offers a robust mathematical strategy for effectively managing class imbalance. It ensures that during training, the optimization focuses on the under-represented class.

Training Code Overview

Having established the theoretical foundation, it’s time to transition into the practical implementations. Our custom training code encompasses the following modules:

  • (link): a suite of utility functions and classes to support logging, configuration and Optuna hyperparameter optimization setup, as well as data sampling for testing. It also contains utility functions for integrating with AWS services like Secrets Manager and S3 for cloud-based training.
get_loggerSets up logging to facilitate tracking and debugging throughout the training process.
parserExtracts command-line arguments necessary for the training process, such as the database details, directory paths, and mode (test or train).
add_additional_argsA decorator that extends the base argument parser, allowing for additional arguments to be dynamically added without altering the base parser.
get_secretRetrieves secrets (like database credentials) from the AWS Secrets Manager.
get_db_urlConstructs a database URL for connecting to the Optuna database, using secrets fetched from AWS Secrets Manager.
create_studyEstablishes an Optuna study instance for hyperparameter optimization.
study_reportOutputs the results of the study, including pruned and completed trials, the best trial score, and the optimal parameters.
StudyVisualizerProvides visualization capabilities for understanding the hyperparameter tuning results using Optuna’s built-in visualizations.
dataset_from_csvCreates a TensorFlow Dataset directly from a CSV file, with processing capabilities to handle features, labels, and weights.
stratified_sampleProduces a stratified sample from a dataframe, ensuring all unique categorical values are represented. It’s especially useful for generating small samples for testing purposes.
test_sampleReads CSV files from S3, creates a stratified sample, and then converts it to a TensorFlow Dataset. This is primarily used for testing in SageMaker’s local mode.
  • (link): this module contains the main training and hyperparameter optimization logic for a TabTransformer model. It establishes the trainer class and the objective function for Optuna optimization. The script is also equipped to handle both local testing and cloud training.
TabTransformerTrainer- Initialization: Sets up training environment, datasets, and configurations.
- Optimizer Creation: Sets up the Adam optimizer based on hyperparameters.
- Loss Function Creation: Chooses between Binary Crossentropy or Focal Loss based on hyperparameters.
- Metrics Creation: Establishes metrics for model evaluation using several metrics including Binary Accuracy, Recall, Precision, and AUC-PR.
- Input Creation & Encoding: Processes and encodes input features for the model.
- Model Creation: Constructs the tab-transformer model.
- Model Training: Defines training using the fit API and trains the model, leveraging distributed training for cloud-based training and CPU-based training for local testing.
tf_objective- Creates a TabTransformerTrainer instance per trial.
- Trains and evaluates the model for each hyperparameter set.
- Retrains the model using best hyperparameters on the full dataset after concatenating the training and validation set.
main- Sets up logging and configuration.
- Initializes distributed training strategy for cloud-based training (single host multi-device). In testing (local mode), CPU-based training is utilized with a small subset of the data.
- Loads datasets.
- Executes Optuna optimization.
- Retrieves and saves the best model using the SaveModel format.

More details on distributed training with Keras’ API can be found here. Finally, in order to use additional libraries during training, we create a requirements.txt file inside src:



In the training script, we utilized Optuna for hyperparameter optimization. The table below details each hyperparameter and provides a brief description:

transformer_num_layersNumber of transformer blocks/layers.
transformer_num_headsNumber of attention heads in the multi-head attention mechanism within each transformer block.
transformer_embedding_dimsDimensionality of the embeddings in the transformer block.
transformer_dropout_rateDropout rate for dropping out entire tokens to attend to.
mlp_num_hidden_layersNumber of hidden layers in the Multi-Layer Perceptron (MLP) classifier.
mlp_dropout_rateDropout rate used in the MLP.
use_focal_lossBoolean indicating if focal loss should be used (alternatively, binary cross-entropy loss is used).
adam_learning_rateLearning rate for the Adam optimizer.
adam_beta_1The exponential decay rate for the first moment estimates in the Adam optimizer.
adam_beta_2The exponential decay rate for the second-moment estimates in the Adam optimizer.
adam_epsilonA small constant to prevent division by zero in the Adam optimizer.
adam_clipnormGradient clipping norm for the Adam optimizer.
fit_epochsNumber of epochs for model training.
mlp_hidden_units_multiplesList of multipliers to compute the number of units in each hidden layer of the MLP.
loss_apply_class_balancingBoolean indicating if class balancing should be applied when using focal loss.
loss_alphaAlpha parameter for the focal loss (modulates the class imbalance).
loss_gammaGamma parameter for the focal loss (modulates the rate at which easy examples are down-weighted).

These hyperparameters are tuned using Optuna to optimize the model’s performance on the validation set.

Checkpoint IV

The directory tree after adding the training code should resemble:

├── docker
│   ├──
│   ├── preprocess.Dockerfile
│   └── preprocess.Dockerfile.dockerignore
├── poetry.lock
├── pyproject.toml
├── src
│   ├── config
│   │   ├── main.yaml
│   │   └── tf_keras
│   │       └── tf_keras.yaml
│   ├──
│   ├──
│   ├──
│   ├── requirements.txt
│   └──
└── tests

Model Inference

In this section, we discuss the deployment process of our trained model for real-time inference. This involves creating a second custom Docker image that serves the model, managing necessary dependencies, and ensuring efficient request handling.

Flask-Based Model Inference

The inference script (link) establishes a Flask server for serving a trained Keras model and managing inference requests. The script initiates a Flask application offering two endpoints: /ping and /invocations. The former serves as a health check, while the latter manages model inference.

Key Functions and Descriptions

serve()Initializes the Flask application and configures its logging. Sets up two main routes: /ping (health check) and /invocations (model inference).
load_trained_model()Loads the trained model into memory. Uses the @lru_cache() decorator to ensure the model is loaded just once, optimizing response times.
predict()Core function for model inference. Manages data preprocessing (extracting data, reshaping, converting to TensorFlow dataset), model inference (using loaded model, transforming predictions), and error handling (data parsing issues, inference errors). Returns model’s predictions in a JSON string format.

To deploy the model as a SageMaker Endpoint, we create a serve script inside the src directory. SageMaker will then initiate the hosting by executing the command docker run <image-id> serve. To ensure the script runs with the appropriate Python interpreter, we’ll incorporate the Shebang at the beginning of our script. This step eliminates the necessity of using the .py file extension when executing the script, making it directly executable through the docker run command.

Custom Docker Image for Inference

To enable efficient model serving, we will create a custom Docker image specifically tailored for inference. This image will encapsulate our trained model, its dependencies, and the logic to handle incoming requests and respond with predictions. Similar to the processing job section, we automate the build steps using the bash script. Create the following additional text files in the docker directory:

  1. Write the Docker File serve.Dockerfile: This Dockerfile will define our serving image. It installs necessary dependencies and sets up the model serving script as its main entry point.
# See
FROM tensorflow/tensorflow:2.13.0

RUN apt-get update && \
    apt-get install -y --no-install-recommends nginx curl && \
    pip3 install \
    sagemaker \
    sagemaker-training \
    polars \
    boto3 \

# Ensure python I/O is unbuffered so log messages are immediate
# Disable the generation of bytecode '.pyc' files
# Set environment variable for the source code directory
ENV PROGRAM_PATH='/opt/program'
# Ensure PROGRAM_PATH is included in the search path for executables, so any executables (i.e., 'serve') in that directory will take precedence over other executables
ENV PATH="/opt/program:${PATH}"

# Copy src directory into the container, ignoring all files except for `serve`

# Change permissions of the 'serve' script
RUN chmod +x $PROGRAM_PATH/serve

  1. Define Docker Ignore Rules using serve.Dockerfile.dockerignore: For serving, we only need the serving script serve since SageMaker will copy and inject the trained model artifacts onto our image at runtime.

Using this custom Docker image, we can now seamlessly deploy our trained model to any environment supporting Docker containers, including cloud-based platforms like AWS SageMaker, ensuring a consistent and scalable inference experience.

How SageMaker Handles Model Artifacts

When setting up our model with SageMaker using the CreateModel request, the container definition incorporates a ModelDataUrl parameter. This parameter specifies the S3 location containing the model artifacts. SageMaker leverages this URL to fetch and copy the model artifacts into the /opt/ml/model directory, making them accessible for inference when the container initiates.

It’s vital that the ModelDataUrl points to a tar.gz compressed file. If it doesn’t, SageMaker won’t retrieve the file.Because our model was trained within SageMaker, its artifacts are stored in Amazon S3 as a singular compressed tar file.

All of these will be handled appropriately in the notebook linked in the section that follows immediately.

Checkpoint V

After adding the inference script and dockerfile, the directory tree should contain the following:

├── docker
│   ├──
│   ├── preprocess.Dockerfile
│   ├── preprocess.Dockerfile.dockerignore
│   ├── serve.Dockerfile
│   └── serve.Dockerfile.dockerignore
├── poetry.lock
├── pyproject.toml
├── src
│   ├── config
│   │   ├── main.yaml
│   │   └── tf_keras
│   │       └── tf_keras.yaml
│   ├──
│   ├──
│   ├──
│   ├── requirements.txt
│   ├── serve
│   └──
└── tests

Finally, more details on building custom images for model hosting can be found here.

Putting It All Together

Now, the culmination of our efforts is the ability to launch all our steps and jobs – the data preprocessing, model training, and the building and pushing of our Docker images – through a singular interface using the SageMaker Python SDK. By leveraging Jupyter notebooks, we can execute these tasks seamlessly, with the added advantage of documentation and the flexibility to modify or scale as needed.

For a comprehensive view of how everything ties together, you can explore the complete notebook here. This notebook serves as a blueprint, illustrating the integration of all the steps discussed in the guide, thus providing a holistic view of the end-to-end process.