Papers
arxiv:2308.03291

SynJax: Structured Probability Distributions for JAX

Published on Aug 7, 2023
· Submitted by akhaliq on Aug 8, 2023
Authors:

Abstract

The development of deep learning software libraries enabled significant progress in the field by allowing users to focus on modeling, while letting the library to take care of the tedious and time-consuming task of optimizing execution for modern hardware accelerators. However, this has benefited only particular types of deep learning models, such as Transformers, whose primitives map easily to the vectorized computation. The models that explicitly account for structured objects, such as trees and segmentations, did not benefit equally because they require custom algorithms that are difficult to implement in a vectorized form. SynJax directly addresses this problem by providing an efficient vectorized implementation of inference algorithms for structured distributions covering alignment, tagging, segmentation, constituency trees and spanning trees. With SynJax we can build large-scale differentiable models that explicitly model structure in the data. The code is available at https://github.com/deepmind/synjax.

Community

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2308.03291 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2308.03291 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2308.03291 in a Space README.md to link it from this page.

Collections including this paper 0

No Collection including this paper

Add this paper to a collection to link it from this page.