GraphStorm 0.3: Scalable, multi-task learning on graphs with user-friendly APIs

GraphStorm 0.3: Scalable, multi-task learning on graphs with user-friendly APIs

GraphStorm is a low-code enterprise graph machine learning (GML) framework to build, train, and deploy graph ML solutions on complex enterprise-scale graphs in days instead of months. With GraphStorm, you can build solutions that directly take into account the structure of relationships or interactions between billions of entities, which are inherently embedded in most real-world data, including fraud detection scenarios, recommendations, community detection, and search/retrieval problems.

Today, we are launching GraphStorm 0.3, adding native support for multi-task learning on graphs. Specifically, GraphStorm 0.3 allows you to define multiple training targets on different nodes and edges within a single training loop. In addition, GraphStorm 0.3 adds new APIs to customize GraphStorm pipelines: you now only need 12 lines of code to implement a custom node classification training loop. To help you get started with the new API, we have published two Jupyter notebook examples: one for node classification, and one for a link prediction task. We also released a comprehensive study of co-training language models (LM) and graph neural networks (GNN) for large graphs with rich text features using the Microsoft Academic Graph (MAG) dataset from our KDD 2024 paper. The study showcases the performance and scalability of GraphStorm on text rich graphs and the best practices of configuring GML training loops for better performance and efficiency.

Native support for multi-task learning on graphs

Many enterprise applications have graph data associated with multiple tasks on different nodes and edges. For example, retail organizations want to conduct fraud detection on both sellers and buyers. Scientific publishers want to find more related works to cite in their papers and need to select the right subject for their publication to be discoverable. To better model such applications, customers have asked us to support multi-task learning on graphs.

GraphStorm 0.3 supports multi-task learning on graphs with six most common tasks: node classification, node regression, edge classification, edge regression, link prediction, and node feature reconstruction. You can specify the training targets through a YAML configuration file. For example, a scientific publisher can use the following YAML configuration to simultaneously define a paper subject classification task on paper nodes and a link prediction task on paper-citing-paper edges for the scientific publisher use case:

version: 1.0
gsf:
basic: # basic settings of the backbone GNN model


multi_task_learning:
– node_classification: # define a node classification task for paper subject prediction.
target_ntype: “paper” # the paper nodes are the training targets.
label_field: “label_class” # the node feature “label_class” contains the training labels.
mask_fields:
– “train_mask_class” # train mask is named as train_mask_class.
– “val_mask_class” # validation mask is named as val_mask_class.
– “test_mask_class” # test mask is named as test_mask_class.
num_classes: 10 # There are total 10 different classes (subject) to predict.
task_weight: 1.0 # The task weight is 1.0.

– link_prediction: # define a link prediction paper citation recommendation.
num_negative_edges: 4 # Sample 4 negative edges for each positive edge during training
num_negative_edges_eval: 100 # Sample 100 negative edges for each positive edge during evaluation
train_negative_sampler: joint # Share the negative edges between positive edges (to speedup training)
train_etype:
– “paper,citing,paper” # The target edge type for link prediction training is “paper, citing, paper”
mask_fields:
– “train_mask_lp” # train mask is named as train_mask_lp.
– “val_mask_lp” # validation mask is named as val_mask_lp.
– “test_mask_lp” # test mask is named as test_mask_lp.
task_weight: 0.5 # The task weight is 0.5.

For more details about how to run graph multi-task learning with GraphStorm, refer to Multi-task Learning in GraphStorm in our documentation.

New APIs to customize GraphStorm pipelines and components

Since GraphStorm’s release in early 2023, customers have mainly used its command line interface (CLI), which abstracts away the complexity of the graph ML pipeline for you to quickly build, train, and deploy models using common recipes. However, customers are telling us that they want an interface that allows them to customize the training and inference pipeline of GraphStorm to their specific requirements more easily. Based on customer feedback for the experimental APIs we released in GraphStorm 0.2, GraphStorm 0.3 introduces refactored graph ML pipeline APIs. With the new APIs, you only need 12 lines of code to define a custom node classification training pipeline, as illustrated by the following example:

import graphstorm as gs
gs.initialize()

acm_data = gs.dataloading.GSgnnData(part_config=’./acm_gs_1p/acm.json’)

train_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=acm_data, target_idx=acm_data.get_node_train_set(ntypes=[‘paper’]), fanout=[20, 20], batch_size=64)
val_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=acm_data, target_idx=acm_data.get_node_val_set(ntypes=[‘paper’]), fanout=[100, 100], batch_size=256, train_task=False)
test_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=acm_data, target_idx=acm_data.get_node_test_set(ntypes=[‘paper’]), fanout=[100, 100], batch_size=256, train_task=False)

model = RgcnNCModel(g=acm_data.g, num_hid_layers=2, hid_size=128, num_classes=14)
evaluator = gs.eval.GSgnnClassificationEvaluator(eval_frequency=100)

trainer = gs.trainer.GSgnnNodePredictionTrainer(model)
trainer.setup_evaluator(evaluator)

trainer.fit(train_dataloader, val_dataloader, test_dataloader, num_epochs=5)

To help you get started with the new APIs, we also have released new Jupyter notebook examples in our Documentation and Tutorials page.

Comprehensive study of LM+GNN for large graphs with rich text features

