Achieve ~2x speed-up in LLM inference with Medusa-1 on Amazon SageMaker AI

Achieve ~2x speed-up in LLM inference with Medusa-1 on Amazon SageMaker AI

This blog post is co-written with Moran beladev, Manos Stergiadis, and Ilya Gusev from Booking.com.

Large language models (LLMs) have revolutionized the field of natural language processing with their ability to understand and generate humanlike text. Trained on broad, generic datasets spanning a wide range of topics and domains, LLMs use their parametric knowledge to perform increasingly complex and versatile tasks across multiple business use cases. Furthermore, companies are increasingly investing resources in customizing LLMs through few-shot learning and fine-tuning to optimize their performance for specialized applications.

However, the impressive performance of LLMs comes at the cost of significant computational requirements, driven by their large number of parameters and autoregressive decoding process which is sequential in nature. This combination makes achieving low latency a challenge for use cases such as real-time text completion, simultaneous translation, or conversational voice assistants, where subsecond response times are critical.

Researchers developed Medusa, a framework to speed up LLM inference by adding extra heads to predict multiple tokens simultaneously. This post demonstrates how to use Medusa-1, the first version of the framework, to speed up an LLM by fine-tuning it on Amazon SageMaker AI and confirms the speed up with deployment and a simple load test. Medusa-1 achieves an inference speedup of around two times without sacrificing model quality, with the exact improvement varying based on model size and data used. In this post, we demonstrate its effectiveness with a 1.8 times speedup observed on a sample dataset.

Introduction to Medusa and its benefits for LLM inference speed

LLMs generate text in a sequential manner, which involves autoregressive sampling, with each new token conditional on the previous ones. Generating K tokens necessitates K sequential executions of the model. This token-by-token processing introduces an inherent latency and computational overhead because the model needs to perform a separate forward pass for each new token in the output sequence. The following diagram from Role-Play with Large Language Models illustrates this flow.

Speculative decoding tackles this challenge by using a smaller, faster draft model to generate multiple potential token continuations in parallel, which are then verified by a larger, more accurate target model. This parallelization speeds up text generation while maintaining the quality of the target model because the verification task is faster than autoregressive token generation. For a detailed explanation of the concept, refer to the paper Accelerating Large Language Model Decoding with Speculative Sampling. The speculative decoding technique can be implemented using the inference optimization toolkit on Amazon SageMaker Jumpstart.

The paper Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads introduced Medusa as an alternative to speculative decoding. Instead of adding a separate draft model, it adds extra decoding heads to the LLM that generate candidate continuations simultaneously. These candidates are then evaluated in parallel using a tree-based attention mechanism. This parallel processing reduces the number of sequential steps needed, leading to faster inference times. The main advantage of Medusa over speculative decoding is that it eliminates the need to acquire and maintain a separate draft model while achieving higher speedups. For example, when tested on the MT-Bench dataset, the paper reports that Medusa-2 (the second version of Medusa) speeds up inference time by 2.8 times. This outperforms speculative decoding, which only manages to speed up inference time by 1.5 times on the same dataset.

The Medusa framework currently supports Llama and Mistral models. Although it offers significant speed improvements, it does come with a memory trade-off (similar to speculative decoding). For instance, adding five Medusa heads to the 7-billion-parameter Mistral model increases the total parameter count by 750 million (150 million per head), which means these additional parameters must be stored in GPU memory, leading to a higher memory requirement. However, in most cases, this increase doesn’t necessitate switching to a higher GPU memory instance. For example, you can still use an ml.g5.4xlarge instance with 24 GB of GPU memory to host your 7-billion-parameter Llama or Mistral model with extra Medusa heads.

Training Medusa heads requires additional development time and computational resources, which should be factored into project planning and resource allocation. Another important limitation to mention is that the current framework, when deployed on an Amazon SageMaker AI endpoint, only supports a batch size of one—a configuration typically used for low-latency applications.

The following diagram from the original Medusa paper authors’ FasterDecoding repository gives a visual Medusa framework overview.

There are two main variants of Medusa:

