--- license: apache-2.0 datasets: - KDAI-NLP/traffy-fondue-type-only language: - th metrics: - f1 tags: - roberta --- # Traffy Complaint Classification This model is trained to automatically classify types of traffic complaints in Thai text, aiming to reduce the need for manual classification by humans. ### Model Details Model Name: KDAI-NLP/wangchanberta-traffy-multi Tokenizer: airesearch/wangchanberta-base-att-spm-uncased License: Apache License 2.0 ### How to Use ```python !pip install sentencepiece import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from torch.nn.functional import sigmoid import json # Target lists target_list = [ 'ความสะอาด', 'สายไฟ', 'สะพาน', 'ถนน', 'น้ำท่วม', 'ร้องเรียน', 'ท่อระบายน้ำ', 'ความปลอดภัย', 'คลอง', 'แสงสว่าง', 'ทางเท้า', 'จราจร', 'กีดขวาง', 'การเดินทาง', 'เสียงรบกวน', 'ต้นไม้', 'สัตว์จรจัด', 'เสนอแนะ', 'คนจรจัด', 'ห้องน้ำ', 'ป้ายจราจร', 'สอบถาม', 'ป้าย', 'PM2.5' ] # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased") model = AutoModelForSequenceClassification.from_pretrained("KDAI-NLP/wangchanberta-traffy-multi") # Example text to classify text = "ช่วยด้วยครับถนนน้ำท่วมอีกแล้ว ต้นไม้ก็ล้มขวางทาง กลับบ้านไม่ได้" # Encode the text using the tokenizer inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) # Get model predictions (logits) with torch.no_grad(): logits = model(**inputs).logits # Apply sigmoid function to convert logits to probabilities probabilities = sigmoid(logits) # Map probabilities to corresponding labels probabilities = probabilities.squeeze().tolist() label_probabilities = zip(target_list, probabilities) # Print labels with probabilities for label, probability in label_probabilities: print(f"{label}: {probability:.4f}") # Or JSON # Create a dictionary for labels and probabilities results_dict = {label: probability for label, probability in label_probabilities} # Convert dictionary to JSON string results_json = json.dumps(results_dict, ensure_ascii=False, indent=4) # Print the JSON string print(results_json) ``` ## Training Details The model was trained on traffic complaint data API (included stopwords) using the airesearch/wangchanberta-base-att-spm-uncased base model. This is a multi-label classification task with a total of 24 classes. ## Training Scores | Model | Stopword | Epoch | Training Loss | Validation Loss | F1 | Accuracy | | ---------------------------------- | -------- | ----- | ------------- | --------------- | ------- | -------- | | wangchanberta-base-att-spm-uncased | Included | 0 | 0.0322 | 0.034822 | 0.7015 | 0.7569 | | wangchanberta-base-att-spm-uncased | Included | 2 | 0.0207 | 0.026364 | 0.8405 | 0.7821 | | wangchanberta-base-att-spm-uncased | Included | 4 | 0.0165 | 0.025142 | 0.8458 | 0.7934 | Feel free to customize the README further if needed.