File size: 536 Bytes
4f6613a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import click
import torch
from loguru import logger


@click.command()
@click.argument("model_path")
@click.argument("output_path")
def main(model_path, output_path):
    if model_path == output_path:
        logger.error("Model path and output path are the same")
        return

    logger.info(f"Loading model from {model_path}")
    state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
    torch.save(state_dict, output_path)
    logger.info(f"Model saved to {output_path}")


if __name__ == "__main__":
    main()