Medusa-1 – Requires a two-stage approach where you first fine-tune your LLM and then add Medusa heads and train them on top of your frozen fine-tuned LLM
Medusa-2 – Introduced later as an improvement, fine-tunes both the additional heads and the backbone LLM parameters together, enabling potentially even further latency speedups

The Medusa paper reports that across models of varying sizes, you can achieve inference speedups of around two times for Medusa-1 and around three times for Medusa-2. With Medusa-1, the predictions are identical to those of the originally fine-tuned LLM. In contrast, with Medusa-2, we might observe slightly different results compared to simple fine-tuning of the LLM because both the heads and the backbone LLM parameters are updated together. In this post, we focus on Medusa-1.

Solution overview

We cover the following steps in our solution:

Prerequisites
Load and prepare the dataset
Fine-tune an LLM using a SageMaker AI training job
Train Medusa heads on top of a frozen fine-tuned LLM using a SageMaker AI training job
Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint
Demonstrate LLM inference speedup

By following this solution, you can accelerate LLM inference in your applications, leading to faster response times and improved user experience.

Prerequisites

To build the solution yourself, there are the following prerequisites:

You need an AWS account with an AWS Identity and Access Management (IAM) role that has permissions to manage resources created as part of the solution (for example AmazonSageMakerFullAccess and AmazonS3FullAccess). For details, refer to Creating an AWS account.
We use JupyterLab in Amazon SageMaker Studio running on an ml.t3.medium instance with a Python 3 (ipykernel) kernel. However, you can also use an Amazon SageMaker notebook instance (with a conda_pytorch_p310 kernel) or any integrated development environment (IDE) of your choice.
Be sure to set up your AWS Command Line Interface (AWS CLI) credentials correctly. For more information, refer Configure the AWS CLI.
The solution uses an ml.g5.4xlarge instance for the SageMaker AI training jobs, and three ml.g5.4xlarge instance are used for the SageMaker AI endpoints. Make sure you have sufficient capacity for this instance in your AWS account by requesting a quota increase if required. Also check the pricing of the on-demand instances to understand the associated costs.
To replicate the solution demonstrated in this post, you need to clone this GitHub repository. Within the repository, you can use the medusa_1_train.ipynb notebook to run all the steps in this post. This repository is a modified version of the original How to Fine-Tune LLMs in 2024 on Amazon SageMaker. We added simplified Medusa training code, adapted from the original Medusa repository.

Load and prepare the dataset

Now that you have cloned the GitHub repository and opened the medusa_1_train.ipynb notebook, you will load and prepare the dataset in the notebook. We encourage you to read this post while running the code in the notebook. For this post, we use a dataset called sql-create-context, which contains samples of natural language instructions, schema definitions and the corresponding SQL query. It contains 78,577 examples of natural language queries, SQL CREATE TABLE statements, and SQL queries answering the question using the CREATE statement as context. For demonstration purposes, we select 3,000 samples and split them into train, validation, and test sets.

You need to run the “Load and prepare the dataset” section of the medusa_1_train.ipynb to prepare the dataset for fine-tuning. We also included a data exploration script to analyze the length of input and output tokens. After data exploration, we prepare the train, validation, and test sets and upload them to Amazon Simple Storage Service (Amazon S3).

Fine-tune an LLM using SageMaker AI training job

We use the Zephyr 7B β model as our backbone LLM. Zephyr is a series of language models trained to act as helpful assistants, and Zephyr 7B β is a fine-tuned version of Mistral-7B-v0.1, trained on a mix of publicly available and synthetic datasets using Direct Preference Optimization.

To launch a SageMaker AI training job, we need to use the PyTorch or Hugging Face estimator. SageMaker AI starts and manages all the necessary Amazon Elastic Compute Cloud (Amazon EC2) instances for us, supplies the appropriate containers, downloads data from our S3 bucket to the container and uploads and runs the specified training script, in our case fine_tune_llm.py. We select the hyperparameters based on the QLoRA paper, but we encourage you to experiment with your own combinations. To expedite the execution of this code, we set the number of epochs to 1. However, for better results, it’s generally recommended to set the number of epochs to at least 2 or 3.

