Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -122,17 +122,14 @@ def compute_embeddings(selected_task, input_text):
|
|
122 |
print(f"Selected task not found: {selected_task}")
|
123 |
return f"Error: Task '{selected_task}' not found. Please select a valid task."
|
124 |
|
125 |
-
|
126 |
-
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
134 |
-
embeddings = F.normalize(embeddings, p=2, dim=1)
|
135 |
-
embeddings_list = embeddings.detach().cpu().numpy().tolist()
|
136 |
clear_cuda_cache()
|
137 |
return embeddings_list
|
138 |
|
|
|
122 |
print(f"Selected task not found: {selected_task}")
|
123 |
return f"Error: Task '{selected_task}' not found. Please select a valid task."
|
124 |
|
125 |
+
query_prefix = f"Instruct: {task_description}\nQuery: "
|
126 |
+
queries = [input_text]
|
127 |
|
128 |
+
# Get the embeddings
|
129 |
+
query_embeddings = model.encode(queries, instruction=query_prefix, max_length=4096)
|
130 |
+
# Normalize embeddings
|
131 |
+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
132 |
+
embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
|
|
|
|
|
|
|
133 |
clear_cuda_cache()
|
134 |
return embeddings_list
|
135 |
|