msaifee commited on
Commit
bf86fa8
1 Parent(s): bc37ab3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -22
app.py CHANGED
@@ -14,25 +14,29 @@ model_name = "meta-llama/Llama-3.2-1B"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token)
16
 
17
- # Define the inference function
18
- def generate_text(prompt, max_length, temperature):
19
- inputs = tokenizer(prompt, return_tensors="pt")
20
- output = model.generate(inputs['input_ids'], max_length=max_length, temperature=temperature)
21
- return tokenizer.decode(output[0], skip_special_tokens=True)
22
-
23
-
24
- # Create the Gradio interface
25
- iface = gr.Interface(
26
- fn=generate_text,
27
- inputs=[
28
- gr.Textbox(label="Enter your prompt", placeholder="Start typing..."),
29
- gr.Slider(minimum=50, maximum=200, label="Max Length", value=100),
30
- gr.Slider(minimum=0.1, maximum=1.0, label="Temperature", value=0.7),
31
- ],
32
- outputs="text",
33
- title="LLaMA 3.2 Text Generator",
34
- description="Enter a prompt to generate text using the LLaMA 3.2 model.",
35
- )
36
-
37
- # Launch the Gradio app
38
- iface.launch()
 
 
 
 
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token)
16
 
17
+ pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, device_map="auto")
18
+
19
+ pipe("How are you doing?")
20
+
21
+ # # Define the inference function
22
+ # def generate_text(prompt, max_length, temperature):
23
+ # inputs = tokenizer(prompt, return_tensors="pt")
24
+ # output = model.generate(inputs['input_ids'], max_length=max_length, temperature=temperature)
25
+ # return tokenizer.decode(output[0], skip_special_tokens=True)
26
+
27
+
28
+ # # Create the Gradio interface
29
+ # iface = gr.Interface(
30
+ # fn=generate_text,
31
+ # inputs=[
32
+ # gr.Textbox(label="Enter your prompt", placeholder="Start typing..."),
33
+ # gr.Slider(minimum=50, maximum=200, label="Max Length", value=100),
34
+ # gr.Slider(minimum=0.1, maximum=1.0, label="Temperature", value=0.7),
35
+ # ],
36
+ # outputs="text",
37
+ # title="LLaMA 3.2 Text Generator",
38
+ # description="Enter a prompt to generate text using the LLaMA 3.2 model.",
39
+ # )
40
+
41
+ # # Launch the Gradio app
42
+ # iface.launch()