from sagemaker.pytorch.estimator import PyTorch
from sagemaker.debugger import TensorBoardOutputConfig
import time
import os

def get_current_time():
return time.strftime(“%Y-%m-%d-%H-%M-%S”, time.localtime())

def create_estimator(hyperparameters_dict, job_name, role, sess, train_scipt_path):
metric=[
{“Name”: “loss”, “Regex”: r”‘loss’:s*([0-9.]+)”},
{“Name”: “epoch”, “Regex”: r”‘epoch’:s*([0-9.]+)”},
]

tensorboard_s3_output_path = os.path.join(
“s3://”, sess.default_bucket(), job_name, ‘tensorboard’
)
print(“Tensorboard output path:”, tensorboard_s3_output_path)

tensorboard_output_config = TensorBoardOutputConfig(
s3_output_path=tensorboard_s3_output_path,
container_local_output_path=hyperparameters_dict[‘logging_dir’]
)
estimator = PyTorch(
sagemaker_session = sess,
entry_point = train_scipt_path, # train script
source_dir = ‘train’, # directory which includes all the files needed for training
instance_type = ‘ml.g5.4xlarge’, # instances type used for the training job, “local_gpu” for local mode
metric_definitions = metric,
instance_count = 1, # the number of instances used for training
role = role, # Iam role used in training job to access AWS ressources, e.g. S3
volume_size = 300, # the size of the EBS volume in GB
framework_version = ‘2.1.0’, # the pytorch_version version used in the training job
py_version = ‘py310’, # the python version used in the training job
hyperparameters = hyperparameters_dict, # the hyperparameters passed to the training job
disable_output_compression = True, # not compress output to save training time and cost
tensorboard_output_config = tensorboard_output_config
)
return estimator

# hyperparameters, which are passed into the training job
sft_hyperparameters = {
### SCRIPT PARAMETERS ###
‘train_dataset_path’: ‘/opt/ml/input/data/train/train_dataset.json’, # path where sagemaker will save training dataset
‘eval_dataset_path’: ‘/opt/ml/input/data/eval/eval_dataset.json’, # path where sagemaker will save evaluation dataset
‘model_id’: model_id,
‘max_seq_len’: 256, # max sequence length for model and packing of the dataset
‘use_qlora’: True, # use QLoRA model
### TRAINING PARAMETERS ###
‘num_train_epochs’: 1, # number of training epochs
‘per_device_train_batch_size’: 1, # batch size per device during training
‘gradient_accumulation_steps’: 16, # number of steps before performing a backward/update pass
‘gradient_checkpointing’: True, # use gradient checkpointing to save memory
‘optim’: “adamw_8bit”, # use fused adamw 8bit optimizer
‘logging_steps’: 15, # log every 10 steps
‘save_strategy’: “steps”, # save checkpoint every epoch
‘save_steps’: 15,
‘save_total_limit’: 2,
‘eval_strategy’: “steps”,
‘eval_steps’: 15,
‘learning_rate’: 1e-4, # learning rate, based on QLoRA paper
‘bf16’: True, # use bfloat16 precision
‘max_grad_norm’: 10, # max gradient norm based on QLoRA paper
‘warmup_ratio’: 0.03, # warmup ratio based on QLoRA paper
‘lr_scheduler_type’: “constant”, # use constant learning rate scheduler
‘output_dir’: ‘/opt/ml/checkpoints/’, # Temporary output directory for model checkpoints
‘merge_adapters’: True, # merge LoRA adapters into model for easier deployment
‘report_to’: “tensorboard”, # report metrics to tensorboard
‘logging_dir’: “/opt/ml/output/tensorboard” # tensorboard logging directory
}

sft_job_name = f”sft-qlora-text-to-sql-{get_current_time()}”
data = {
‘train’: train_dataset_path,
‘eval’: eval_dataset_path
}

sft_estimator = create_estimator(sft_hyperparameters, sft_job_name, role, sess, “fine_tune_llm.py”)

sft_estimator.fit(job_name=sft_job_name, inputs=data, wait=False)

