--- library_name: transformers datasets: - erfanzar/MoD-Prompts - erfanzar/GPT-4-Prompts language: - en pipeline_tag: text-generation --- # Raven Fine-Tuned Gemma-2B Raven is a Fine-tuned version of google/gemma-2 whith same prompting style of gemma-2b-it which trained Using TPU VM v4-64 and [EasyDeL](https://github.com/erfanzar/EasyDeL) both fine-tuning and serving code are available and it's recommended to use JAX-EasyDeL Gemma since HF-Gemma implementaion is Wrong. ### Serving and Using Raven ```python from EasyDel import JAXServer, JAXServerConfig, EasyServe from fjformer import get_dtype from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, ChatMLPrompter from EasyDel.serve.prompters.base_prompter import BasePrompter from jax import numpy as jnp, lax import jax from typing import List, Union, Optional max_sequence_length = 8192 max_compile_tokens = 256 max_new_tokens_ratio = 25 dtype = "fp16" prompter_type = "gemma" sharding_axis_dims = (1, 1, 1, -1) pretrained_model_name_or_path = "erfanzar/Raven-v0.1" attn_mechanism = "normal" scan_mlp_chunk_size = max_compile_tokens use_scan_mlp = True scan_ring_attention = True block_k = 128 block_q = 128 use_sharded_kv_caching = False server_config = JAXServerConfig( max_sequence_length=max_sequence_length, max_compile_tokens=max_compile_tokens, max_new_tokens=max_compile_tokens * max_new_tokens_ratio, dtype=dtype, pre_compile=False, eos_token_id=107 ) prompters = { "gemma": GemmaPrompter(), "llama": Llama2Prompter(), "openchat": OpenChatPrompter(), "chatml": ChatMLPrompter() } prompter: BasePrompter = prompters[prompter_type] class JAXServerC(JAXServer): @staticmethod def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str: return prompter.format_message( history=history, prompt=prompt, system_message=system, prefix=None ) @staticmethod def format_instruct(system: str, instruction: str) -> str: return prompter.format_message( prefix=None, system_message=system, prompt=instruction, history=[] ) server = JAXServerC.from_torch_pretrained( server_config=server_config, pretrained_model_name_or_path=pretrained_model_name_or_path, device=jax.devices('cpu')[0], dtype=get_dtype(dtype=dtype), param_dtype=get_dtype(dtype=dtype), precision=jax.lax.Precision("fastest"), sharding_axis_dims=sharding_axis_dims, sharding_axis_names=("dp", "fsdp", "tp", "sp"), input_shape=(1, server_config.max_sequence_length), model_config_kwargs=dict( fully_sharded_data_parallel=True, attn_mechanism=attn_mechanism, scan_mlp_chunk_size=max_compile_tokens, use_scan_mlp=use_scan_mlp, scan_ring_attention=scan_ring_attention, block_k=block_k, block_q=block_q, use_sharded_kv_caching=use_sharded_kv_caching ) ) history = [] while True: user_prompt = input("> ") model_prompt = server.format_chat( history, user_prompt, "You are an AI assistant be respect-full and explain detailed questions step by step." ) past_response_length = 0 for response, used_tokens in server.sample( model_prompt, greedy=False ): print(response[past_response_length:], end="") past_response_length = len(response) history.append([user_prompt, response]) ``` Gradio UI is also available via `server.gradio_inference().launch()`.