Device incorrect when importing functions from Torch

#85
by catwell - opened

For instance:

import spaces
import torch
from torch import ones

@spaces.GPU
def foo():
  print(torch.ones(3, device="cuda"))
  print(ones(3, device="cuda"))

foo()

This prints:

tensor([1., 1., 1.], device='cuda:0')
tensor([1., 1., 1.])

It is pretty confusing and will cause issues with many existing code bases, so I think it would be a good idea to fix this.

If you need a workaround... At the top of the app.py file, after importing spaces and torch, do this:

def my_ones(*args, **kwargs):
    return torch.ones(*args, **kwargs)

torch.ones = my_ones

I know it looks weird, but it works (provided you never call ones outside of the @spaces.GPU decorator).

(I used ones at an example but it is the same thing with all other Torch functions such as arange.)

Sign up or log in to comment