When our training job has completed successfully after approximately 1 hour, we can use the fine-tuned model artifact for the next step, training the Medusa heads on top of it. To visualize the training metrics in Tensorboard, you can follow the guidance in this documentation: Load and visualize output tensors using the TensorBoard application

Train Medusa heads on top of frozen fine-tuned LLM using a SageMaker AI training job

For training Medusa heads, we can reuse the functions previously mentioned to launch the training job. We selected hyperparameters based on a combination of what the Medusa paper reported and what we found to be best performing after a few experiments. We set the number of Medusa heads to 5 and used the 8-bit AdamW optimizer, as recommended by the paper. For simplicity, we maintained a constant learning rate of 1e-4 with a constant scheduler, similar to the previous fine-tuning step. Although the paper recommends an increased learning rate and a cosine scheduler, we found that our chosen combination of hyperparameters performed well on this dataset. However, we encourage you to experiment with your own hyperparameter settings to potentially achieve even better results.

# hyperparameters, which are passed into the training job
medusa_hyperparameters = {
### SCRIPT PARAMETERS ###
‘train_dataset_path’: ‘/opt/ml/input/data/train/train_dataset.json’, # path where sagemaker will save training dataset
‘eval_dataset_path’: ‘/opt/ml/input/data/eval/eval_dataset.json’, # path where sagemaker will save evaluation dataset
‘model_path’: ‘/opt/ml/input/data/fine-tuned-model/’,
‘max_seq_len’: 256, # max sequence length for model and packing of the dataset
‘medusa_num_heads’: 5,
### TRAINING PARAMETERS ###
‘num_train_epochs’: 3, # number of training epochs
‘per_device_train_batch_size’: 1, # batch size per device during training
‘gradient_accumulation_steps’: 16, # number of steps before performing a backward/update pass
‘gradient_checkpointing’: True, # use gradient checkpointing to save memory
‘optim’: “adamw_8bit”, # use fused adamw 8bit optimizer
‘logging_steps’: 15, # log every 10 steps
‘save_strategy’: “steps”, # save checkpoint every epoch
‘save_steps’: 15,
‘save_total_limit’:2,
‘eval_strategy’: “steps”,
‘eval_steps’: 15,
‘learning_rate’: 1e-4, # learning rate
‘bf16’: True, # use bfloat16 precision
‘max_grad_norm’: 10, # max gradient norm based on QLoRA paper
‘warmup_ratio’: 0.03, # warmup ratio based on QLoRA paper
‘lr_scheduler_type’: “constant”, # use constant learning rate scheduler
‘output_dir’: ‘/opt/ml/checkpoints/’, # Temporary output directory for model checkpoints
‘report_to’: “tensorboard”, # report metrics to tensorboard
‘logging_dir’: “/opt/ml/output/tensorboard” # tensorboard logging directory
}

medusa_train_job_name = f”medusa-text-to-sql-{get_current_time()}”
data = {
‘train’: train_dataset_path,
‘eval’: eval_dataset_path,
‘fine-tuned-model’: fine_tuned_model_path
}

medusa_estimator = create_estimator(medusa_hyperparameters, medusa_train_job_name, role, sess, “train_medusa_heads.py”)

medusa_estimator.fit(job_name=medusa_train_job_name, inputs=data, wait=False)

We found that after 3 epochs, the evaluation loss of Medusa heads was converging, which can be observed in the TensorBoard graph in the following image.

Besides the hyperparameters, the main difference is that we pass train_medusa_heads.py as the training entrypoint, where we first add Medusa heads, then freeze the fine-tuned LLM, and we create custom MedusaSFTTrainer class, which is a subclass of the transformers SFTTrainer.

# Add medusa heads and freeze base model
add_medusa_heads(
model,
medusa_num_heads=script_args.medusa_num_heads,
)
freeze_layers(model)
model.config.torch_dtype = torch_dtype
model.config.use_cache = False

logger.info(“Finished loading model and medusa heads”)

