New_R3gm / demucs /parser.py
r3gm's picture
Upload 288 files
7bc29af
raw
history blame
10.8 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from pathlib import Path
def get_parser():
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.")
default_raw = None
default_musdb = None
if 'DEMUCS_RAW' in os.environ:
default_raw = Path(os.environ['DEMUCS_RAW'])
if 'DEMUCS_MUSDB' in os.environ:
default_musdb = Path(os.environ['DEMUCS_MUSDB'])
parser.add_argument(
"--raw",
type=Path,
default=default_raw,
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.")
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw")
parser.add_argument("-m",
"--musdb",
type=Path,
default=default_musdb,
help="Path to musdb root")
parser.add_argument("--is_wav", action="store_true",
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).")
parser.add_argument("--metadata", type=Path, default=Path("metadata/"),
help="Folder where metadata information is stored.")
parser.add_argument("--wav", type=Path,
help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
"subfolder.")
parser.add_argument("--samplerate", type=int, default=44100)
parser.add_argument("--audio_channels", type=int, default=2)
parser.add_argument("--samples",
default=44100 * 10,
type=int,
help="number of samples to feed in")
parser.add_argument("--data_stride",
default=44100,
type=int,
help="Stride for chunks, shorter = longer epochs")
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers")
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers")
parser.add_argument("-d",
"--device",
help="Device to train on, default is cuda if available else cpu")
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.")
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file")
parser.add_argument("--test", help="Just run the test pipeline + one validation. "
"This should be a filename relative to the models/ folder.")
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, "
"on a pretrained model. ")
parser.add_argument("--rank", default=0, type=int)
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument("--master")
parser.add_argument("--checkpoints",
type=Path,
default=Path("checkpoints"),
help="Folder where to store checkpoints etc")
parser.add_argument("--evals",
type=Path,
default=Path("evals"),
help="Folder where to store evals and waveforms")
parser.add_argument("--save",
action="store_true",
help="Save estimated for the test set waveforms")
parser.add_argument("--logs",
type=Path,
default=Path("logs"),
help="Folder where to store logs")
parser.add_argument("--models",
type=Path,
default=Path("models"),
help="Folder where to store trained models")
parser.add_argument("-R",
"--restart",
action='store_true',
help='Restart training, ignoring previous run')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs")
parser.add_argument("-r",
"--repeat",
type=int,
default=2,
help="Repeat the train set, longer epochs")
parser.add_argument("-b", "--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1")
parser.add_argument("--init", help="Initialize from a pre-trained model.")
# Augmentation options
parser.add_argument("--no_augment",
action="store_false",
dest="augment",
default=True,
help="No basic data augmentation.")
parser.add_argument("--repitch", type=float, default=0.2,
help="Probability to do tempo/pitch change")
parser.add_argument("--max_tempo", type=float, default=12,
help="Maximum relative tempo change in %% when using repitch.")
parser.add_argument("--remix_group_size",
type=int,
default=4,
help="Shuffle sources using group of this size. Useful to somewhat "
"replicate multi-gpu training "
"on less GPUs.")
parser.add_argument("--shifts",
type=int,
default=10,
help="Number of random shifts used for the shift trick.")
parser.add_argument("--overlap",
type=float,
default=0.25,
help="Overlap when --split_valid is passed.")
# See model.py for doc
parser.add_argument("--growth",
type=float,
default=2.,
help="Number of channels between two layers will increase by this factor")
parser.add_argument("--depth",
type=int,
default=6,
help="Number of layers for the encoder and decoder")
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM")
parser.add_argument("--channels",
type=int,
default=64,
help="Number of channels for the first encoder layer")
parser.add_argument("--kernel_size",
type=int,
default=8,
help="Kernel size for the (transposed) convolutions")
parser.add_argument("--conv_stride",
type=int,
default=4,
help="Stride for the (transposed) convolutions")
parser.add_argument("--context",
type=int,
default=3,
help="Context size for the decoder convolutions "
"before the transposed convolutions")
parser.add_argument("--rescale",
type=float,
default=0.1,
help="Initial weight rescale reference")
parser.add_argument("--no_resample", action="store_false",
default=True, dest="resample",
help="No Resampling of the input/output x2")
parser.add_argument("--no_glu",
action="store_false",
default=True,
dest="glu",
help="Replace all GLUs by ReLUs")
parser.add_argument("--no_rewrite",
action="store_false",
default=True,
dest="rewrite",
help="No 1x1 rewrite convolutions")
parser.add_argument("--normalize", action="store_true")
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True)
# Tasnet options
parser.add_argument("--tasnet", action="store_true")
parser.add_argument("--split_valid",
action="store_true",
help="Predict chunks by chunks for valid and test. Required for tasnet")
parser.add_argument("--X", type=int, default=8)
# Other options
parser.add_argument("--show",
action="store_true",
help="Show model architecture, size and exit")
parser.add_argument("--save_model", action="store_true",
help="Skip traning, just save final model "
"for the current checkpoint value.")
parser.add_argument("--save_state",
help="Skip training, just save state "
"for the current checkpoint value. You should "
"provide a model name as argument.")
# Quantization options
parser.add_argument("--q-min-size", type=float, default=1,
help="Only quantize layers over this size (in MB)")
parser.add_argument(
"--qat", type=int, help="If provided, use QAT training with that many bits.")
parser.add_argument("--diffq", type=float, default=0)
parser.add_argument(
"--ms-target", type=float, default=162,
help="Model size target in MB, when using DiffQ. Best model will be kept "
"only if it is smaller than this target.")
return parser
def get_name(parser, args):
"""
Return the name of an experiment given the args. Some parameters are ignored,
for instance --workers, as they do not impact the final result.
"""
ignore_args = set([
"checkpoints",
"deterministic",
"eval",
"evals",
"eval_cpu",
"eval_workers",
"logs",
"master",
"rank",
"restart",
"save",
"save_model",
"save_state",
"show",
"workers",
"world_size",
])
parts = []
name_args = dict(args.__dict__)
for name, value in name_args.items():
if name in ignore_args:
continue
if value != parser.get_default(name):
if isinstance(value, Path):
parts.append(f"{name}={value.name}")
else:
parts.append(f"{name}={value}")
if parts:
name = " ".join(parts)
else:
name = "default"
return name