Spaces:
Sleeping
Sleeping
FantasticGNU
commited on
Commit
•
94753a2
1
Parent(s):
f0a1818
Update model/openllama.py
Browse files- model/openllama.py +2 -2
model/openllama.py
CHANGED
@@ -215,7 +215,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
215 |
# # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
|
216 |
# # self.llama_model.to(torch.float16)
|
217 |
# # try:
|
218 |
-
self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.
|
219 |
# # except:
|
220 |
# pass
|
221 |
# finally:
|
@@ -223,7 +223,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
223 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
224 |
self.llama_model.print_trainable_parameters()
|
225 |
|
226 |
-
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.
|
227 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
228 |
self.llama_tokenizer.padding_side = "right"
|
229 |
print ('Language decoder initialized.')
|
|
|
215 |
# # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
|
216 |
# # self.llama_model.to(torch.float16)
|
217 |
# # try:
|
218 |
+
self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.bfloat16, device_map=device_map, offload_folder="offload", offload_state_dict = True)
|
219 |
# # except:
|
220 |
# pass
|
221 |
# finally:
|
|
|
223 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
224 |
self.llama_model.print_trainable_parameters()
|
225 |
|
226 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
227 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
228 |
self.llama_tokenizer.padding_side = "right"
|
229 |
print ('Language decoder initialized.')
|