A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://huggingface.co/docs/transformers/v4.51.3/en/model_doc/jamba below:

Website Navigation


Jamba

Jamba Overview

Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations.

For full details of this model please read the release blog post.

Model Details

Jamba is a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and an overall of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.

As depicted in the diagram below, Jamba’s architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers.

Usage Prerequisites

Jamba requires you use transformers version 4.39.0 or higher:

pip install transformers>=4.39.0

In order to run optimized Mamba implementations, you first need to install mamba-ssm and causal-conv1d:

pip install mamba-ssm causal-conv1d>=1.2.0

You also have to have the model on a CUDA device.

You can run the model not using the optimized Mamba kernels, but it is not recommended as it will result in significantly lower latencies. In order to do that, you’ll need to specify use_mamba_kernels=False when loading the model.

Run the model
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]

outputs = model.generate(input_ids, max_new_tokens=216)

print(tokenizer.batch_decode(outputs))
Loading the model in half precision

The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify torch_dtype:

from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)

When using half precision, you can enable the FlashAttention2 implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you’ll also need to parallelize it using accelerate:

from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             device_map="auto")
Load the model in 8-bit

Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU. You can easily quantize the model to 8-bit using bitsandbytes. In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained(
    "ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", quantization_config=quantization_config
)
JambaConfig class transformers.JambaConfig < source >

( vocab_size = 65536 tie_word_embeddings = False hidden_size = 4096 intermediate_size = 14336 num_hidden_layers = 32 num_attention_heads = 32 num_key_value_heads = 8 hidden_act = 'silu' initializer_range = 0.02 rms_norm_eps = 1e-06 use_cache = True num_logits_to_keep = 1 output_router_logits = False router_aux_loss_coef = 0.001 pad_token_id = 0 bos_token_id = 1 eos_token_id = 2 sliding_window = None max_position_embeddings = 262144 attention_dropout = 0.0 num_experts_per_tok = 2 num_experts = 16 expert_layer_period = 2 expert_layer_offset = 1 attn_layer_period = 8 attn_layer_offset = 4 use_mamba_kernels = True mamba_d_state = 16 mamba_d_conv = 4 mamba_expand = 2 mamba_dt_rank = 'auto' mamba_conv_bias = True mamba_proj_bias = False **kwargs )

Parameters

This is the configuration class to store the configuration of a JambaModel. It is used to instantiate a Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.

ai21labs/Jamba-v0.1

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

JambaModel class transformers.JambaModel < source >

( config: JambaConfig )

Parameters

The bare Jamba Model outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a JambaDecoderLayer

forward < source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None output_router_logits: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None )

Parameters

The JambaModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

JambaForCausalLM class transformers.JambaForCausalLM < source >

( config: JambaConfig )

forward < source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None output_router_logits: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **loss_kwargs ) transformers.modeling_outputs.MoeCausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

Returns

transformers.modeling_outputs.MoeCausalLMOutputWithPast or tuple(torch.FloatTensor)

A transformers.modeling_outputs.MoeCausalLMOutputWithPast or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (JambaConfig) and inputs.

The JambaForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoTokenizer, JambaForCausalLM

>>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> 
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
JambaForSequenceClassification class transformers.JambaForSequenceClassification < source >

( config )

Parameters

The Jamba Model with a sequence classification head on top (linear layer).

JambaForSequenceClassification uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token. If a pad_token_id is defined in the configuration, it finds the last token that is not a padding token in each row. If no pad_token_id is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when inputs_embeds are passed instead of input_ids, it does the same (take the last value in each row of the batch).

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward < source >

( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None )

Parameters

The JambaForSequenceClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

< > Update on GitHub

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3