tokenizer = AutoTokenizer.from_pretrained(script_args.model_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

################
# Training
################
trainer = MedusaSFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
dataset_kwargs={
“add_special_tokens”: False, # We template with special tokens
“append_concat_token”: False, # No need to add additional separator token
},
medusa_num_heads=script_args.medusa_num_heads,
medusa_heads_coefficient=script_args.medusa_heads_coefficient,
medusa_decay_coefficient=script_args.medusa_decay_coefficient,
medusa_scheduler=script_args.medusa_scheduler,
train_only_medusa_heads=script_args.train_only_medusa_heads,
medusa_lr_multiplier=script_args.medusa_lr_multiplier
)
trainer.train()

In the add_medusa_heads() function, we add the residual blocks of the Medusa heads, and also override the forward pass for our model to make sure not to train the frozen backbone LLM:

def add_medusa_heads(
model,
medusa_num_heads,
):
“””
Args:
model (nn.Module): The base language model to be used.
medusa_num_heads (int, optional): Number of additional tokens to predict
“””
hidden_size = model.lm_head.weight.shape[-1]
vocab_size = model.lm_head.weight.shape[0]
model.config.medusa_num_layers = 1
model.config.medusa_num_heads = medusa_num_heads
model.medusa_num_heads = medusa_num_heads
# Create a list of Medusa heads
model.medusa_heads = nn.ModuleList(
[
nn.Sequential(
ResBlock(hidden_size),
nn.Linear(hidden_size, vocab_size, bias=False),
)
for _ in range(medusa_num_heads)
]
)

# Ensure medusa_head’s dtype and device align with the base_model
model.medusa_heads.to(model.dtype).to(model.device)
logger.info(f”Loading medusa heads in {str(model.dtype)} to device {model.device}”)

for i in range(medusa_num_heads):
# Initialize the weights of each medusa_head using the base model’s weights
model.medusa_heads[i][-1].weight.data[:] = model.lm_head.weight.data[:]

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train_only_medusa_heads: bool = False,
):
“””Forward pass of the MedusaModel.
Returns:
torch.Tensor: A tensor containing predictions from all Medusa heads.
(Optional) Original predictions from the base model’s LM head.
“””
maybe_grad = torch.no_grad() if train_only_medusa_heads else nullcontext()
with maybe_grad:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
medusa_logits = [self.lm_head(hidden_states)]
for i in range(self.medusa_num_heads):
medusa_logits.append(self.medusa_heads[i](hidden_states))
return torch.stack(medusa_logits, dim=0)

model.forward = types.MethodType(forward, model)

After the model training is finished (which takes 1 hour), we prepare the model artefacts for deployment and upload it to Amazon S3. Your final model artifact contains both the original fine-tuned model from the previous step under the base-model prefix and the trained Medusa heads in a file named medusa_heads.safetensors.

Deploy the fine-tuned LLM with Medusa heads on a SageMaker AI endpoint

The Medusa framework is supported by the Text Generation Inference (TGI) server. After training the LLM with Medusa heads, we deploy it to a SageMaker AI real-time endpoint using the Hugging Face Inference Container set up with TGI.

First, we create a SageMaker AI HuggingFaceModel object and then deploy the model to an endpoint with the following function:

import json
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri

def deploy_model(endpoint_name, instance_type, model_s3_path=None, hf_model_id=None):
llm_image = get_huggingface_llm_image_uri(
“huggingface”,
version=”2.2.0″,
session=sess,
)

print(f”llm image uri: {llm_image}”)

model_data = None
if model_s3_path:
model_data = {‘S3DataSource’: {‘S3Uri’: model_s3_path, ‘S3DataType’: ‘S3Prefix’, ‘CompressionType’: ‘None’}}
hf_model_id = “/opt/ml/model”
else:
assert hf_model_id, “You need to provide either pretrained HF model id, or S3 model data to deploy”
config = {
‘HF_MODEL_ID’: hf_model_id, # path to where sagemaker stores the model
‘SM_NUM_GPUS’: json.dumps(1), # Number of GPU used per replica
‘MAX_INPUT_LENGTH’: json.dumps(1024), # Max length of input text
‘MAX_TOTAL_TOKENS’: json.dumps(2048), # Max length of the generation (including input text)
}

