import os import platform from dotenv import load_dotenv import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling from datasets import load_dataset, concatenate_datasets from huggingface_hub import login import time import uvicorn from fastapi import FastAPI import threading import logging import warnings # Ignorar advertencias específicas si lo deseas (opcional) warnings.filterwarnings("ignore", category=FutureWarning) # Configurar logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("training.log"), logging.StreamHandler() ] ) # Cargar las variables de entorno load_dotenv() huggingface_token = os.getenv('HUGGINGFACE_TOKEN') if huggingface_token is None: raise ValueError("HUGGINGFACE_TOKEN no encontrado en las variables de entorno.") # Iniciar sesión en Hugging Face login(token=huggingface_token) # Definir la aplicación FastAPI app = FastAPI() @app.get("/") async def root(): return {"message": "Modelo entrenado y en ejecución."} def load_and_train(): model_name = 'gpt2' tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name, return_dict=True) # Asignar el pad_token al eos_token tokenizer.pad_token = tokenizer.eos_token # Redimensionar las embeddings del modelo para incluir el pad_token model.resize_token_embeddings(len(tokenizer)) # Verificar dispositivo device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) logging.info(f"Entrenando en: {device}") # Determinar cache_dir if platform.system() == "Linux": cache_dir = '/dev/shm' else: cache_dir = './cache' # Crear el directorio de caché si no existe os.makedirs(cache_dir, exist_ok=True) # Intentar cargar los datasets con manejo de errores try: dataset_humanizado = load_dataset('daily_dialog', split='train', cache_dir=cache_dir, trust_remote_code=True) dataset_codigo = load_dataset('code_search_net', split='train', cache_dir=cache_dir, trust_remote_code=True) except Exception as e: logging.error(f"Error al cargar los datasets: {e}") # Intentar cargar un dataset alternativo time.sleep(60) # Esperar 60 segundos antes de reintentar try: dataset_humanizado = load_dataset('alternative_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True) dataset_codigo = load_dataset('alternative_code_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True) except Exception as e: logging.error(f"Error al cargar el dataset alternativo: {e}") return logging.info("Daily Dialog columnas: %s", dataset_humanizado.column_names) logging.info("Code Search Net columnas: %s", dataset_codigo.column_names) # Combinar los datasets en memoria combined_dataset = concatenate_datasets([dataset_humanizado, dataset_codigo]) logging.info("Dataset combinado columnas: %s", combined_dataset.column_names) # Función para crear un campo 'text' estandarizado def concatenate_text_fields(examples): """ Crea un nuevo campo 'text' concatenando los campos de texto disponibles en cada ejemplo. Prioriza 'dialog', luego 'whole_func_string', y luego 'func_documentation_string'. Si ninguno está presente, asigna una cadena vacía. Args: examples (dict): Diccionario con listas de valores para cada columna. Returns: dict: Diccionario con el nuevo campo 'text'. """ texts = [] # Determinar el tamaño del lote num_examples = len(next(iter(examples.values()))) # Obtener el tamaño del lote for i in range(num_examples): text = '' # Procesar 'dialog' if 'dialog' in examples and i < len(examples['dialog']) and isinstance(examples['dialog'][i], str) and examples['dialog'][i]: text = examples['dialog'][i] # Procesar 'whole_func_string' elif 'whole_func_string' in examples and i < len(examples['whole_func_string']) and isinstance(examples['whole_func_string'][i], str) and examples['whole_func_string'][i]: text = examples['whole_func_string'][i] # Procesar 'func_documentation_string' elif 'func_documentation_string' in examples and i < len(examples['func_documentation_string']) and isinstance(examples['func_documentation_string'][i], str) and examples['func_documentation_string'][i]: text = examples['func_documentation_string'][i] # Puedes añadir más campos si es necesario texts.append(text) examples['text'] = texts return examples # Crear el campo 'text' combined_dataset = combined_dataset.map(concatenate_text_fields, batched=True) # Función de tokenización basada en el campo 'text' def tokenize_function(examples): tokenized = tokenizer( examples['text'], truncation=True, padding='max_length', max_length=512 ) tokenized['labels'] = tokenized['input_ids'].copy() return tokenized # Tokenizar el dataset tokenized_dataset = combined_dataset.map( tokenize_function, batched=True ) # Configurar el Data Collator data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False # Para modelado de lenguaje causal ) # Configurar argumentos de entrenamiento training_args = TrainingArguments( output_dir=os.path.join(cache_dir, 'results'), # Almacenar temporalmente en RAM per_device_train_batch_size=4, per_device_eval_batch_size=4, num_train_epochs=1, learning_rate=1e-5, logging_steps=100, save_total_limit=1, seed=42, weight_decay=0.01, warmup_ratio=0.1, evaluation_strategy="epoch", lr_scheduler_type="linear", save_strategy="epoch", # Guardar solo al final de cada epoch logging_dir=os.path.join(cache_dir, 'logs'), # Directorio de logs ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator, ) while True: try: trainer.train() # Subir el modelo a Hugging Face desde la RAM model.push_to_hub( 'Yhhxhfh/nombre_de_tu_modelo', commit_message="Actualización del modelo", add_to_git_credential=False # Desactivar la configuración automática de credenciales de Git ) tokenizer.push_to_hub( 'Yhhxhfh/nombre_de_tu_modelo', commit_message="Actualización del tokenizador", add_to_git_credential=False # Desactivar la configuración automática de credenciales de Git ) logging.info("Modelo y tokenizador subidos exitosamente.") time.sleep(0) # Esperar 0 segundos antes de la siguiente iteración except Exception as e: logging.error(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...") time.sleep(0) # Esperar 0 segundos antes de reintentar if __name__ == "__main__": # Correr FastAPI en un hilo separado threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start() load_and_train()