ari9dam commited on
Commit
4b58010
1 Parent(s): 03165f4

device map

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -13,7 +13,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
 
15
  model_id = "microsoft/Orca-2-13b"
16
- model = transformers.AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
17
 
18
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False)
19
 
 
13
 
14
 
15
  model_id = "microsoft/Orca-2-13b"
16
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
17
 
18
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False)
19