BusinessDev commited on
Commit
0f18d6d
1 Parent(s): da16864

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +37 -36
train.py CHANGED
@@ -2,11 +2,6 @@ from transformers import MBartForSequenceClassification, MBart50Tokenizer, Train
2
  from datasets import Dataset
3
 
4
 
5
- # Load the model and tokenizer
6
- model_name = "LocalDoc/mbart_large_qa_azerbaijan" # Replace with your model name if different
7
- tokenizer = MBart50Tokenizer.from_pretrained(model_name)
8
- model = MBartForSequenceClassification.from_pretrained(model_name)
9
- chunk_size = 512
10
 
11
  # Prepare the dataset (simplified)
12
  def prepare_text_dataset(data):
@@ -29,34 +24,40 @@ def prepare_text_dataset(data):
29
 
30
  return formatted_dataset
31
 
32
-
33
- # Load the plain text (replace with your actual loading logic)
34
- with open("constitution.txt", "r", encoding="utf-8") as f:
35
- constitution_text = f.read()
36
-
37
- # Prepare the dataset
38
- train_dataset = prepare_text_dataset(constitution_text)
39
-
40
- # Define training arguments
41
- training_args = TrainingArguments(
42
- output_dir="./results", # Adjust output directory
43
- overwrite_output_dir=True,
44
- num_train_epochs=3, # Adjust training epochs
45
- per_device_train_batch_size=1, # Adjust batch size based on your GPU memory
46
- save_steps=500,
47
- save_total_limit=2,
48
- )
49
-
50
- # Create the Trainer
51
- trainer = Trainer(
52
- model=model,
53
- args=training_args,
54
- train_dataset=train_dataset,
55
- )
56
-
57
- # Start training
58
- trainer.train()
59
-
60
- # Save the fine-tuned model
61
- model.save_pretrained("./fine-tuned_model")
62
- tokenizer.save_pretrained("./fine-tuned_model")
 
 
 
 
 
 
 
2
  from datasets import Dataset
3
 
4
 
 
 
 
 
 
5
 
6
  # Prepare the dataset (simplified)
7
  def prepare_text_dataset(data):
 
24
 
25
  return formatted_dataset
26
 
27
+ def init():
28
+ # Load the model and tokenizer
29
+ model_name = "LocalDoc/mbart_large_qa_azerbaijan" # Replace with your model name if different
30
+ tokenizer = MBart50Tokenizer.from_pretrained(model_name)
31
+ model = MBartForSequenceClassification.from_pretrained(model_name)
32
+ chunk_size = 512
33
+
34
+ # Load the plain text (replace with your actual loading logic)
35
+ with open("constitution.txt", "r", encoding="utf-8") as f:
36
+ constitution_text = f.read()
37
+
38
+ # Prepare the dataset
39
+ train_dataset = prepare_text_dataset(constitution_text)
40
+
41
+ # Define training arguments
42
+ training_args = TrainingArguments(
43
+ output_dir="./results", # Adjust output directory
44
+ overwrite_output_dir=True,
45
+ num_train_epochs=3, # Adjust training epochs
46
+ per_device_train_batch_size=1, # Adjust batch size based on your GPU memory
47
+ save_steps=500,
48
+ save_total_limit=2,
49
+ )
50
+
51
+ # Create the Trainer
52
+ trainer = Trainer(
53
+ model=model,
54
+ args=training_args,
55
+ train_dataset=train_dataset,
56
+ )
57
+
58
+ # Start training
59
+ trainer.train()
60
+
61
+ # Save the fine-tuned model
62
+ model.save_pretrained("./fine-tuned_model")
63
+ tokenizer.save_pretrained("./fine-tuned_model")