Many enterprise applications have graphs with text features. In retail search applications, for example, shopping log data provides insights on how text-rich product descriptions, search queries, and customer behavior are related. Foundational large language models (LLMs) alone are not suitable to model such data because the underlying data distributions and relationships don’t correspond to what LLMs learn from their pre-training data corpuses. GML, on the other hand, is great for modeling related data (graphs) but until now, GML practitioners had to manually combine their GML models with LLMs to model text features and get the best performance for their use cases. Especially when the underlying graph dataset was large, this manual work was challenging and time-consuming.

In GraphStorm 0.2, GraphStorm introduced built-in techniques to train language models (LMs) and GNN models together efficiently at scale on massive text-rich graphs. Since then, customers have been asking us for guidance on how GraphStorm’s LM+GNN techniques should be employed to optimize performance. To address this, with GraphStorm 0.3, we released a LM+GNN benchmark using the large graph dataset, Microsoft Academic Graph (MAG), on two standard graph ML tasks: node classification and link prediction. The graph dataset is a heterogeneous graph, contains hundreds of millions of nodes and billions of edges, and the majority of nodes are attributed with rich text features. The detailed statistics of the datasets are shown in the following table.

Dataset
Num. of nodes
Num. of edges
Num. of node/edge types
Num. of nodes in NC training set
Num. of edges in LP training set
Num. of nodes with text-features

MAG
484,511,504
7,520,311,838
4/4
28,679,392
1,313,781,772
240,955,156

We benchmark two main LM-GNN methods in GraphStorm: pre-trained BERT+GNN, a baseline method that is widely adopted, and fine-tuned BERT+GNN, introduced by GraphStorm developers in 2022. With the pre-trained BERT+GNN method, we first use a pre-trained BERT model to compute embeddings for node text features and then train a GNN model for prediction. With the fine-tuned BERT+GNN method, we initially fine-tune the BERT models on the graph data and use the resulting fine-tuned BERT model to compute embeddings that are then used to train a GNN models for prediction. GraphStorm provides different ways to fine-tune the BERT models, depending on the task types. For node classification, we fine-tune the BERT model on the training set with the node classification tasks; for link prediction, we fine-tune the BERT model with the link prediction tasks. In the experiment, we use 8 r5.24xlarge instances for data processing and use 4 g5.48xlarge instances for model training and inference. The fine-tuned BERT+GNN approach has up to 40% better performance (link prediction on MAG) compared to pre-trained BERT+GNN.

The following table shows the model performance of the two methods and the overall computation time of the whole pipeline starting from data processing and graph construction. NC means node classification and LP means link prediction. LM Time Cost means the time spent on computing BERT embeddings and the time spent on fine-tuning the BERT models for pre-trained BERT+GNN and fine-tuned BERT+GNN, respectively.

Dataset
Task
Data processing time
Target
Pre-trained BERT + GNN
Fine-tuned BERT + GNN

LM Time Cost
One epoch time
Metric
LM Time Cost
One epoch time
Metric

MAG
NC
553 min
paper subject
206 min
135 min
Acc:0.572
1423 min
137 min
Acc:0.633

LP
cite
198 min
2195 min
Mrr: 0.487
4508 min
2172 min
Mrr: 0.684

We also benchmark GraphStorm on large synthetic graphs to showcase its scalability. We generate three synthetic graphs with 1 billion, 10 billion, and 100 billion edges. The corresponding training set sizes are 8 million, 80 million, and 800 million, respectively. The following table shows the computation time of graph preprocessing, graph partition, and model training. Overall, GraphStorm enables graph construction and model training on 100 billion scale graphs within hours!

Graph Size
Data pre-process
Graph Partition
Model Training

# instances
Time
# instances
Time
# instances
Time

1B
4
19 min
4
8 min
4
1.5 min

10B
8
31 min
8
41 min
8
8 min

100B
16
61 min
16
416 min
16
50 min

More benchmark details and results are available in our KDD 2024 paper.

Conclusion

GraphStorm 0.3 is published under the Apache-2.0 license to help you tackle your large-scale graph ML challenges, and now offers native support for multi-task learning and new APIs to customize pipelines and other components of GraphStorm. Refer to the GraphStorm GitHub repository and documentation to get started.

About the Author

Xiang Song is a senior applied scientist at AWS AI Research and Education (AIRE), where he develops deep learning frameworks including GraphStorm, DGL and DGL-KE. He led the development of Amazon Neptune ML, a new capability of Neptune that uses graph neural networks for graphs stored in graph database. He is now leading the development of GraphStorm, an open-source graph machine learning framework for enterprise use cases. He received his Ph.D. in computer systems and architecture at the Fudan University, Shanghai, in 2014.

Jian Zhang is a senior applied scientist who has been using machine learning techniques to help customers solve various problems, such as fraud detection, decoration image generation, and more. He has successfully developed graph-based machine learning, particularly graph neural network, solutions for customers in China, USA, and Singapore. As an enlightener of AWS’s graph capabilities, Zhang has given many public presentations about the GNN, the Deep Graph Library (DGL), Amazon Neptune, and other AWS services.

Florian Saupe is a Principal Technical Product Manager at AWS AI/ML research supporting science teams like the graph machine learning group, and ML Systems teams working on large scale distributed training, inference, and fault resilience. Before joining AWS, Florian lead technical product management for automated driving at Bosch, was a strategy consultant at McKinsey & Company, and worked as a control systems/robotics scientist – a field in which he holds a phd.

Leave a Comment

Your email address will not be published. Required fields are marked *

Sign In

Register

Reset Password

Please enter your username or email address, you will receive a link to create a new password via email.

Scroll to Top