|
|
|
import subprocess |
|
import yaml |
|
|
|
|
|
import torch.nn as nn |
|
from transformers import AutoConfig, RobertaForMaskedLM, PretrainedConfig |
|
|
|
|
|
class DNikudModel(nn.Module): |
|
def __init__(self, config, nikud_size, dagesh_size, sin_size, pretrain_model=None, device='cpu'): |
|
super(DNikudModel, self).__init__() |
|
|
|
if pretrain_model is not None: |
|
model_base = RobertaForMaskedLM.from_pretrained(pretrain_model).to(device) |
|
else: |
|
model_base = RobertaForMaskedLM(config=config).to(device) |
|
|
|
self.model = model_base.roberta |
|
for name, param in self.model.named_parameters(): |
|
param.requires_grad = False |
|
|
|
self.lstm1 = nn.LSTM(config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True) |
|
self.lstm2 = nn.LSTM(2 * config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True) |
|
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) |
|
self.out_n = nn.Linear(config.hidden_size, nikud_size) |
|
self.out_d = nn.Linear(config.hidden_size, dagesh_size) |
|
self.out_s = nn.Linear(config.hidden_size, sin_size) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
last_hidden_state = self.model(input_ids, attention_mask=attention_mask).last_hidden_state |
|
lstm1, _ = self.lstm1(last_hidden_state) |
|
lstm2, _ = self.lstm2(lstm1) |
|
dense = self.dense(lstm2) |
|
|
|
nikud = self.out_n(dense) |
|
dagesh = self.out_d(dense) |
|
sin = self.out_s(dense) |
|
|
|
return nikud, dagesh, sin |
|
|
|
|
|
def get_git_commit_hash(): |
|
try: |
|
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() |
|
return commit_hash |
|
except subprocess.CalledProcessError: |
|
|
|
print("Not inside a Git repository!") |
|
return None |
|
|
|
|
|
class ModelConfig(PretrainedConfig): |
|
def __init__(self, max_length=None, dict=None): |
|
super(ModelConfig, self).__init__() |
|
if dict is None: |
|
self.__dict__.update(AutoConfig.from_pretrained("tau/tavbert-he").__dict__) |
|
self.max_length = max_length |
|
self._commit_hash = get_git_commit_hash() |
|
else: |
|
self.__dict__.update(dict) |
|
|
|
def print(self): |
|
print(self.__dict__) |
|
|
|
def save_to_file(self, file_path): |
|
with open(file_path, "w") as yaml_file: |
|
yaml.dump(self.__dict__, yaml_file, default_flow_style=False) |
|
|
|
@classmethod |
|
def load_from_file(cls, file_path): |
|
with open(file_path, "r") as yaml_file: |
|
config_dict = yaml.safe_load(yaml_file) |
|
return cls(dict=config_dict) |
|
|