Spaces:
Running
on
L40S
Running
on
L40S
File size: 536 Bytes
4f6613a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import click
import torch
from loguru import logger
@click.command()
@click.argument("model_path")
@click.argument("output_path")
def main(model_path, output_path):
if model_path == output_path:
logger.error("Model path and output path are the same")
return
logger.info(f"Loading model from {model_path}")
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
torch.save(state_dict, output_path)
logger.info(f"Model saved to {output_path}")
if __name__ == "__main__":
main()
|