{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tulasiram58827/TTS_TFLite/blob/main/Parallel_WaveGAN_TFLite.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qu_1y5_ZDxpU"
},
"source": [
"This notebook contains code to convert TensorFlow ParallelWaveGAN to TFLite"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1KQie-EQDzEL"
},
"source": [
"## Acknowledgments"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h-qWgadcDzCW"
},
"source": [
"- Pretrained model(in PyTorch) downloaded from [Parallel WaveGAN Repository](https://github.com/kan-bayashi/ParallelWaveGAN#results)\n",
"\n",
"- Converted PyTorch weights to Tensorflow Compatible using [Tensorflow TTS Repository](https://github.com/TensorSpeech/TensorFlowTTS) with this [Notebook](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/examples/parallel_wavegan/convert_pwgan_from_pytorch_to_tensorflow.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pBE0GfYwEwoT"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PpGT8_mm8vs-"
},
"outputs": [],
"source": [
"!git clone https://github.com/TensorSpeech/TensorFlowTTS.git\n",
"!cd TensorFlowTTS\n",
"!pip install /content/TensorFlowTTS/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2tI8NSz_886Z"
},
"outputs": [],
"source": [
"!pip install parallel_wavegan"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iQq9Gkn9MYT"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import torch\n",
"import sys\n",
"sys.path.append('/content/TensorFlowTTS')\n",
"from tensorflow_tts.models import TFParallelWaveGANGenerator\n",
"from tensorflow_tts.configs import ParallelWaveGANGeneratorConfig\n",
"\n",
"from parallel_wavegan.models import ParallelWaveGANGenerator\n",
"import numpy as np\n",
"\n",
"from IPython.display import Audio"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BIr9zN74E3PU"
},
"source": [
"## Intialize Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "UoFriagU9NBx"
},
"outputs": [],
"source": [
"tf_model = TFParallelWaveGANGenerator(config=ParallelWaveGANGeneratorConfig(), name=\"parallel_wavegan_generator\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "mayxwoLp9fiR"
},
"outputs": [],
"source": [
"tf_model._build()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P__OyD23E8jN"
},
"source": [
"## Load PyTorch Checkpoints"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gZq9ibuzHbI9",
"outputId": "660ebfd7-6ed9-49e2-b3b1-9c26894694f4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1wPwO9K-0Yq-GYcXbHseaqt8kUpa_ojJf\n",
"To: /content/checkpoint-400000steps.pkl\n",
"\r",
"0.00B [00:00, ?B/s]\r",
"17.5MB [00:00, 154MB/s]\n"
]
}
],
"source": [
"!gdown --id 1wPwO9K-0Yq-GYcXbHseaqt8kUpa_ojJf -O checkpoint-400000steps.pkl"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "GoeX-YLQ9kaf"
},
"outputs": [],
"source": [
"torch_checkpoints = torch.load(\"checkpoint-400000steps.pkl\", map_location=torch.device('cpu'))\n",
"torch_generator_weights = torch_checkpoints[\"model\"][\"generator\"]\n",
"torch_model = ParallelWaveGANGenerator()\n",
"torch_model.load_state_dict(torch_checkpoints[\"model\"][\"generator\"])\n",
"torch_model.remove_weight_norm()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3NSfX33w99WW",
"outputId": "436460ea-2969-4f89-ba4a-801f0b60abff"
},
"outputs": [
{
"data": {
"text/plain": [
"1334309"
]
},
"execution_count": 9,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"model_parameters = filter(lambda p: p.requires_grad, torch_model.parameters())\n",
"params = sum([np.prod(p.size()) for p in model_parameters])\n",
"params"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x7t7hPgiE_pR"
},
"source": [
"## Convert PyTorch weights to TensorFlow"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "Y4vOfByl-ASZ"
},
"outputs": [],
"source": [
"# in pytorch, in convolution layer, the order is bias -> weight, in tf it is weight -> bias. We need re-order.\n",
"\n",
"def convert_weights_pytorch_to_tensorflow(weights_pytorch):\n",
" \"\"\"\n",
" Convert pytorch Conv1d weight variable to tensorflow Conv2D weights.\n",
" 1D: Pytorch (f_output, f_input, kernel_size) -> TF (kernel_size, f_input, 1, f_output)\n",
" 2D: Pytorch (f_output, f_input, kernel_size_h, kernel_size_w) -> TF (kernel_size_w, kernel_size_h, f_input, 1, f_output)\n",
" \"\"\"\n",
" if len(weights_pytorch.shape) == 3: # conv1d-kernel\n",
" weights_tensorflow = np.transpose(weights_pytorch, (0,2,1)) # [f_output, kernel_size, f_input]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2)) # [kernel-size, f_output, f_input]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1)) # [kernel-size, f_input, f_output]\n",
" return weights_tensorflow\n",
" elif len(weights_pytorch.shape) == 1: # conv1d-bias\n",
" return weights_pytorch\n",
" elif len(weights_pytorch.shape) == 4: # conv2d-kernel\n",
" weights_tensorflow = np.transpose(weights_pytorch, (0,2,1,3)) # [f_output, kernel_size_h, f_input, kernel_size_w]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2,3)) # [kernel-size_h, f_output, f_input, kernel-size-w]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1,3)) # [kernel_size_h, f_input, f_output, kernel-size-w]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (0,1,3,2)) # [kernel_size_h, f_input, kernel-size-w, f_output]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (0,2,1,3)) # [kernel_size_h, kernel-size-w, f_input, f_output]\n",
" weights_tensorflow = np.transpose(weights_tensorflow, (1,0,2,3)) # [kernel-size_w, kernel_size_h, f_input, f_output]\n",
" return weights_tensorflow\n",
"\n",
"torch_weights = []\n",
"all_keys = list(torch_model.state_dict().keys())\n",
"all_values = list(torch_model.state_dict().values())\n",
"\n",
"idx_already_append = []\n",
"\n",
"for i in range(len(all_keys) -1):\n",
" if i not in idx_already_append:\n",
" if all_keys[i].split(\".\")[0:-1] == all_keys[i + 1].split(\".\")[0:-1]:\n",
" if all_keys[i].split(\".\")[-1] == \"bias\" and all_keys[i + 1].split(\".\")[-1] == \"weight\":\n",
" torch_weights.append(convert_weights_pytorch_to_tensorflow(all_values[i + 1].cpu().detach().numpy()))\n",
" torch_weights.append(convert_weights_pytorch_to_tensorflow(all_values[i].cpu().detach().numpy()))\n",
" idx_already_append.append(i)\n",
" idx_already_append.append(i + 1)\n",
" else:\n",
" if i not in idx_already_append:\n",
" torch_weights.append(convert_weights_pytorch_to_tensorflow(all_values[i].cpu().detach().numpy()))\n",
" idx_already_append.append(i)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "168kydzc-SxJ"
},
"outputs": [],
"source": [
"tf_var = tf_model.trainable_variables\n",
"for i, var in enumerate(tf_var):\n",
" tf.keras.backend.set_value(var, torch_weights[i])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p8D70bCeFOAA"
},
"source": [
"## Convert to TFLite"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "mTnGufeuH3io"
},
"outputs": [],
"source": [
"def convert_to_tflite(quantization):\n",
" pwg_concrete_function = tf_model.inference.get_concrete_function()\n",
" converter = tf.lite.TFLiteConverter.from_concrete_functions([pwg_concrete_function])\n",
" converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
" converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]\n",
" if quantization == 'float16':\n",
" converter.target_spec.supported_types = [tf.float16]\n",
" tf_lite_model = converter.convert()\n",
" model_name = f'parallel_wavegan_{quantization}.tflite'\n",
" with open(model_name, 'wb') as f:\n",
" f.write(tf_lite_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zSaic3flIJX7"
},
"source": [
"#### Dynamic Range Quantization"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "6STZKNqg-vxS"
},
"outputs": [],
"source": [
"quantization = 'dr' #@param [\"dr\", \"float16\"]\n",
"convert_to_tflite(quantization)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VB5H4bUmIUFR",
"outputId": "e53b77e9-d680-424a-a1e7-5a6abb723866"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.7M\tparallel_wavegan_dr.tflite\n"
]
}
],
"source": [
"!du -sh parallel_wavegan_dr.tflite"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tb_FF8fNINWr"
},
"source": [
"#### Float16 Quantization"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19kqBUnQ_KG3",
"outputId": "4ab46a0f-98cb-44ea-8b8c-a337b5fc8d35"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.2M\tparallel_wavegan_float16.tflite\n"
]
}
],
"source": [
"quantization = 'float16'\n",
"convert_to_tflite(quantization)\n",
"!du -sh parallel_wavegan_float16.tflite"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7Kab76pmFifJ"
},
"source": [
"## Download Sample Output of Tacotron2"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IwuNQ_Z1Fm0d",
"outputId": "a5e18dc0-573d-468e-f215-87d5bc86e67e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1LmU3j8yedwBzXKVDo9tCvozLM4iwkRnP\n",
"To: /content/tac_output.npy\n",
"\r",
" 0% 0.00/36.0k [00:00, ?B/s]\r",
"100% 36.0k/36.0k [00:00<00:00, 59.6MB/s]\n"
]
}
],
"source": [
"!gdown --id 1LmU3j8yedwBzXKVDo9tCvozLM4iwkRnP -O tac_output.npy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZyoeWq2mFRCb"
},
"source": [
"## TFLite Inference"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "jXu30-8g_q88"
},
"outputs": [],
"source": [
"data = np.load('tac_output.npy')\n",
"feats = np.expand_dims(data, 0)\n",
"\n",
"interpreter = tf.lite.Interpreter(model_path='parallel_wavegan_dr.tflite')\n",
"\n",
"input_details = interpreter.get_input_details()\n",
"\n",
"output_details = interpreter.get_output_details()\n",
"\n",
"interpreter.resize_tensor_input(input_details[0]['index'], [1, feats.shape[1], feats.shape[2]], strict=True)\n",
"interpreter.allocate_tensors()\n",
"\n",
"interpreter.set_tensor(input_details[0]['index'], feats)\n",
"\n",
"interpreter.invoke()\n",
"\n",
"output = interpreter.get_tensor(output_details[0]['index'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lilT7qceIaKZ"
},
"source": [
"## Play Audio"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 75
},
"id": "kAo5yluw_6Xw",
"outputId": "d6a20ef6-176d-4c23-f407-a3d0b1275bba"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 31,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"output = output[0, :, 0]\n",
"\n",
"Audio(output, rate=22050)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Tensorflow_TTS_PWGAN.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}