This blog post is co-written with Moran beladev, Manos Stergiadis, and Ilya Gusev from Booking.com.
Large language models (LLMs) have revolutionized the field of natural language processing with their ability to understand and generate humanlike text. Trained on broad, generic datasets spanning a wide range of topics and domains, LLMs use their parametric knowledge to perform increasingly complex and versatile tasks across multiple business use cases. Furthermore, companies are increasingly investing resources in customizing LLMs through few-shot learning and fine-tuning to optimize their performance for specialized applications.
However, the impressive performance of LLMs comes at the cost of significant computational requirements, driven by their large number of parameters and autoregressive decoding process which is sequential in nature. This combination makes achieving low latency a challenge for use cases such as real-time text completion, simultaneous translation, or conversational voice assistants, where subsecond response times are critical.
Researchers developed Medusa, a framework to speed up LLM inference by adding extra heads to predict multiple tokens simultaneously. This post demonstrates how to use Medusa-1, the first version of the framework, to speed up an LLM by fine-tuning it on Amazon SageMaker AI and confirms the speed up with deployment and a simple load test. Medusa-1 achieves an inference speedup of around two times without sacrificing model quality, with the exact improvement varying based on model size and data used. In this post, we demonstrate its effectiveness with a 1.8 times speedup observed on a sample dataset.
Introduction to Medusa and its benefits for LLM inference speed
LLMs generate text in a sequential manner, which involves autoregressive sampling, with each new token conditional on the previous ones. Generating K tokens necessitates K sequential executions of the model. This token-by-token processing introduces an inherent latency and computational overhead because the model needs to perform a separate forward pass for each new token in the output sequence. The following diagram from Role-Play with Large Language Models illustrates this flow.
Speculative decoding tackles this challenge by using a smaller, faster draft model to generate multiple potential token continuations in parallel, which are then verified by a larger, more accurate target model. This parallelization speeds up text generation while maintaining the quality of the target model because the verification task is faster than autoregressive token generation. For a detailed explanation of the concept, refer to the paper Accelerating Large Language Model Decoding with Speculative Sampling. The speculative decoding technique can be implemented using the inference optimization toolkit on Amazon SageMaker Jumpstart.
The paper Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads introduced Medusa as an alternative to speculative decoding. Instead of adding a separate draft model, it adds extra decoding heads to the LLM that generate candidate continuations simultaneously. These candidates are then evaluated in parallel using a tree-based attention mechanism. This parallel processing reduces the number of sequential steps needed, leading to faster inference times. The main advantage of Medusa over speculative decoding is that it eliminates the need to acquire and maintain a separate draft model while achieving higher speedups. For example, when tested on the MT-Bench dataset, the paper reports that Medusa-2 (the second version of Medusa) speeds up inference time by 2.8 times. This outperforms speculative decoding, which only manages to speed up inference time by 1.5 times on the same dataset.
The Medusa framework currently supports Llama and Mistral models. Although it offers significant speed improvements, it does come with a memory trade-off (similar to speculative decoding). For instance, adding five Medusa heads to the 7-billion-parameter Mistral model increases the total parameter count by 750 million (150 million per head), which means these additional parameters must be stored in GPU memory, leading to a higher memory requirement. However, in most cases, this increase doesn’t necessitate switching to a higher GPU memory instance. For example, you can still use an ml.g5.4xlarge instance with 24 GB of GPU memory to host your 7-billion-parameter Llama or Mistral model with extra Medusa heads.
Training Medusa heads requires additional development time and computational resources, which should be factored into project planning and resource allocation. Another important limitation to mention is that the current framework, when deployed on an Amazon SageMaker AI endpoint, only supports a batch size of one—a configuration typically used for low-latency applications.
The following diagram from the original Medusa paper authors’ FasterDecoding repository gives a visual Medusa framework overview.
There are two main variants of Medusa:
Medusa-1 – Requires a two-stage approach where you first fine-tune your LLM and then add Medusa heads and train them on top of your frozen fine-tuned LLM
Medusa-2 – Introduced later as an improvement, fine-tunes both the additional heads and the backbone LLM parameters together, enabling potentially even further latency speedups
The Medusa paper reports that across models of varying sizes, you can achieve inference speedups of around two times for Medusa-1 and around three times for Medusa-2. With Medusa-1, the predictions are identical to those of the originally fine-tuned LLM. In contrast, with Medusa-2, we might observe slightly different results compared to simple fine-tuning of the LLM because both the heads and the backbone LLM parameters are updated together. In this post, we focus on Medusa-1.
Solution overview
We cover the following steps in our solution:
Prerequisites
Load and prepare the dataset
Fine-tune an LLM using a SageMaker AI training job
Train Medusa heads on top of a frozen fine-tuned LLM using a SageMaker AI training job
Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint
Demonstrate LLM inference speedup
By following this solution, you can accelerate LLM inference in your applications, leading to faster response times and improved user experience.
Prerequisites
To build the solution yourself, there are the following prerequisites:
You need an AWS account with an AWS Identity and Access Management (IAM) role that has permissions to manage resources created as part of the solution (for example AmazonSageMakerFullAccess and AmazonS3FullAccess). For details, refer to Creating an AWS account.
We use JupyterLab in Amazon SageMaker Studio running on an ml.t3.medium instance with a Python 3 (ipykernel) kernel. However, you can also use an Amazon SageMaker notebook instance (with a conda_pytorch_p310 kernel) or any integrated development environment (IDE) of your choice.
Be sure to set up your AWS Command Line Interface (AWS CLI) credentials correctly. For more information, refer Configure the AWS CLI.
The solution uses an ml.g5.4xlarge instance for the SageMaker AI training jobs, and three ml.g5.4xlarge instance are used for the SageMaker AI endpoints. Make sure you have sufficient capacity for this instance in your AWS account by requesting a quota increase if required. Also check the pricing of the on-demand instances to understand the associated costs.
To replicate the solution demonstrated in this post, you need to clone this GitHub repository. Within the repository, you can use the medusa_1_train.ipynb notebook to run all the steps in this post. This repository is a modified version of the original How to Fine-Tune LLMs in 2024 on Amazon SageMaker. We added simplified Medusa training code, adapted from the original Medusa repository.
Load and prepare the dataset
Now that you have cloned the GitHub repository and opened the medusa_1_train.ipynb notebook, you will load and prepare the dataset in the notebook. We encourage you to read this post while running the code in the notebook. For this post, we use a dataset called sql-create-context, which contains samples of natural language instructions, schema definitions and the corresponding SQL query. It contains 78,577 examples of natural language queries, SQL CREATE TABLE statements, and SQL queries answering the question using the CREATE statement as context. For demonstration purposes, we select 3,000 samples and split them into train, validation, and test sets.
You need to run the “Load and prepare the dataset” section of the medusa_1_train.ipynb to prepare the dataset for fine-tuning. We also included a data exploration script to analyze the length of input and output tokens. After data exploration, we prepare the train, validation, and test sets and upload them to Amazon Simple Storage Service (Amazon S3).
Fine-tune an LLM using SageMaker AI training job
We use the Zephyr 7B β model as our backbone LLM. Zephyr is a series of language models trained to act as helpful assistants, and Zephyr 7B β is a fine-tuned version of Mistral-7B-v0.1, trained on a mix of publicly available and synthetic datasets using Direct Preference Optimization.
To launch a SageMaker AI training job, we need to use the PyTorch or Hugging Face estimator. SageMaker AI starts and manages all the necessary Amazon Elastic Compute Cloud (Amazon EC2) instances for us, supplies the appropriate containers, downloads data from our S3 bucket to the container and uploads and runs the specified training script, in our case fine_tune_llm.py. We select the hyperparameters based on the QLoRA paper, but we encourage you to experiment with your own combinations. To expedite the execution of this code, we set the number of epochs to 1. However, for better results, it’s generally recommended to set the number of epochs to at least 2 or 3.
When our training job has completed successfully after approximately 1 hour, we can use the fine-tuned model artifact for the next step, training the Medusa heads on top of it. To visualize the training metrics in Tensorboard, you can follow the guidance in this documentation: Load and visualize output tensors using the TensorBoard application
Train Medusa heads on top of frozen fine-tuned LLM using a SageMaker AI training job
For training Medusa heads, we can reuse the functions previously mentioned to launch the training job. We selected hyperparameters based on a combination of what the Medusa paper reported and what we found to be best performing after a few experiments. We set the number of Medusa heads to 5 and used the 8-bit AdamW optimizer, as recommended by the paper. For simplicity, we maintained a constant learning rate of 1e-4 with a constant scheduler, similar to the previous fine-tuning step. Although the paper recommends an increased learning rate and a cosine scheduler, we found that our chosen combination of hyperparameters performed well on this dataset. However, we encourage you to experiment with your own hyperparameter settings to potentially achieve even better results.
We found that after 3 epochs, the evaluation loss of Medusa heads was converging, which can be observed in the TensorBoard graph in the following image.
Besides the hyperparameters, the main difference is that we pass train_medusa_heads.py as the training entrypoint, where we first add Medusa heads, then freeze the fine-tuned LLM, and we create custom MedusaSFTTrainer class, which is a subclass of the transformers SFTTrainer.
In the add_medusa_heads() function, we add the residual blocks of the Medusa heads, and also override the forward pass for our model to make sure not to train the frozen backbone LLM:
After the model training is finished (which takes 1 hour), we prepare the model artefacts for deployment and upload it to Amazon S3. Your final model artifact contains both the original fine-tuned model from the previous step under the base-model prefix and the trained Medusa heads in a file named medusa_heads.safetensors.
Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint
The Medusa framework is supported by the Text Generation Inference (TGI) server. After training the LLM with Medusa heads, we deploy it to a SageMaker AI real-time endpoint using the Hugging Face Inference Container set up with TGI.
First, we create a SageMaker AI HuggingFaceModel object and then deploy the model to an endpoint with the following function:
We deploy three LLMs on three SageMaker AI endpoints:
Base LLM which isn’t fine-tuned
The LLM that we fine-tuned
The fine-tuned LLM that also has trained Medusa heads
You can deploy the three models in parallel by using a function that we included in the notebook, or you can deploy the models one by one by running the code below:
After the status for each endpoint becomes InService, which should take around 15 minutes, we can invoke them for inference. We send the following input:
We can observe the following responses:
The base LLM response contains extra words that aren’t needed:
The fine-tuned LLM response is improved significantly, and contains only the required output:
The fine-tuned LLM with trained Medusa heads provides the exact same response as the fine-tuned model, demonstrating that Medusa-1, by design, maintains the output (quality) of the original model:
Demonstrate LLM inference speedup
To measure the inference speed improvements, we compare the response times of the deployed fine-tuned LLM and the fine-tuned LLM with Medusa heads on 450 test observations with the following code:
First, we run predictions using the fine-tuned LLM:
Then, we run predictions using the fine-tuned LLM with Medusa heads:
The prediction runs should take around 8 and 4 minutes respectively. We can observe that the average latency decreased from 950 to 530 milliseconds, which is an improvement of 1.8 times. You can achieve even higher improvements if your dataset contains longer inputs and outputs. In our dataset, we only had an average of 18 input tokens and 30 output tokens.
We want to once again highlight that, with this technique, the output quality is fully maintained, and all the prediction outputs are the same. The model responses for the test set of 450 observations are the same for both with Medusa heads and without Medusa heads:
You might notice in your run that a few observations aren’t exactly matching, and you might get a 99% match due to small errors in floating point operations caused by optimizations on GPUs.
Cleanup
At the end of this experiment, don’t forget to delete the SageMaker AI endpoints you created:
Conclusion
In this post, we demonstrated how to fine-tune and deploy an LLM with Medusa heads using the Medusa-1 technique on Amazon SageMaker AI to accelerate LLM inference. By using this framework and SageMaker AI scalable infrastructure, we showed how to achieve up to twofold speedups in LLM inference while maintaining model quality. This solution is particularly beneficial for applications requiring low-latency text generation, such as customer service chat assistants, content creation, and recommendation systems.
As a next step, you can explore fine-tuning your own LLM with Medusa heads on your own dataset and benchmark the results for your specific use case, using the provided GitHub repository.
About the authors
Daniel Zagyva is a Senior ML Engineer at AWS Professional Services. He specializes in developing scalable, production-grade machine learning solutions for AWS customers. His experience extends across different areas, including natural language processing, generative AI and machine learning operations.
Aleksandra Dokic is a Senior Data Scientist at AWS Professional Services. She enjoys supporting customers to build innovative AI/ML solutions on AWS and she is excited about business transformations through the power of data.
Moran Beladev is a Senior ML Manager at Booking.com. She is leading the content intelligence track which is focused on building, training and deploying content models (computer vision, NLP and generative AI) using the most advanced technologies and models. Moran is also a PhD candidate, researching applying NLP models on social graphs.
Manos Stergiadis is a Senior ML Scientist at Booking.com. He specializes in generative NLP and has experience researching, implementing and deploying large deep learning models at scale.
Ilya Gusev is a Senior Machine Learning Engineer at Booking.com. He leads the development of the several LLM systems inside Booking.com. His work focuses on building production ML systems that help millions of travelers plan their trips effectively.
Laurens van der Maas is a Machine Learning Engineer at AWS Professional Services. He works closely with customers building their machine learning solutions on AWS, specializes in natural language processing, experimentation and responsible AI, and is passionate about using machine learning to drive meaningful change in the world.