llm_model = HuggingFaceModel(
name=endpoint_name,
role=role,
image_uri=llm_image,
model_data=model_data,
env=config
)

deployed_llm = llm_model.deploy(
endpoint_name=endpoint_name,
initial_instance_count=1,
instance_type=instance_type,
container_startup_health_check_timeout=300,
)
return deployed_llm

We deploy three LLMs on three SageMaker AI endpoints:

Base LLM which isn’t fine-tuned
The LLM that we fine-tuned
The fine-tuned LLM that also has trained Medusa heads

You can deploy the three models in parallel by using a function that we included in the notebook, or you can deploy the models one by one by running the code below:

base_deployed_llm = deploy_model( f”base-{get_current_time()}”, instance_type=”ml.g5.4xlarge”, model_s3_path=None, hf_model_id=model_id )
sft_deployed_llm = deploy_model( f”sft-{get_current_time()}”, instance_type=”ml.g5.4xlarge”, model_s3_path=fine_tuned_model_path )
medusa_deployed_llm = deploy_model( f”medusa-{get_current_time()}”, instance_type=”ml.g5.4xlarge”, model_s3_path=medusa_trained_model_path )

After the status for each endpoint becomes InService, which should take around 15 minutes, we can invoke them for inference. We send the following input:

“You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA. SCHEMA: CREATE TABLE table_name_32 (time VARCHAR, heat VARCHAR, name VARCHAR)
What was Liu Limin’s time in heat 4?“

We can observe the following responses:

The base LLM response contains extra words that aren’t needed:

“To retrieve the time of Liu Limin in heat 4 based on the provided SCHEMA, we need to write a SQL query. Since the table name is not provided, let ‘s assume it’s “my_table”. Assuming the table has a primary key (which is not mentioned in the SCHEMA), we can join the table with itself to compare the heat and name of Liu Limin in heat 4 with all the records in the table. Here’s the SQL query: “sql SELECT t1.time FROM my_table t1 JOIN my_table t2 ON t1.name = t2.name AND t2.heat = 4 WHERE t1.name = ‘Liu Limin’ AND t1.heat <> 4; ` Explanation: 1. We’re selecting the time` column from the first table (t1) in the FROM clause. 2. We’re joining the table with itself (my_table t1 JOIN my_table t2) to compare the name and heat of Liu Limin in heat 4 with all the records in the table. 3. We’re filtering the results using the WHERE clause. We’re selecting only the records where the name is ‘Liu Limin’ and the heat is not equal to 4 (i.e., not heat 4). This is to ensure that we’re selecting the time of Liu Limin in heat 3. Note: This query assumes that the table has a unique primary key. If the table doesn’t have a primary key, you may need to add additional conditions to the JOIN and WHERE clauses to ensure that we’re selecting the correct records.“

The fine-tuned LLM response is improved significantly, and contains only the required output:

‘SELECT time FROM table_name_32 WHERE heat = 4 AND name = “liu limin”‘

The fine-tuned LLM with trained Medusa heads provides the exact same response as the fine-tuned model, demonstrating that Medusa-1, by design, maintains the output (quality) of the original model:

‘SELECT time FROM table_name_32 WHERE heat = 4 AND name = “liu limin”‘

Demonstrate LLM inference speedup

To measure the inference speed improvements, we compare the response times of the deployed fine-tuned LLM and the fine-tuned LLM with Medusa heads on 450 test observations with the following code:

import time
import numpy as np
from tqdm import tqdm

def request(sample, deployed_llm):
prompt = tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=True)
outputs = deployed_llm.predict({
“inputs”: prompt,
“parameters”: {
“max_new_tokens”: 512,
“do_sample”: False,
“return_full_text”: False,
}
})
return {“role”: “assistant”, “content”: outputs[0][“generated_text”].strip()}

def predict(deployed_llm, test_dataset):
predicted_answers = []
latencies = []

for sample in tqdm(test_dataset):
start_time = time.time()
predicted_answer = request(sample[“messages”][:2], deployed_llm)
end_time = time.time()

