Update engine.py
Browse files
engine.py
CHANGED
@@ -108,6 +108,7 @@ def predict_fn(data_loader, model, device, extract_features=False):
|
|
108 |
mask=mask,
|
109 |
token_type_ids=token_type_ids
|
110 |
).cpu().detach().numpy().tolist())
|
|
|
111 |
print("1",torch.argmax(outputs, dim=1))
|
112 |
print("2",torch.argmax(outputs, dim=1).cpu())
|
113 |
print("3",torch.argmax(outputs, dim=1).cpu().numpy())
|
|
|
108 |
mask=mask,
|
109 |
token_type_ids=token_type_ids
|
110 |
).cpu().detach().numpy().tolist())
|
111 |
+
print("0",outputs)
|
112 |
print("1",torch.argmax(outputs, dim=1))
|
113 |
print("2",torch.argmax(outputs, dim=1).cpu())
|
114 |
print("3",torch.argmax(outputs, dim=1).cpu().numpy())
|