Ruicheng's picture
first commit
ec0c8fa
raw
history blame
3.01 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from enum import Enum
import os
from pathlib import Path
from typing import Any, Dict, Optional
class ClusterType(Enum):
AWS = "aws"
FAIR = "fair"
RSC = "rsc"
def _guess_cluster_type() -> ClusterType:
uname = os.uname()
if uname.sysname == "Linux":
if uname.release.endswith("-aws"):
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
return ClusterType.AWS
elif uname.nodename.startswith("rsc"):
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
return ClusterType.RSC
return ClusterType.FAIR
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
if cluster_type is None:
return _guess_cluster_type()
return cluster_type
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
cluster_type = get_cluster_type(cluster_type)
if cluster_type is None:
return None
CHECKPOINT_DIRNAMES = {
ClusterType.AWS: "checkpoints",
ClusterType.FAIR: "checkpoint",
ClusterType.RSC: "checkpoint/dino",
}
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
checkpoint_path = get_checkpoint_path(cluster_type)
if checkpoint_path is None:
return None
username = os.environ.get("USER")
assert username is not None
return checkpoint_path / username
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
cluster_type = get_cluster_type(cluster_type)
if cluster_type is None:
return None
SLURM_PARTITIONS = {
ClusterType.AWS: "learnlab",
ClusterType.FAIR: "learnlab",
ClusterType.RSC: "learn",
}
return SLURM_PARTITIONS[cluster_type]
def get_slurm_executor_parameters(
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
) -> Dict[str, Any]:
# create default parameters
params = {
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
"gpus_per_node": num_gpus_per_node,
"tasks_per_node": num_gpus_per_node, # one task per GPU
"cpus_per_task": 10,
"nodes": nodes,
"slurm_partition": get_slurm_partition(cluster_type),
}
# apply cluster-specific adjustments
cluster_type = get_cluster_type(cluster_type)
if cluster_type == ClusterType.AWS:
params["cpus_per_task"] = 12
del params["mem_gb"]
elif cluster_type == ClusterType.RSC:
params["cpus_per_task"] = 12
# set additional parameters / apply overrides
params.update(kwargs)
return params