latency = end_time – start_time
latencies.append(latency)
predicted_answers.append(predicted_answer)

# Calculate p90 and average latencies
p90_latency = np.percentile(latencies, 90)
avg_latency = np.mean(latencies)

print(f”P90 Latency: {p90_latency:.2f} seconds”)
print(f”Average Latency: {avg_latency:.2f} seconds”)

return predicted_answers

First, we run predictions using the fine-tuned LLM:

sft_predictions = predict(sft_deployed_llm, test_dataset)
P90 Latency: 1.28 seconds
Average Latency: 0.95 seconds

Then, we run predictions using the fine-tuned LLM with Medusa heads:

medusa_predictions = predict(medusa_deployed_llm, test_dataset)
P90 Latency: 0.80 seconds
Average Latency: 0.53 seconds

The prediction runs should take around 8 and 4 minutes respectively. We can observe that the average latency decreased from 950 to 530 milliseconds, which is an improvement of 1.8 times. You can achieve even higher improvements if your dataset contains longer inputs and outputs. In our dataset, we only had an average of 18 input tokens and 30 output tokens.

We want to once again highlight that, with this technique, the output quality is fully maintained, and all the prediction outputs are the same. The model responses for the test set of 450 observations are the same for both with Medusa heads and without Medusa heads:

match_percentage = sum(a[“content”] == b[“content”] for a, b in zip(sft_predictions, medusa_predictions)) / len(sft_predictions) * 100
print(f”Predictions with the fine-tuned model with medusa heads are the same as without medusa heads: {match_percentage:.2f}% of test set “)

Predictions with fine-tuned model with medusa heads are the same as without medusa heads: 100.00% of test set

You might notice in your run that a few observations aren’t exactly matching, and you might get a 99% match due to small errors in floating point operations caused by optimizations on GPUs.

Cleanup

At the end of this experiment, don’t forget to delete the SageMaker AI endpoints you created:

base_deployed_llm.delete_model()
base_deployed_llm.delete_endpoint()
sft_deployed_llm.delete_model()
sft_deployed_llm.delete_endpoint()
medusa_deployed_llm.delete_model()
medusa_deployed_llm.delete_endpoint()

Conclusion

In this post, we demonstrated how to fine-tune and deploy an LLM with Medusa heads using the Medusa-1 technique on Amazon SageMaker AI to accelerate LLM inference. By using this framework and SageMaker AI scalable infrastructure, we showed how to achieve up to twofold speedups in LLM inference while maintaining model quality. This solution is particularly beneficial for applications requiring low-latency text generation, such as customer service chat assistants, content creation, and recommendation systems.

As a next step, you can explore fine-tuning your own LLM with Medusa heads on your own dataset and benchmark the results for your specific use case, using the provided GitHub repository.

About the authors

Daniel Zagyva is a Senior ML Engineer at AWS Professional Services. He specializes in developing scalable, production-grade machine learning solutions for AWS customers. His experience extends across different areas, including natural language processing, generative AI and machine learning operations.

Aleksandra Dokic is a Senior Data Scientist at AWS Professional Services. She enjoys supporting customers to build innovative AI/ML solutions on AWS and she is excited about business transformations through the power of data.

Moran Beladev is a Senior ML Manager at Booking.com. She is leading the content intelligence track which is focused on building, training and deploying content models (computer vision, NLP and generative AI) using the most advanced technologies and models. Moran is also a PhD candidate, researching applying NLP models on social graphs.

Manos Stergiadis is a Senior ML Scientist at Booking.com. He specializes in generative NLP and has experience researching, implementing and deploying large deep learning models at scale.

Ilya Gusev is a Senior Machine Learning Engineer at Booking.com. He leads the development of the several LLM systems inside Booking.com. His work focuses on building production ML systems that help millions of travelers plan their trips effectively.

Laurens van der Maas is a Machine Learning Engineer at AWS Professional Services. He works closely with customers building their machine learning solutions on AWS, specializes in natural language processing, experimentation and responsible AI, and is passionate about using machine learning to drive meaningful change in the world.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top