Spaces:
Paused
Paused
lengyue233
commited on
Commit
•
cb728f4
1
Parent(s):
d69caf0
Update tools/llama/generate.py
Browse files- tools/llama/generate.py +1 -15
tools/llama/generate.py
CHANGED
@@ -154,16 +154,11 @@ def decode_one_token_ar_agent(
|
|
154 |
logits = x.logits # [:, -1:]
|
155 |
hidden_states = x.hidden_states # [:, -1:]
|
156 |
|
157 |
-
sampling_kwargs_main = sampling_kwargs.copy()
|
158 |
-
sampling_kwargs_main["temperature"] = 0.1
|
159 |
-
sampling_kwargs_main["top_p"] = 0.1
|
160 |
-
sampling_kwargs_main["repetition_penalty"] = 1.0
|
161 |
-
|
162 |
codebooks = [
|
163 |
sample_agent(
|
164 |
logits,
|
165 |
previous_tokens=None, # Disable repetition penalty for the token codebook
|
166 |
-
**
|
167 |
)[0]
|
168 |
]
|
169 |
|
@@ -194,15 +189,6 @@ def decode_one_token_ar_agent(
|
|
194 |
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
195 |
)
|
196 |
|
197 |
-
# for i in range(codebooks.size(1) - 1):
|
198 |
-
# codebooks[:, i + 1, :] = torch.masked_fill(
|
199 |
-
# codebooks[:, i + 1, :],
|
200 |
-
# codebooks[:, :1, :] != semantic_id,
|
201 |
-
# CODEBOOK_PAD_TOKEN_ID + i * 1024,
|
202 |
-
# )
|
203 |
-
|
204 |
-
# print(codebooks)
|
205 |
-
|
206 |
return codebooks
|
207 |
|
208 |
|
|
|
154 |
logits = x.logits # [:, -1:]
|
155 |
hidden_states = x.hidden_states # [:, -1:]
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
codebooks = [
|
158 |
sample_agent(
|
159 |
logits,
|
160 |
previous_tokens=None, # Disable repetition penalty for the token codebook
|
161 |
+
**sampling_kwargs,
|
162 |
)[0]
|
163 |
]
|
164 |
|
|
|
189 |
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
190 |
)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
return codebooks
|
193 |
|
194 |
|