|
import functools |
|
import subprocess |
|
import sys |
|
from contextlib import contextmanager, redirect_stdout |
|
from io import StringIO |
|
from pathlib import Path |
|
from unittest import mock |
|
from unittest.mock import Mock, call, ANY |
|
|
|
import torch |
|
|
|
wd = Path(__file__).parent.parent.absolute() |
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
def load_generate_script(): |
|
sys.path.append(str(wd)) |
|
|
|
import generate as generate |
|
|
|
return generate |
|
|
|
|
|
def test_generate(): |
|
generate = load_generate_script() |
|
|
|
from lit_llama.model import LLaMA, LLaMAConfig |
|
|
|
T, C = 5, 3 |
|
logits = torch.randn(T, C) |
|
input_idx = torch.randint(10, size=(T,)) |
|
|
|
config = LLaMAConfig(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) |
|
model = LLaMA(config) |
|
max_new_tokens = 20 |
|
|
|
multinomial_results = [] |
|
original_multinomial = torch.multinomial |
|
|
|
def multinomial(*args, **kwargs): |
|
out = original_multinomial(*args, **kwargs) |
|
multinomial_results.append(out) |
|
return out |
|
|
|
with mock.patch("torch.multinomial", multinomial): |
|
out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10, top_k=4) |
|
|
|
assert out.size(0) == T + max_new_tokens |
|
multinomial_results = torch.hstack(multinomial_results) |
|
expected = torch.cat((input_idx, multinomial_results)) |
|
assert out.shape == expected.shape |
|
torch.testing.assert_close(out, expected) |
|
|
|
|
|
@mock.patch("torch.cuda.is_bf16_supported", return_value=False) |
|
def test_main(tmp_path, monkeypatch): |
|
generate = load_generate_script() |
|
|
|
checkpoint_path = tmp_path / "ckpt" |
|
checkpoint_path.touch() |
|
tokenizer_path = tmp_path / "tokenizer" |
|
tokenizer_path.touch() |
|
|
|
class FabricMock(Mock): |
|
@property |
|
def device(self): |
|
return torch.device("cpu") |
|
|
|
@contextmanager |
|
def init_module(self, empty_init): |
|
yield |
|
|
|
monkeypatch.setattr(generate.L, "Fabric", FabricMock) |
|
model_mock = Mock() |
|
monkeypatch.setattr(generate.LLaMA, "from_name", model_mock) |
|
lookup_mock = Mock(return_value="1T") |
|
monkeypatch.setattr(generate, "llama_model_lookup", lookup_mock) |
|
load_mock = Mock() |
|
load_mock.return_value = load_mock |
|
load_mock.__enter__ = Mock() |
|
load_mock.__exit__ = Mock() |
|
monkeypatch.setattr(generate.torch, "load", load_mock) |
|
monkeypatch.setattr(generate, "lazy_load", load_mock) |
|
tokenizer_mock = Mock() |
|
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]]) |
|
tokenizer_mock.return_value.decode.return_value = "foo bar baz" |
|
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) |
|
generate_mock = Mock() |
|
generate_mock.return_value = torch.tensor([[3, 2, 1]]) |
|
monkeypatch.setattr(generate, "generate", generate_mock) |
|
|
|
num_samples = 2 |
|
out = StringIO() |
|
with redirect_stdout(out): |
|
generate.main( |
|
checkpoint_path=checkpoint_path, |
|
tokenizer_path=tokenizer_path, |
|
temperature=2.0, |
|
top_k=2, |
|
num_samples=num_samples, |
|
) |
|
|
|
model_mock.assert_called_once_with("1T") |
|
load_mock.assert_called_once_with(checkpoint_path) |
|
tokenizer_mock.assert_called_once_with(tokenizer_path) |
|
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples |
|
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) |
|
assert generate_mock.mock_calls == [call(ANY, ANY, 50, temperature=2.0, top_k=2)] * num_samples |
|
|
|
assert out.getvalue() == "foo bar baz\n" * num_samples |
|
|
|
|
|
def test_cli(): |
|
cli_path = wd / "generate.py" |
|
output = subprocess.check_output([sys.executable, cli_path, "-h"]) |
|
output = str(output.decode()) |
|
assert "Generates text samples" in output |
|
|