ValueError: Shapes (1,10,13,256) and (1,10,13,128) cannot be broadcast.

#2
by depasquale - opened
MLX Community org

I appreciate the work that people are doing to make models available for MLX, but I've found that many of the models uploaded to mlx-community don't work. For example, this is what I get when I run this one:

from mlx_lm import load, generate

tokenizer_config = {
    'eos_token': "<|end|>"
}

model_id = "mlx-community/Phi-3-medium-4k-instruct-4bit"
user_message = "Name a color."
model, tokenizer = load(model_id, tokenizer_config=tokenizer_config)
prompt = f"<s><|user|>\n{user_message}<|end|>\n<|assistant|>\n"
response = generate(model, tokenizer, prompt=prompt, temp=0.5, max_tokens=1000, verbose=True)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 12
     10 model, tokenizer = load(model_id, tokenizer_config=tokenizer_config)
     11 prompt = f"<s><|user|>\n{user_message}<|end|>\n<|assistant|>\n"
---> 12 response = generate(model, tokenizer, prompt=prompt, temp=0.5, max_tokens=1000, verbose=True)

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/utils.py:247, in generate(model, tokenizer, prompt, temp, max_tokens, verbose, formatter, repetition_penalty, repetition_context_size, top_p, logit_bias)
    244 tic = time.perf_counter()
    245 detokenizer.reset()
--> 247 for (token, prob), n in zip(
    248     generate_step(
    249         prompt_tokens,
    250         model,
    251         temp,
    252         repetition_penalty,
    253         repetition_context_size,
    254         top_p,
    255         logit_bias,
    256     ),
    257     range(max_tokens),
    258 ):
    259     if n == 0:
    260         prompt_time = time.perf_counter() - tic

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/utils.py:195, in generate_step(prompt, model, temp, repetition_penalty, repetition_context_size, top_p, logit_bias)
    192             repetition_context = repetition_context[-repetition_context_size:]
    193     return y, prob
--> 195 y, p = _step(y)
    197 mx.async_eval(y)
    198 while True:

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/utils.py:178, in generate_step.<locals>._step(y)
    176 def _step(y):
    177     nonlocal repetition_context
--> 178     logits = model(y[None], cache=cache)
    179     logits = logits[:, -1, :]
    181     if repetition_penalty:

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/models/phi3.py:183, in Model.__call__(self, inputs, cache)
    178 def __call__(
    179     self,
    180     inputs: mx.array,
    181     cache=None,
    182 ):
--> 183     out = self.model(inputs, cache)
    184     return self.lm_head(out)

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/models/phi3.py:165, in Phi3Model.__call__(self, inputs, cache)
    162     cache = [None] * len(self.layers)
    164 for layer, c in zip(self.layers, cache):
--> 165     h = layer(h, mask, c)
    167 return self.norm(h)

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/models/phi3.py:129, in TransformerBlock.__call__(self, x, mask, cache)
    123 def __call__(
    124     self,
    125     x: mx.array,
    126     mask: Optional[mx.array] = None,
    127     cache: Optional[Tuple[mx.array, mx.array]] = None,
    128 ) -> mx.array:
--> 129     r = self.self_attn(self.input_layernorm(x), mask, cache)
    130     h = x + r
    131     r = self.mlp(self.post_attention_layernorm(h))

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/models/phi3.py:86, in Attention.__call__(self, x, mask, cache)
     84     queries = self.rope(queries, offset=cache.offset)
     85     keys = self.rope(keys, offset=cache.offset)
---> 86     keys, values = cache.update_and_fetch(keys, values)
     87 else:
     88     queries = self.rope(queries)

File ~/files/projects/python/test-env/.venv/lib/python3.12/site-packages/mlx_lm/models/base.py:34, in KVCache.update_and_fetch(self, keys, values)
     31         self.keys, self.values = new_k, new_v
     33 self.offset += keys.shape[2]
---> 34 self.keys[..., prev : self.offset, :] = keys
     35 self.values[..., prev : self.offset, :] = values
     36 return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]

ValueError: Shapes (1,10,13,256) and (1,10,13,128) cannot be broadcast.
MLX Community org
MLX Community org

Great, thanks!

depasquale changed discussion status to closed
MLX Community org

Most welcome!

Sign up or log in to comment