goendalf666 commited on
Commit
9f0d82e
1 Parent(s): 0bf64ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -3
app.py CHANGED
@@ -1,7 +1,43 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ # Initialize the model and tokenizer
6
+ cuda = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+ model = AutoModelForCausalLM.from_pretrained("goendalf666/salesGPT_v2").to(cuda)
8
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
9
 
10
+ def interact_with_model(user_input):
11
+ # Construct conversation text for the model
12
+ conversation_text = (
13
+ "You are in the role of a Salesman. "
14
+ "Here is a conversation: "
15
+ f"Customer: {user_input} Salesman: "
16
+ )
17
+
18
+ # Tokenize inputs
19
+ inputs = tokenizer(conversation_text, return_tensors="pt").to(cuda)
20
+
21
+ # Generate response
22
+ outputs = model.generate(**inputs, max_length=512)
23
+ response_text = tokenizer.batch_decode(outputs)[0]
24
+
25
+ # Extract only the newly generated text
26
+ new_text_start = len(conversation_text)
27
+ new_generated_text = response_text[new_text_start:].strip()
28
+
29
+ # Find where the next "Customer:" is, and truncate the text there
30
+ end_index = new_generated_text.find("Customer:")
31
+ if end_index != -1:
32
+ new_generated_text = new_generated_text[:end_index].strip()
33
+
34
+ # Ignore if the model puts "Salesman: " itself at the beginning
35
+ if new_generated_text.startswith("Salesman:"):
36
+ new_generated_text = new_generated_text[len("Salesman:"):].strip()
37
+
38
+ # Return the model's response
39
+ return new_generated_text
40
+
41
+ # Create Gradio Interface and launch it
42
+ iface = gr.Interface(fn=interact_with_model, inputs="text", outputs="text")
43
  iface.launch()