temporary0-0name commited on
Commit
fa9c2df
1 Parent(s): f9eafdb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from gpt_class import GPTConfig, GPT
5
+ import tiktoken
6
+
7
+ # Setup device
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Load model
11
+ state_dict = torch.load('log/model_51999.pt', map_location=device)
12
+ config = state_dict['config']
13
+ model = GPT(config)
14
+ model.load_state_dict(state_dict['model'])
15
+ model.to(device)
16
+ model.eval()
17
+
18
+ # Set seed for reproducibility
19
+ torch.manual_seed(42)
20
+ torch.cuda.manual_seed_all(42)
21
+
22
+ # Get tokenizer
23
+ tokenizer = tiktoken.get_encoding("gpt2")
24
+
25
+ def generate_text(example, num_return_sequences='4', max_length='64'):
26
+ num_return_sequences = int(num_return_sequences) if num_return_sequences.isdigit() else 4
27
+ max_length = int(max_length) if max_length.isdigit() else 64
28
+
29
+ model.eval()
30
+ tokens = tokenizer.encode(example)
31
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).repeat(num_return_sequences, 1)
32
+ tokens = tokens.to(device)
33
+ sample_rng = torch.Generator(device=device)
34
+
35
+ xgen = tokens
36
+ while xgen.size(1) < max_length:
37
+ with torch.no_grad():
38
+ with torch.autocast(device_type=device):
39
+ logits, _ = model(xgen) # Assumes model returns logits and optional loss
40
+ logits = logits[:, -1, :] # Get last token logits
41
+ probs = F.softmax(logits, dim=-1)
42
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
43
+ ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
44
+ xcol = torch.gather(topk_indices, -1, ix)
45
+ xgen = torch.cat((xgen, xcol), dim=1)
46
+
47
+ results = []
48
+ for i in range(num_return_sequences):
49
+ tokens = xgen[i, :max_length].tolist()
50
+ decoded = tokenizer.decode(tokens)
51
+ results.append(decoded)
52
+ return "\n\n".join(results)
53
+
54
+ # Create Gradio interface
55
+ iface = gr.Interface(
56
+ fn=generate_text,
57
+ inputs=[
58
+ gr.components.Textbox(label="Prompt"),
59
+ gr.components.Textbox(label="Number of Sequences [1-4]"),
60
+ gr.components.Textbox(label="Maximum Length [32-128]")
61
+ ],
62
+ outputs=gr.components.Textbox(label="Generated Text"),
63
+ title="Text Generator",
64
+ description="Enter a prompt to generate text using a GPT model. Adjust the number of sequences and the maximum length as needed.",
65
+ examples=[
66
+ ["It is raining and my family", "2", "64"],
67
+ ["We entered into the forest and", "2", "64"],
68
+ ["I sat for doing my homework", "2", "64"]
69
+ ]
70
+ )
71
+
72
+ iface.launch(share=True)