|
|
|
|
|
|
|
|
|
from torch import Tensor |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
|
|
|
|
@dataclass |
|
class InferenceParams: |
|
"""Inference parameters that are passed to the main model in order |
|
to efficienly calculate and store the context during inference.""" |
|
|
|
max_seqlen: int |
|
max_batch_size: int |
|
seqlen_offset: int = 0 |
|
batch_size_offset: int = 0 |
|
key_value_memory_dict: dict = field(default_factory=dict) |
|
lengths_per_sample: Optional[Tensor] = None |
|
|
|
def reset(self, max_seqlen, max_batch_size): |
|
self.max_seqlen = max_seqlen |
|
self.max_batch_size = max_batch_size |
|
self.seqlen_offset = 0 |
|
if self.lengths_per_sample is not None: |
|
self.lengths_per_sample.zero_() |
|
|
|
|
|
@dataclass |
|
class RecurrentInferenceParams: |
|
"""Inference parameters passed to blocks with recurrent mode.""" |
|
|
|
fir_filter_length: int = 3 |
|
state_dim: int = 16 |
|
seqlen_offset: int = 0 |
|
fir_state_dict: dict = field(default_factory=dict) |
|
state_dict: dict = field(default_factory=dict) |
|
|
|
def reset(self): |
|
self.fir_filter_length = 3 |
|
self.state_dim = 16 |
|
self.seqlen_offset = 0 |
|
|