|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import copy |
|
import os |
|
import pathlib |
|
import typing |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
|
|
str_type_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} |
|
|
|
class BaichuanModel(nn.Module): |
|
def __init__(self, |
|
head_num, |
|
size_per_head, |
|
inter_size, |
|
vocab_size, |
|
rotary_embedding_dim, |
|
start_id, end_id, layer_num, |
|
max_seq_len: int, |
|
layernorm_eps, |
|
tensor_para_size: int, |
|
pipeline_para_size: int, |
|
use_gptj_residual, |
|
lib_path: typing.Union[str, pathlib.Path], |
|
model_path, |
|
memopt_mode: int = 0, |
|
inference_data_type: str = "fp16", |
|
weights_data_type: typing.Union[str, np.dtype] = np.float32): |
|
super().__init__() |
|
self.head_num = head_num |
|
self.size_per_head = size_per_head |
|
self.inter_size = inter_size |
|
self.vocab_size = vocab_size |
|
self.rotary_embedding_dim = rotary_embedding_dim |
|
self.start_id = start_id |
|
self.end_id = end_id |
|
self.max_seq_len = max_seq_len |
|
self.layer_num = layer_num |
|
self.use_gptj_residual = use_gptj_residual |
|
self.layernorm_eps = layernorm_eps |
|
self.memopt_mode = memopt_mode |
|
|
|
|
|
self.tensor_para_size = tensor_para_size |
|
self.pipeline_para_size = pipeline_para_size |
|
self.build_model = False |
|
self.weights_data_type = weights_data_type |
|
self.inference_data_type = inference_data_type |
|
|
|
assert torch.cuda.is_available(), "CUDA is required for this model." |
|
|
|
assert head_num % tensor_para_size == 0, "head_num must be a multiple of tensor_para_size." |
|
assert layer_num % pipeline_para_size == 0, "layer_num must be a multiple of pipeline_para_size." |
|
|
|
|
|
torch.classes.load_library(os.path.abspath(lib_path)) |
|
|
|
|
|
try: |
|
dist.init_process_group(backend='mpi') |
|
except: |
|
print("[INFO] WARNING: Have initialized the process group") |
|
self.rank = dist.get_rank() |
|
self.device_count = torch.cuda.device_count() |
|
self.device = self.rank % self.device_count |
|
torch.cuda.set_device(self.device) |
|
|
|
world_size = dist.get_world_size() |
|
|
|
assert world_size == tensor_para_size * pipeline_para_size, "tensor_para_size * pipeline_para_size must be equal to world_size." |
|
|
|
self.tensor_para_rank = self.rank % self.tensor_para_size |
|
self.pipeline_para_rank = self.rank // self.tensor_para_size |
|
|
|
self.model = torch.classes.FasterTransformer.BaichuanOp( |
|
self.head_num, self.size_per_head, self.inter_size, |
|
self.layer_num, |
|
self.vocab_size, |
|
self.rotary_embedding_dim, |
|
self.layernorm_eps, |
|
self.start_id, self.end_id, |
|
self.tensor_para_size, self.pipeline_para_size, |
|
self.max_seq_len, |
|
self.use_gptj_residual, |
|
self.memopt_mode, |
|
model_path, |
|
self.weights_data_type, |
|
self.inference_data_type) |
|
|
|
self.build_model = True |
|
torch.cuda.empty_cache() |
|
|
|
def forward(self, |
|
start_ids: torch.Tensor, |
|
start_lengths: torch.Tensor, |
|
output_len, |
|
beam_width=1, |
|
top_k: torch.Tensor = None, |
|
top_p: torch.Tensor = None, |
|
beam_search_diversity_rate: torch.Tensor = None, |
|
temperature: torch.Tensor = None, |
|
len_penalty: torch.Tensor = None, |
|
repetition_penalty: torch.Tensor = None, |
|
random_seed: torch.Tensor = None, |
|
return_output_length=False, |
|
return_cum_log_probs=0): |
|
|
|
input_len = start_ids.size(1) |
|
assert input_len > 0, "input len must be larger than zero. For an unconditional case, use start_id as the first token." |
|
|
|
|
|
input_ids = start_ids.cuda(self.device) |
|
input_lengths = start_lengths.cuda(self.device) |
|
|
|
outputs = self.model.forward(input_ids, |
|
input_lengths, |
|
output_len, |
|
beam_width, |
|
top_k, |
|
top_p, |
|
beam_search_diversity_rate, |
|
temperature, |
|
len_penalty, |
|
repetition_penalty, |
|
random_seed, |
|
return_cum_log_probs) |
|
|
|
if return_cum_log_probs == 0: |
|
output_ids, output_lengths = outputs |
|
else: |
|
output_ids, output_lengths, output_cum_log_probs = outputs |
|
if return_output_length: |
|
if return_cum_log_probs > 0: |
|
return output_ids, output_lengths, output_cum_log_probs |
|
else: |
|
return output_ids, output_lengths |
|
else: |
|
return output_ids |
|
|
|
def set_input_tensor(self, input_tensor): |
|
"""Set input tensor to be used instead of forward()'s input. |
|
|
|
When doing pipeline parallelism the input from the previous |
|
stage comes from communication, not from the input, so the |
|
model's forward_step_func won't have it. This function is thus |
|
used by internal code to bypass the input provided by the |
|
forward_step_func""" |
|
self.input_tensor = input_tensor |
|
|