laurentiubp commited on
Commit
de77006
1 Parent(s): 10c1dee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -10
app.py CHANGED
@@ -1,14 +1,82 @@
1
- import spaces
 
 
 
2
  import gradio as gr
3
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
4
 
 
5
  """
6
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
  client = InferenceClient("catallama/CataLlama-v0.1-Instruct-DPO")
9
 
10
-
11
- @spaces.GPU(duration=120)
12
  def respond(
13
  message,
14
  history: list[tuple[str, str]],
@@ -40,12 +108,11 @@ def respond(
40
 
41
  response += token
42
  yield response
43
-
44
- """
45
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
46
  """
47
- demo = gr.ChatInterface(
48
- respond,
 
 
49
  additional_inputs=[
50
  gr.Textbox(value="Ets un chatbot amigable.", label="System message"),
51
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
@@ -60,6 +127,10 @@ demo = gr.ChatInterface(
60
  ],
61
  )
62
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
  demo.launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+
11
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
12
+
13
+
14
+ DESCRIPTION = """\
15
+ # CataLlama-v0.1 Chat
16
+ This Space demonstrates model [CataLlama-v0.1-Instruct-DPO](https://huggingface.co/catallama/CataLlama-v0.1-Instruct-DPO).
17
 
18
+ CataLlama is a fine-tune on Llama-3 8B to enhance it's proficiency on the Catalan Language.
19
  """
20
+
21
+ if not torch.cuda.is_available():
22
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
23
+
24
+
25
+ if torch.cuda.is_available():
26
+ model_id = "catallama/CataLlama-v0.1-Instruct-DPO"
27
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
29
+
30
+
31
+ @spaces.GPU
32
+ def generate(
33
+ message: str,
34
+ chat_history: list[tuple[str, str]],
35
+ system_prompt: str,
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.6,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50,
40
+ repetition_penalty: float = 1.2,
41
+ ) -> Iterator[str]:
42
+
43
+ conversation = []
44
+ if system_prompt:
45
+ conversation.append({"role": "system", "content": system_prompt})
46
+ for user, assistant in chat_history:
47
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
48
+ conversation.append({"role": "user", "content": message})
49
+
50
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
51
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
+ input_ids = input_ids.to(model.device)
55
+
56
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
57
+ generate_kwargs = dict(
58
+ {"input_ids": input_ids},
59
+ streamer=streamer,
60
+ max_new_tokens=max_new_tokens,
61
+ do_sample=True,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ temperature=temperature,
65
+ num_beams=1,
66
+ repetition_penalty=repetition_penalty,
67
+ )
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
69
+ t.start()
70
+
71
+ outputs = []
72
+ for text in streamer:
73
+ outputs.append(text)
74
+ yield "".join(outputs)
75
+
76
+
77
  """
78
  client = InferenceClient("catallama/CataLlama-v0.1-Instruct-DPO")
79
 
 
 
80
  def respond(
81
  message,
82
  history: list[tuple[str, str]],
 
108
 
109
  response += token
110
  yield response
 
 
 
111
  """
112
+
113
+
114
+ chat_interface = gr.ChatInterface(
115
+ fn=generate,
116
  additional_inputs=[
117
  gr.Textbox(value="Ets un chatbot amigable.", label="System message"),
118
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
 
127
  ],
128
  )
129
 
130
+ with gr.Blocks() as demo:
131
+ gr.Markdown(DESCRIPTION)
132
+ chat_interface.render()
133
+
134
 
135
  if __name__ == "__main__":
136
  demo.launch()