Spaces:
Running
Running
Device incorrect when importing functions from Torch
#85
by
catwell
- opened
For instance:
import spaces
import torch
from torch import ones
@spaces.GPU
def foo():
print(torch.ones(3, device="cuda"))
print(ones(3, device="cuda"))
foo()
This prints:
tensor([1., 1., 1.], device='cuda:0')
tensor([1., 1., 1.])
It is pretty confusing and will cause issues with many existing code bases, so I think it would be a good idea to fix this.
If you need a workaround... At the top of the app.py
file, after importing spaces
and torch
, do this:
def my_ones(*args, **kwargs):
return torch.ones(*args, **kwargs)
torch.ones = my_ones
I know it looks weird, but it works (provided you never call ones
outside of the @spaces.GPU
decorator).
(I used ones
at an example but it is the same thing with all other Torch functions such as arange
.)