Tonic commited on
Commit
27b96e1
1 Parent(s): 6bb5d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -122,17 +122,14 @@ def compute_embeddings(selected_task, input_text):
122
  print(f"Selected task not found: {selected_task}")
123
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
124
 
125
- max_length = 2048
126
- processed_texts = [f'Instruct: {task_description}\nQuery: {input_text}']
127
 
128
- batch_dict = tokenizer(processed_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
129
- batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
130
- batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
131
- batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
132
- outputs = model(**batch_dict)
133
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
134
- embeddings = F.normalize(embeddings, p=2, dim=1)
135
- embeddings_list = embeddings.detach().cpu().numpy().tolist()
136
  clear_cuda_cache()
137
  return embeddings_list
138
 
 
122
  print(f"Selected task not found: {selected_task}")
123
  return f"Error: Task '{selected_task}' not found. Please select a valid task."
124
 
125
+ query_prefix = f"Instruct: {task_description}\nQuery: "
126
+ queries = [input_text]
127
 
128
+ # Get the embeddings
129
+ query_embeddings = model.encode(queries, instruction=query_prefix, max_length=4096)
130
+ # Normalize embeddings
131
+ query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
132
+ embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
 
 
 
133
  clear_cuda_cache()
134
  return embeddings_list
135