from fastai.vision.all import load_learner, PILImage import torch import csv import hashlib import json from pathlib import Path import os def get_preds(obj, learn, model_name='tags', thresh=15): labels = [] ''' get list of classes from Learner object ''' for item in learn.dls.vocab: labels.append(item) ''' open mapping from csv into dictionary and get only the onces with mapping ''' if model_name == 'life-event': input_file = "./model/cardtagger/mapping-life-event.csv" else: input_file = "./model/cardtagger/mapping.csv" data = csv.DictReader(open(input_file)) dic = dict() for row in data: if row['tag'] != row['alternatives']: dic[row['tag']] = row['alternatives'].split(',') ''' combine the classnames with the result and get those with > threshold back add the synonym mapping list to the dictionary ''' predictions = [] x = 0 for item in obj: acc = round(item.item()*100, 1) if acc > thresh: synonyms = [] for i in dic: if labels[x] == i: synonyms = dic[i] predictions.append({"label": labels[x], "probability" : acc, "synonyms" : synonyms }) #predictions[labels[x]] = acc x += 1 predictions = {"predictions": predictions} return predictions def cardtagger(image): img = PILImage(PILImage.create(image).resize((128,128))) ''' get classification of images that already where send to api or predict on new ''' base = Path("./tmp/") md5hash = hashlib.md5(img.tobytes()).hexdigest() file = os.path.join(base, md5hash) if os.path.exists(file): result = json.load(open(base / (md5hash))) else: ''' get classification of tags ''' tag_model = load_learner('./model/cardtagger/tags.pkl') tag_prediction, _, tag_probs = tag_model.predict(img) result_tags = get_preds(tag_probs, tag_model, 'tags') ''' get classification of life event ''' life_event_model = load_learner('./model/cardtagger/life-event-2.pkl') life_event_prediction, _, life_event_probs = life_event_model.predict(img) result_life = get_preds(life_event_probs, life_event_model, 'life-event', 30) ''' comebine tag predictions ... ''' result = {"predictions": result_tags['predictions']+result_life['predictions']} ''' write the json to a temp file and return the results ''' # out_file = open(file, "w+") # # json.dump(result, out_file) return result #cardtagger('test.jpg')