Absortio

Email → Summary → Bookmark → Email

Introducing TensorFlow Graph Neural Networks

Extracto

Introducing TensorFlow GNN, a library to build Graph Neural Networks on the TensorFlow platform.

Resumen

Resumen Principal

TensorFlow presenta una innovadora biblioteca denominada TensorFlow GNN (Graph Neural Networks), diseñada específicamente para facilitar la construcción y el entrenamiento de redes neuronales basadas en estructuras de grafos dentro del ecosistema TensorFlow. Esta herramienta representa un avance significativo en el campo del machine learning, al permitir a los desarrolladores y científicos de datos trabajar con datos altamente interconectados de manera más eficiente y escalable. La biblioteca está optimizada para manejar grafos heterogéneos complejos, lo que la hace especialmente valiosa para aplicaciones del mundo real donde las relaciones entre entidades son diversas y ricas en información. TensorFlow GNN ofrece una arquitectura modular que permite la fácil implementación de modelos personalizados, mientras mantiene la compatibilidad con el resto del stack de TensorFlow, incluyendo su potente infraestructura de entrenamiento distribuido. La introducción de esta biblioteca refuerza el compromiso de Google con el desarrollo de soluciones de deep learning accesibles y especializadas para datos estructurados en forma de grafos, abriendo nuevas posibilidades en áreas como sistemas de recomendación, biología computacional y análisis de redes sociales.

Elementos Clave

  • Integración con TensorFlow: La biblioteca está completamente integrada con el ecosistema TensorFlow, lo que permite aprovechar herramientas existentes como TensorBoard, TFX y entrenamiento distribuido sin interrupciones en el flujo de trabajo
  • Soporte para grafos heterogéneos: TensorFlow GNN está diseñada para manejar grafos con múltiples tipos de nodos y aristas, lo que refleja mejor la complejidad de los datos del mundo real en aplicaciones como redes sociales o sistemas biológicos
  • Arquitectura modular y extensible: Ofrece componentes reutilizables y una estructura flexible que permite a los usuarios construir modelos personalizados manteniendo buenas prácticas de ingeniería de software y reproducibilidad
  • Optimización para escalabilidad: La biblioteca está construida pensando en el rendimiento y la escalabilidad, permitiendo el procesamiento eficiente de grafos grandes y complejos en entornos de producción

Análisis e Implicaciones

La introducción de TensorFlow GNN marca un hito importante en la democratización del graph machine learning, al proporcionar una solución robusta y accesible para desarrolladores que trabajan con datos relacionales complejos. Esta herramienta tiene el potencial de acelerar significativamente la adopción de técnicas avanzadas de análisis de grafos en industrias que dependen de comprender relaciones intrincadas entre entidades. La estandarización que ofrece TensorFlow GNN podría establecer nuevas best practices en el desarrollo de modelos basados en grafos, facilitando la colaboración y el intercambio de soluciones entre diferentes organizaciones.

Contexto Adicional

Las Graph Neural Networks han demostrado resultados revolucionarios en múltiples dominios, desde descubrimiento de fármacos hasta detección de fraude financiero, y TensorFlow GNN posiciona a TensorFlow como una plataforma líder para estas aplicaciones emergentes. La biblioteca se beneficia del backend optimizado de TensorFlow, lo que garantiza un rendimiento consistente tanto en entornos de investigación como en implementaciones de producción a gran escala.

Contenido

Posted by Sibon Li, Jan Pfeifer and Bryan Perozzi and Douglas Yarrington

Today, we are excited to release TensorFlow Graph Neural Networks (GNNs), a library designed to make it easy to work with graph structured data using TensorFlow. We have used an earlier version of this library in production at Google in a variety of contexts (for example, spam and anomaly detection, traffic estimation, YouTube content labeling) and as a component in our scalable graph mining pipelines. In particular, given the myriad types of data at Google, our library was designed with heterogeneous graphs in mind. We are releasing this library with the intention to encourage collaborations with researchers in industry.

Why use GNNs?

Graphs are all around us, in the real world and in our engineered systems. A set of objects, places, or people and the connections between them is generally describable as a graph. More often than not, the data we see in machine learning problems is structured or relational, and thus can also be described with a graph. And while fundamental research on GNNs is perhaps decades old, recent advances in the capabilities of modern GNNs have led to advances in domains as varied as traffic prediction, rumor and fake news detection, modeling disease spread, physics simulations, and understanding why molecules smell.

Graphs can model the relationships between many different types of data, including web pages (left), social connections (center), or molecules (right).
Graphs can model the relationships between many different types of data, including web pages (left), social connections (center), or molecules (right).

A graph represents the relations (edges) between a collection of entities (nodes or vertices). We can characterize each node, edge, or the entire graph, and thereby store information in each of these pieces of the graph. Additionally, we can ascribe directionality to edges to describe information or traffic flow, for example.

