ari9dam commited on
Commit
5eb35d4
β€’
1 Parent(s): bc393d0

gpu 80bit inference

Browse files
Files changed (2) hide show
  1. app.py +8 -5
  2. requirements.txt +4 -2
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
-
5
  import gradio as gr
6
  import torch
7
  import transformers
@@ -11,8 +11,9 @@ MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
 
14
  model_id = "microsoft/Orca-2-13b"
15
- model = transformers.AutoModelForCausalLM.from_pretrained(model_id)
16
 
17
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False)
18
 
@@ -21,15 +22,17 @@ user_message = "How can you determine if a restaurant is popular among locals or
21
 
22
  DESCRIPTION = """
23
  # Orca-2 13B
24
- This Space demonstrates model [Orca-2-13B](https://huggingface.co/microsoft/Orca-2-13B) by Microsoft, a Llama 2 derivate model with 13B parameters fine-tuned for sigle turn instructions. This space is running on Inference Endpoints using text-generation-inference library. If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://ui.endpoints.huggingface.co/).
25
 
26
  The system message is set to be the cautious system message:
27
  You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.
28
- Feel free to modify it in the additional input section. The demo uses greedy decoding.
29
 
30
  πŸ”Ž For more details about the Orca family of models take a look [at our blog post](https://msft.it/6042iGtzK).
31
  πŸ”¨ Looking for lighter versions of Orca-2? πŸ‡ Check out the [7B Chat model](https://huggingface.co/spaces/huggingface-projects/Orca-2-7b). Note: Orca 2 is licensed under the [Microsoft Research License](LICENSE). Llama 2 is licensed under the [LLAMA 2 Community License](https://ai.meta.com/llama/license/).
32
  """
 
 
33
 
34
  # Function to combine system message and user
35
  def to_prompt(conversations):
@@ -43,7 +46,7 @@ def to_prompt(conversations):
43
  inputs = tokenizer(prompt, return_tensors='pt').input_ids
44
  return inputs
45
 
46
-
47
  def generate(
48
  message: str,
49
  chat_history: list[tuple[str, str]],
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
+ import spaces
5
  import gradio as gr
6
  import torch
7
  import transformers
 
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
+
15
  model_id = "microsoft/Orca-2-13b"
16
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
17
 
18
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False)
19
 
 
22
 
23
  DESCRIPTION = """
24
  # Orca-2 13B
25
+ This Space demonstrates model [Orca-2-13B](https://huggingface.co/microsoft/Orca-2-13B) by Microsoft, a Llama 2 derivative with 13B parameters fine-tuned for sigle turn instructions. This space is <b>running 8-bit inference with greedy decoding</b>.
26
 
27
  The system message is set to be the cautious system message:
28
  You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.
29
+ Feel free to modify it in the additional input section.
30
 
31
  πŸ”Ž For more details about the Orca family of models take a look [at our blog post](https://msft.it/6042iGtzK).
32
  πŸ”¨ Looking for lighter versions of Orca-2? πŸ‡ Check out the [7B Chat model](https://huggingface.co/spaces/huggingface-projects/Orca-2-7b). Note: Orca 2 is licensed under the [Microsoft Research License](LICENSE). Llama 2 is licensed under the [LLAMA 2 Community License](https://ai.meta.com/llama/license/).
33
  """
34
+ if not torch.cuda.is_available():
35
+ DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
36
 
37
  # Function to combine system message and user
38
  def to_prompt(conversations):
 
46
  inputs = tokenizer(prompt, return_tensors='pt').input_ids
47
  return inputs
48
 
49
+ @spaces.GPU
50
  def generate(
51
  message: str,
52
  chat_history: list[tuple[str, str]],
requirements.txt CHANGED
@@ -4,6 +4,7 @@ altair==5.1.2
4
  annotated-types==0.6.0
5
  anyio==3.7.1
6
  attrs==23.1.0
 
7
  certifi==2023.11.17
8
  charset-normalizer==3.3.2
9
  click==8.1.7
@@ -56,6 +57,7 @@ requests==2.31.0
56
  rich==13.7.0
57
  rpds-py==0.13.1
58
  safetensors==0.4.0
 
59
  semantic-version==2.10.0
60
  sentencepiece==0.1.99
61
  shellingham==1.5.4
@@ -67,7 +69,7 @@ sympy==1.12
67
  tokenizers==0.13.3
68
  tomlkit==0.12.0
69
  toolz==0.12.0
70
- torch
71
  tqdm==4.66.1
72
  transformers==4.33.1
73
  triton==2.1.0
@@ -77,4 +79,4 @@ tzdata==2023.3
77
  urllib3==2.1.0
78
  uvicorn==0.24.0.post1
79
  websockets==11.0.3
80
- zipp==3.17.0
 
4
  annotated-types==0.6.0
5
  anyio==3.7.1
6
  attrs==23.1.0
7
+ bitsandbytes==0.41.1
8
  certifi==2023.11.17
9
  charset-normalizer==3.3.2
10
  click==8.1.7
 
57
  rich==13.7.0
58
  rpds-py==0.13.1
59
  safetensors==0.4.0
60
+ scipy==1.11.4
61
  semantic-version==2.10.0
62
  sentencepiece==0.1.99
63
  shellingham==1.5.4
 
69
  tokenizers==0.13.3
70
  tomlkit==0.12.0
71
  toolz==0.12.0
72
+ torch --index-url https://download.pytorch.org/whl/cu118
73
  tqdm==4.66.1
74
  transformers==4.33.1
75
  triton==2.1.0
 
79
  urllib3==2.1.0
80
  uvicorn==0.24.0.post1
81
  websockets==11.0.3
82
+ zipp==3.17.0