|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Optional, Tuple, Union |
|
|
|
import flax |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
|
from flax.linen import combine_masks, make_causal_mask |
|
from flax.linen.attention import dot_product_attention_weights |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax import lax |
|
|
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling |
|
from ...modeling_flax_utils import ( |
|
ACT2FN, |
|
FlaxPreTrainedModel, |
|
append_replace_return_docstrings, |
|
overwrite_call_docstring, |
|
) |
|
from ...utils import ModelOutput, add_start_docstrings, logging |
|
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
CLIP_START_DOCSTRING = r""" |
|
|
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading, saving and converting weights from PyTorch models) |
|
|
|
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) |
|
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to |
|
general usage and behavior. |
|
|
|
Finally, this model supports inherent JAX features such as: |
|
|
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) |
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) |
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) |
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) |
|
|
|
Parameters: |
|
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
|
`jax.numpy.bfloat16` (on TPUs). |
|
|
|
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
|
specified all the computation will be performed with the given `dtype`. |
|
|
|
**Note that this only specifies the dtype of the computation and does not influence the dtype of model |
|
parameters.** |
|
|
|
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
|
[`~FlaxPreTrainedModel.to_bf16`]. |
|
""" |
|
|
|
CLIP_TEXT_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.max_position_embeddings - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
CLIP_VISION_INPUTS_DOCSTRING = r""" |
|
Args: |
|
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
CLIP_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.max_position_embeddings - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
@flax.struct.dataclass |
|
class FlaxCLIPTextModelOutput(ModelOutput): |
|
""" |
|
Base class for text model's outputs that also contains a pooling of the last hidden states. |
|
|
|
Args: |
|
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`): |
|
The text embeddings obtained by applying the projection layer to the pooled output of |
|
[`FlaxCLIPTextModel`]. |
|
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape |
|
`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
text_embeds: jnp.ndarray = None |
|
last_hidden_state: jnp.ndarray = None |
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None |
|
attentions: Optional[Tuple[jnp.ndarray]] = None |
|
|
|
|
|
@flax.struct.dataclass |
|
class FlaxCLIPOutput(ModelOutput): |
|
""" |
|
Args: |
|
logits_per_image:(`jnp.ndarray` of shape `(image_batch_size, text_batch_size)`): |
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text |
|
similarity scores. |
|
logits_per_text:(`jnp.ndarray` of shape `(text_batch_size, image_batch_size)`): |
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image |
|
similarity scores. |
|
text_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`): |
|
The text embeddings obtained by applying the projection layer to the pooled output of |
|
[`FlaxCLIPTextModel`]. |
|
image_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`): |
|
The image embeddings obtained by applying the projection layer to the pooled output of |
|
[`FlaxCLIPVisionModel`]. |
|
text_model_output(`FlaxBaseModelOutputWithPooling`): |
|
The output of the [`FlaxCLIPTextModel`]. |
|
vision_model_output(`FlaxBaseModelOutputWithPooling`): |
|
The output of the [`FlaxCLIPVisionModel`]. |
|
""" |
|
|
|
logits_per_image: jnp.ndarray = None |
|
logits_per_text: jnp.ndarray = None |
|
text_embeds: jnp.ndarray = None |
|
image_embeds: jnp.ndarray = None |
|
text_model_output: FlaxBaseModelOutputWithPooling = None |
|
vision_model_output: FlaxBaseModelOutputWithPooling = None |
|
|
|
def to_tuple(self) -> Tuple[Any]: |
|
return tuple( |
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() |
|
for k in self.keys() |
|
) |
|
|
|
|
|
class FlaxCLIPVisionEmbeddings(nn.Module): |
|
config: CLIPVisionConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
embed_dim = self.config.hidden_size |
|
image_size = self.config.image_size |
|
patch_size = self.config.patch_size |
|
|
|
self.class_embedding = self.param("class_embedding", jax.nn.initializers.normal(stddev=0.02), (embed_dim,)) |
|
|
|
self.patch_embedding = nn.Conv( |
|
embed_dim, |
|
kernel_size=(patch_size, patch_size), |
|
strides=(patch_size, patch_size), |
|
padding="VALID", |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(), |
|
) |
|
|
|
self.num_patches = (image_size // patch_size) ** 2 |
|
num_positions = self.num_patches + 1 |
|
self.position_embedding = nn.Embed(num_positions, embed_dim, embedding_init=jax.nn.initializers.normal()) |
|
self.position_ids = jnp.expand_dims(jnp.arange(0, num_positions, dtype="i4"), axis=0) |
|
|
|
def __call__(self, pixel_values): |
|
patch_embeds = self.patch_embedding(pixel_values) |
|
batch_size, height, width, channels = patch_embeds.shape |
|
patch_embeds = jnp.reshape(patch_embeds, (batch_size, height * width, channels)) |
|
|
|
class_embeds = jnp.expand_dims(self.class_embedding, axis=(0, 1)) |
|
class_embeds = jnp.tile(class_embeds, (batch_size, 1, 1)) |
|
embeddings = jnp.concatenate([class_embeds, patch_embeds], axis=1) |
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
|
return embeddings |
|
|
|
|
|
class FlaxCLIPTextEmbeddings(nn.Module): |
|
config: CLIPTextConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
embed_dim = self.config.hidden_size |
|
|
|
self.token_embedding = nn.Embed(self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal()) |
|
self.position_embedding = nn.Embed( |
|
self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal() |
|
) |
|
self.position_ids = jnp.expand_dims( |
|
jnp.arange(0, self.config.max_position_embeddings, dtype="i4"), axis=(0, 1) |
|
) |
|
|
|
def __call__(self, input_ids, position_ids): |
|
input_embeds = self.token_embedding(input_ids.astype("i4")) |
|
position_embeds = self.position_embedding(position_ids.astype("i4")) |
|
|
|
embeddings = input_embeds + position_embeds |
|
return embeddings |
|
|
|
|
|
class FlaxCLIPAttention(nn.Module): |
|
config: Union[CLIPTextConfig, CLIPVisionConfig] |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.embed_dim = self.config.hidden_size |
|
self.num_heads = self.config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
f" {self.num_heads})." |
|
) |
|
self.scale = self.head_dim**-0.5 |
|
self.dropout = self.config.attention_dropout |
|
|
|
self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) |
|
self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) |
|
self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) |
|
self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) |
|
|
|
self.causal = isinstance(self.config, CLIPTextConfig) |
|
if self.causal: |
|
self.causal_mask = make_causal_mask(jnp.ones((1, self.config.max_position_embeddings), dtype="i4")) |
|
|
|
def _split_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
|
def _merge_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
): |
|
query = self.q_proj(hidden_states) |
|
key = self.k_proj(hidden_states) |
|
value = self.v_proj(hidden_states) |
|
|
|
query = self._split_heads(query) |
|
key = self._split_heads(key) |
|
value = self._split_heads(value) |
|
|
|
causal_attention_mask = None |
|
if self.causal: |
|
query_length, key_length = query.shape[1], key.shape[1] |
|
causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length] |
|
|
|
if attention_mask is not None and causal_attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4") |
|
elif causal_attention_mask is not None: |
|
attention_mask = causal_attention_mask |
|
elif attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
|
if attention_mask is not None: |
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
dropout_rng = None |
|
if not deterministic and self.dropout > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
attn_weights = dot_product_attention_weights( |
|
query, |
|
key, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.dropout, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision=None, |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) |
|
attn_output = self._merge_heads(attn_output) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) |
|
return outputs |
|
|
|
|
|
class FlaxCLIPMLP(nn.Module): |
|
config: Union[CLIPTextConfig, CLIPVisionConfig] |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.activation_fn = ACT2FN[self.config.hidden_act] |
|
self.fc1 = nn.Dense( |
|
self.config.intermediate_size, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(0.01), |
|
) |
|
self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) |
|
|
|
def __call__(self, hidden_states): |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = self.activation_fn(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlaxCLIPEncoderLayer(nn.Module): |
|
config: Union[CLIPTextConfig, CLIPVisionConfig] |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.self_attn = FlaxCLIPAttention(self.config, dtype=self.dtype) |
|
self.layer_norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.mlp = FlaxCLIPMLP(self.config, dtype=self.dtype) |
|
self.layer_norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
): |
|
residual = hidden_states |
|
|
|
hidden_states = self.layer_norm1(hidden_states) |
|
attn_outputs = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = attn_outputs[0] |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += attn_outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class FlaxCLIPLayerCollection(nn.Module): |
|
config: Union[CLIPTextConfig, CLIPVisionConfig] |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layers = [ |
|
FlaxCLIPEncoderLayer(self.config, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.num_hidden_layers) |
|
] |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
all_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
for layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = layer( |
|
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions += (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class FlaxCLIPEncoder(nn.Module): |
|
config: Union[CLIPTextConfig, CLIPVisionConfig] |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layers = FlaxCLIPLayerCollection(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
inputs_embeds, |
|
attention_mask=None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
return self.layers( |
|
hidden_states=inputs_embeds, |
|
attention_mask=attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
class FlaxCLIPTextTransformer(nn.Module): |
|
config: CLIPTextConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.embeddings = FlaxCLIPTextEmbeddings(self.config, dtype=self.dtype) |
|
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype) |
|
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
|
|
|
self.eos_token_id = self.config.eos_token_id |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) |
|
|
|
encoder_outputs = self.encoder( |
|
inputs_embeds=hidden_states, |
|
attention_mask=attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_state = encoder_outputs[0] |
|
last_hidden_state = self.final_layer_norm(last_hidden_state) |
|
|
|
if self.eos_token_id == 2: |
|
|
|
|
|
|
|
|
|
|
|
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)] |
|
else: |
|
|
|
pooled_output = last_hidden_state[ |
|
jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1) |
|
] |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
|
return FlaxBaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxCLIPVisionTransformer(nn.Module): |
|
config: CLIPVisionConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.embeddings = FlaxCLIPVisionEmbeddings(self.config, dtype=self.dtype) |
|
self.pre_layrnorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype) |
|
self.post_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
pixel_values=None, |
|
deterministic: bool = True, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict: bool = True, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
hidden_states = self.embeddings(pixel_values) |
|
hidden_states = self.pre_layrnorm(hidden_states) |
|
|
|
encoder_outputs = self.encoder( |
|
inputs_embeds=hidden_states, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_state = encoder_outputs[0] |
|
pooled_output = last_hidden_state[:, 0, :] |
|
pooled_output = self.post_layernorm(pooled_output) |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
|
return FlaxBaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = CLIPTextConfig |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: CLIPTextConfig, |
|
input_shape=(1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs, |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
|
|
|
input_ids = jnp.zeros(input_shape, dtype="i4") |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"] |
|
|
|
if params is not None: |
|
random_params = flatten_dict(unfreeze(random_params)) |
|
params = flatten_dict(unfreeze(params)) |
|
for missing_key in self._missing_keys: |
|
params[missing_key] = random_params[missing_key] |
|
self._missing_keys = set() |
|
return freeze(unflatten_dict(params)) |
|
else: |
|
return random_params |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
position_ids=None, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if position_ids is None: |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(input_ids, dtype="i4"), |
|
jnp.array(attention_mask, dtype="i4"), |
|
jnp.array(position_ids, dtype="i4"), |
|
not train, |
|
output_attentions, |
|
output_hidden_states, |
|
return_dict, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = CLIPVisionConfig |
|
main_input_name = "pixel_values" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: CLIPVisionConfig, |
|
input_shape: Optional[Tuple] = None, |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs, |
|
): |
|
if input_shape is None: |
|
input_shape = (1, config.image_size, config.image_size, 3) |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
|
|
|
pixel_values = jax.random.normal(rng, input_shape) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
random_params = self.module.init(rngs, pixel_values)["params"] |
|
|
|
if params is not None: |
|
random_params = flatten_dict(unfreeze(random_params)) |
|
params = flatten_dict(unfreeze(params)) |
|
for missing_key in self._missing_keys: |
|
params[missing_key] = random_params[missing_key] |
|
self._missing_keys = set() |
|
return freeze(unflatten_dict(params)) |
|
else: |
|
return random_params |
|
|
|
def __call__( |
|
self, |
|
pixel_values, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(pixel_values, dtype=jnp.float32), |
|
not train, |
|
output_attentions, |
|
output_hidden_states, |
|
return_dict, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = CLIPConfig |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: CLIPConfig, |
|
input_shape: Optional[Tuple] = None, |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs, |
|
): |
|
if input_shape is None: |
|
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
|
|
|
input_ids = jnp.zeros(input_shape[0], dtype="i4") |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
pixel_values = jax.random.normal(rng, input_shape[1]) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"] |
|
|
|
if params is not None: |
|
random_params = flatten_dict(unfreeze(random_params)) |
|
params = flatten_dict(unfreeze(params)) |
|
for missing_key in self._missing_keys: |
|
params[missing_key] = random_params[missing_key] |
|
self._missing_keys = set() |
|
return freeze(unflatten_dict(params)) |
|
else: |
|
return random_params |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
pixel_values, |
|
attention_mask=None, |
|
position_ids=None, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if position_ids is None: |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(input_ids, dtype="i4"), |
|
jnp.array(pixel_values, dtype=jnp.float32), |
|
jnp.array(attention_mask, dtype="i4"), |
|
jnp.array(position_ids, dtype="i4"), |
|
not train, |
|
output_attentions, |
|
output_hidden_states, |
|
return_dict, |
|
rngs=rngs, |
|
) |
|
|
|
def get_text_features( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
position_ids=None, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train=False, |
|
): |
|
r""" |
|
Args: |
|
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
|
provide it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
|
Returns: |
|
text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying |
|
the projection layer to the pooled output of [`FlaxCLIPTextModel`]. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, FlaxCLIPModel |
|
|
|
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") |
|
>>> text_features = model.get_text_features(**inputs) |
|
```""" |
|
if position_ids is None: |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _get_features(module, input_ids, attention_mask, position_ids, deterministic): |
|
text_outputs = module.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
deterministic=deterministic, |
|
) |
|
pooled_output = text_outputs[1] |
|
text_features = module.text_projection(pooled_output) |
|
return text_features |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(input_ids, dtype="i4"), |
|
jnp.array(attention_mask, dtype="i4"), |
|
jnp.array(position_ids, dtype="i4"), |
|
not train, |
|
method=_get_features, |
|
rngs=rngs, |
|
) |
|
|
|
def get_image_features( |
|
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False |
|
): |
|
r""" |
|
Args: |
|
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained |
|
using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
|
|
Returns: |
|
image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by |
|
applying the projection layer to the pooled output of [`FlaxCLIPVisionModel`] |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, FlaxCLIPModel |
|
|
|
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, return_tensors="np") |
|
|
|
>>> image_features = model.get_image_features(**inputs) |
|
```""" |
|
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _get_features(module, pixel_values, deterministic): |
|
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic) |
|
pooled_output = vision_outputs[1] |
|
image_features = module.visual_projection(pooled_output) |
|
return image_features |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(pixel_values, dtype=jnp.float32), |
|
not train, |
|
method=_get_features, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
class FlaxCLIPTextModule(nn.Module): |
|
config: CLIPTextConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
return self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel): |
|
module_class = FlaxCLIPTextModule |
|
|
|
|
|
FLAX_CLIP_TEXT_MODEL_DOCSTRING = """ |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, FlaxCLIPTextModel |
|
|
|
>>> model = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> last_hidden_state = outputs.last_hidden_state |
|
>>> pooler_output = outputs.pooler_output # pooled (EOS token) states |
|
``` |
|
""" |
|
|
|
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING) |
|
append_replace_return_docstrings( |
|
FlaxCLIPTextModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPTextConfig |
|
) |
|
|
|
|
|
class FlaxCLIPTextModelWithProjectionModule(nn.Module): |
|
config: CLIPTextConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype) |
|
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = text_outputs[1] |
|
text_embeds = self.text_projection(pooled_output) |
|
|
|
if not return_dict: |
|
return (text_embeds, text_outputs[0]) + text_outputs[2:] |
|
|
|
return FlaxCLIPTextModelOutput( |
|
text_embeds=text_embeds, |
|
last_hidden_state=text_outputs.last_hidden_state, |
|
hidden_states=text_outputs.hidden_states, |
|
attentions=text_outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel): |
|
module_class = FlaxCLIPTextModelWithProjectionModule |
|
|
|
|
|
FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """ |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection |
|
|
|
>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> text_embeds = outputs.text_embeds |
|
``` |
|
""" |
|
|
|
overwrite_call_docstring( |
|
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING |
|
) |
|
append_replace_return_docstrings( |
|
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig |
|
) |
|
|
|
|
|
class FlaxCLIPVisionModule(nn.Module): |
|
config: CLIPVisionConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.vision_model = FlaxCLIPVisionTransformer(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
pixel_values, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
return self.vision_model( |
|
pixel_values=pixel_values, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
class FlaxCLIPVisionModel(FlaxCLIPVisionPreTrainedModel): |
|
module_class = FlaxCLIPVisionModule |
|
|
|
|
|
FLAX_CLIP_VISION_MODEL_DOCSTRING = """ |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, FlaxCLIPVisionModel |
|
|
|
>>> model = FlaxCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, return_tensors="np") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> last_hidden_state = outputs.last_hidden_state |
|
>>> pooler_output = outputs.pooler_output # pooled CLS states |
|
``` |
|
""" |
|
|
|
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING) |
|
append_replace_return_docstrings( |
|
FlaxCLIPVisionModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPVisionConfig |
|
) |
|
|
|
|
|
class FlaxCLIPModule(nn.Module): |
|
config: CLIPConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
text_config = self.config.text_config |
|
vision_config = self.config.vision_config |
|
|
|
self.projection_dim = self.config.projection_dim |
|
self.text_embed_dim = text_config.hidden_size |
|
self.vision_embed_dim = vision_config.hidden_size |
|
|
|
self.text_model = FlaxCLIPTextTransformer(text_config, dtype=self.dtype) |
|
self.vision_model = FlaxCLIPVisionTransformer(vision_config, dtype=self.dtype) |
|
|
|
self.visual_projection = nn.Dense( |
|
self.projection_dim, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(0.02), |
|
use_bias=False, |
|
) |
|
self.text_projection = nn.Dense( |
|
self.projection_dim, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(0.02), |
|
use_bias=False, |
|
) |
|
|
|
self.logit_scale = self.param( |
|
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, [] |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_ids=None, |
|
pixel_values=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
deterministic: bool = True, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
image_embeds = vision_outputs[1] |
|
image_embeds = self.visual_projection(image_embeds) |
|
|
|
text_embeds = text_outputs[1] |
|
text_embeds = self.text_projection(text_embeds) |
|
|
|
|
|
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True) |
|
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True) |
|
|
|
|
|
logit_scale = jnp.exp(self.logit_scale) |
|
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale |
|
logits_per_image = logits_per_text.T |
|
|
|
if not return_dict: |
|
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) |
|
|
|
return FlaxCLIPOutput( |
|
logits_per_image=logits_per_image, |
|
logits_per_text=logits_per_text, |
|
text_embeds=text_embeds, |
|
image_embeds=image_embeds, |
|
text_model_output=text_outputs, |
|
vision_model_output=vision_outputs, |
|
) |
|
|
|
|
|
@add_start_docstrings(CLIP_START_DOCSTRING) |
|
class FlaxCLIPModel(FlaxCLIPPreTrainedModel): |
|
module_class = FlaxCLIPModule |
|
|
|
|
|
FLAX_CLIP_MODEL_DOCSTRING = """ |
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> import jax |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, FlaxCLIPModel |
|
|
|
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor( |
|
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="np", padding=True |
|
... ) |
|
|
|
>>> outputs = model(**inputs) |
|
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score |
|
>>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities |
|
``` |
|
""" |
|
|
|
overwrite_call_docstring(FlaxCLIPModel, CLIP_INPUTS_DOCSTRING + FLAX_CLIP_MODEL_DOCSTRING) |
|
append_replace_return_docstrings(FlaxCLIPModel, output_type=FlaxCLIPOutput, config_class=CLIPConfig) |
|
|