GNNs can be used to answer questions about multiple characteristics of these graphs. By working at the graph level, we try to predict characteristics of the entire graph. We can identify the presence of certain “shapes,” like circles in a graph that might represent sub-molecules or perhaps close social relationships. GNNs can be used on node-level tasks, to classify the nodes of a graph, and predict partitions and affinity in a graph similar to image classification or segmentation. Finally, we can use GNNs at the edge level to discover connections between entities, perhaps using GNNs to “prune” edges to identify the state of objects in a scene.

Structure

TF-GNN provides building blocks for implementing GNN models in TensorFlow. Beyond the modeling APIs, our library also provides extensive tooling around the difficult task of working with graph data: a Tensor-based graph data structure, a data handling pipeline, and some example models for users to quickly onboard.

The various components of TF-GNN that make up the workflow.
The various components of TF-GNN that make up the workflow.

The initial release of the TF-GNN library contains a number of utilities and features for use by beginners and experienced users alike, including:

  • A high-level Keras-style API to create GNN models that can easily be composed with other types of models. GNNs are often used in combination with ranking, deep-retrieval (dual-encoders) or mixed with other types of models (image, text, etc.)
    • GNN API for heterogeneous graphs. Many of the graph problems we approach at Google and in the real world contain different types of nodes and edges. Hence we chose to provide an easy way to model this.
  • A well-defined schema to declare the topology of a graph, and tools to validate it. This schema describes the shape of its training data and serves to guide other tools.
  • A GraphTensor composite tensor type which holds graph data, can be batched, and has graph manipulation routines available.
  • A library of operations on the GraphTensor structure:
    • Various efficient broadcast and pooling operations on nodes and edges, and related tools.
    • A library of standard baked convolutions, that can be easily extended by ML engineers/researchers.
    • A high-level API for product engineers to quickly build GNN models without necessarily worrying about its details.
  • An encoding of graph-shaped training data on disk, as well as a library used to parse this data into a data structure from which your model can extract the various features.

Example usage

In the example below, we build a model using the TF-GNN Keras API to recommend movies to a user based on what they watched and genres that they liked.

We use the ConvGNNBuilder method to specify the type of edge and node configuration, namely to use WeightedSumConvolution (defined below) for edges. And for each pass through the GNN, we will update the node values through a Dense interconnected layer:

    import tensorflow as tf
    import tensorflow_gnn as tfgnn

    # Model hyper-parameters:
    h_dims = {'user': 256, 'movie': 64, 'genre': 128}
    
    # Model builder initialization:
    gnn = tfgnn.keras.ConvGNNBuilder(
      lambda edge_set_name: WeightedSumConvolution(),
      lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(
         tf.keras.layers.Dense(h_dims[node_set_name]))
    )
    
    # Two rounds of message passing to target node sets:
    model = tf.keras.models.Sequential([
        gnn.Convolve({'genre'}),  # sends messages from movie to genre
        gnn.Convolve({'user'}),  # sends messages from movie and genre to users
        tfgnn.keras.layers.Readout(node_set_name="user"),
        tf.keras.layers.Dense(1)
    ])

The code above works great, but sometimes we may want to use a more powerful custom model architecture for our GNNs. For example, in our previous use case, we might want to specify that certain movies or genres hold more weight when we give our recommendation. In the following snippet, we define a more advanced GNN with custom graph convolutions, in this case with weighted edges. We define the WeightedSumConvolution class to pool edge values as a sum of weights across all edges:

class WeightedSumConvolution(tf.keras.layers.Layer):
  """Weighted sum of source nodes states."""

  def call(self, graph: tfgnn.GraphTensor,
           edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
    messages = tfgnn.broadcast_node_to_edges(
        graph,
        edge_set_name,
        tfgnn.SOURCE,
        feature_name=tfgnn.DEFAULT_STATE_NAME)
    weights = graph.edge_sets[edge_set_name]['weight']
    weighted_messages = tf.expand_dims(weights, -1) * messages
    pooled_messages = tfgnn.pool_edges_to_node(
        graph,
        edge_set_name,
        tfgnn.TARGET,
        reduce_type='sum',
        feature_value=weighted_messages)
    return pooled_messages

Note that even though the convolution was written with only the source and target nodes in mind, TF-GNN makes sure it’s applicable and works on heterogeneous graphs (with various types of nodes and edges) seamlessly.

Next steps

You can check out the TF-GNN GitHub repo for more information. To stay up to date, you can read the TensorFlow blog, join the TensorFlow Forum at discuss.tensorflow.org, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub. Thank you!

Acknowledgments

The work described here was a research collaboration between Oleksandr Ferludin‎, Martin Blais, Jan Pfeifer‎, Arno Eigenwillig, Dustin Zelle, Bryan Perozzi and Da-Cheng Juan of Google, and Sibon Li, Alvaro Sanchez-Gonzalez, Peter Battaglia, Kevin Villela, Jennifer She and David Wong of DeepMind.