binxu commited on
Commit
27e2c26
1 Parent(s): a1e2e84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -7,9 +7,14 @@ from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
7
  # https://huggingface.co/docs/hub/spaces-sdks-gradio
8
  # model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2-deep")
9
  # generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese')
 
 
 
 
 
10
  tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
11
  model = T5ForConditionalGeneration.from_pretrained("binxu/mengzi-t5-base-finetuned-punctuation")
12
- text2text_generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
13
 
14
  def generate(prompt, ):
15
  torch.manual_seed(42)
 
7
  # https://huggingface.co/docs/hub/spaces-sdks-gradio
8
  # model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2-deep")
9
  # generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese')
10
+ if torch.cuda.is_available():
11
+ device = 0
12
+ else:
13
+ device = -1
14
+
15
  tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
16
  model = T5ForConditionalGeneration.from_pretrained("binxu/mengzi-t5-base-finetuned-punctuation")
17
+ text2text_generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=device)
18
 
19
  def generate(prompt, ):
20
  torch.manual_seed(42)