codys12 commited on
Commit
faf8f3f
1 Parent(s): 8824f88
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -19,8 +19,9 @@ DEFAULT_MAX_NEW_TOKENS = 1024
19
  MAX_INPUT_TOKEN_LENGTH = 4096
20
 
21
  if torch.cuda.is_available():
22
- model_id = "mistralai/Mistral-7B-Instruct-v0.1"
23
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
 
26
 
@@ -35,11 +36,16 @@ def generate(
35
  repetition_penalty: float = 1.2,
36
  ) -> Iterator[str]:
37
  conversation = []
 
38
  for user, assistant in chat_history:
39
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
41
 
42
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to("cuda")
43
  if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
44
  input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
45
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
19
  MAX_INPUT_TOKEN_LENGTH = 4096
20
 
21
  if torch.cuda.is_available():
22
+ model_id = "codys12/MergeLlama-7b"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
24
+ model.cuda()
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
 
27
 
 
36
  repetition_penalty: float = 1.2,
37
  ) -> Iterator[str]:
38
  conversation = []
39
+ current_input = ""
40
  for user, assistant in chat_history:
41
+ input += user
42
+ input += assistant
43
+
44
+ current_input += message
45
+
46
+ device = "cuda:0"
47
+ inputs_ids = tokenizer(message, return_tensors="pt").to(device)
48
 
 
49
  if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
50
  input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")