mcysqrd commited on
Commit
b171526
1 Parent(s): 0901a59

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -13
README.md CHANGED
@@ -19,17 +19,33 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
20
  device = "cuda" # the device to load the model onto
21
 
22
- model = AutoModelForCausalLM.from_pretrained("mcysqrd/MODULARMOJO_Mistral_V1")
23
- tokenizer = AutoTokenizer.from_pretrained("mcysqrd/MODULARMOJO_Mistral_V1")
24
-
25
- message = "What can you tell me about mojo roadmap for Scoping and mutability of statement variables?"
26
-
27
- encodeds = tokenizer.apply_chat_template(message, return_tensors="pt")
28
-
29
- model_inputs = encodeds.to(device)
30
- model.to(device)
31
-
32
- generated_ids = model.generate(model_inputs, max_new_tokens=1650, do_sample=True, temperature = 0.01)
33
- decoded = tokenizer.batch_decode(generated_ids)
34
- print(decoded[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ```
 
19
 
20
  device = "cuda" # the device to load the model onto
21
 
22
+ model_name = "mcysqrd/MODULARMOJO_Mistral_V1"
23
+ model = AutoModelForCausalLM.from_pretrained(model_name,
24
+ use_flash_attention_2=True,
25
+ max_memory={0: "24GB"},
26
+ device_map="auto",
27
+ trust_remote_code=True,
28
+ low_cpu_mem_usage=True,
29
+ return_dict=True,
30
+ torch_dtype=torch.bfloat16,
31
+ )
32
+
33
+ eval_prompt = """ what can you tell me about MODULAR_MOJO mojo_roadmap Scoping and mutability of statement variables ? """
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name,add_bos_token=True,trust_remote_code=True)
36
+
37
+ model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
38
+
39
+ model_to_save.config.use_cache = True
40
+
41
+ def stream(user_prompt):
42
+ runtimeFlag = "cuda:0"
43
+ system_prompt = 'The following is an excerpt from MODULAR_MOJO from the section on roadmap.'
44
+ B_INST, E_INST = "[INST]", "[/INST]"
45
+ prompt = f"{system_prompt}{B_INST}{user_prompt.strip()}\n{E_INST}"
46
+ inputs = tokenizer([prompt], return_tensors="pt").to(runtimeFlag)
47
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
48
+ _ = model.generate(**inputs, streamer=streamer, max_new_tokens=200)
49
+
50
+ stream("What can you tell me about MODULAR_MOJO mojo_roadmap Scoping and mutability of statement variables?")
51
  ```