/
Interacting with Jukebox
1 lines (1 loc) · 54 KB
/
Interacting with Jukebox
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Interacting with Jukebox","provenance":[{"file_id":"1mic-GVcMXstBbhOQcttE7m1t4Pu2rrlK","timestamp":1605666889999},{"file_id":"https://github.com/SMarioMan/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb","timestamp":1605653238719}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"}},"cells":[{"cell_type":"markdown","metadata":{"id":"uq8uLwZCn0BV"},"source":["IMPORTANT NOTE ON SYSTEM REQUIREMENTS:\n","\n","If you are connecting to a hosted runtime, make sure it has a P100 GPU (optionally run !nvidia-smi to confirm). Go to Edit>Notebook Settings to set this.\n","\n","CoLab may first assign you a lower memory machine if you are using a hosted runtime. If so, the first time you try to load the 5B model, it will run out of memory, and then you'll be prompted to restart with more memory (then return to the top of this CoLab). If you continue to have memory issues after this (or run into issues on your own home setup), switch to the 1B model.\n","\n","If you are using a local GPU, we recommend V100 or P100 with 16GB GPU memory for best performance. For GPU’s with less memory, we recommend using the 1B model and a smaller batch size throughout. \n","\n","https://www.youtube.com/watch?v=mNtmgYW428M&ab_channel=Broccaloo\n","\n"]},{"cell_type":"code","metadata":{"id":"8qEqdj8u0gdN"},"source":["!nvidia-smi -L"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VAMZK4GNA_PM"},"source":["Mount Google Drive to save sample levels as they are generated."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZPdMgaH_BPGN","executionInfo":{"status":"ok","timestamp":1605705250799,"user_tz":0,"elapsed":18479,"user":{"displayName":"Lokesh Saini","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgZnsdAWAjb7BpCuaFoEopOgAONRk9EvirLaapB=s64","userId":"04919528447711792423"}},"outputId":"05324f82-63d3-4f88-9a57-f88e229910f0"},"source":["from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Mounted at /content/gdrive\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"Zy4Rehq9ZKv_"},"source":["Prepare the environment."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sAdFGF-bqVMY","executionInfo":{"status":"ok","timestamp":1605705305379,"user_tz":0,"elapsed":52268,"user":{"displayName":"Lokesh Saini","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgZnsdAWAjb7BpCuaFoEopOgAONRk9EvirLaapB=s64","userId":"04919528447711792423"}},"outputId":"78ed382b-c445-408f-a636-fdeb8fe0a6e9"},"source":["!pip install git+https://github.com/tdunity/fixedjukebox.git"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Collecting git+https://github.com/tdunity/fixedjukebox.git\n"," Cloning https://github.com/tdunity/fixedjukebox.git to /tmp/pip-req-build-i1xv1jxd\n"," Running command git clone -q https://github.com/tdunity/fixedjukebox.git /tmp/pip-req-build-i1xv1jxd\n","Collecting fire==0.1.3\n"," Downloading https://files.pythonhosted.org/packages/5a/b7/205702f348aab198baecd1d8344a90748cb68f53bdcd1cc30cbc08e47d3e/fire-0.1.3.tar.gz\n","Collecting tqdm==4.45.0\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/4a/1c/6359be64e8301b84160f6f6f7936bbfaaa5e9a4eab6cbc681db07600b949/tqdm-4.45.0-py2.py3-none-any.whl (60kB)\n","\u001b[K |████████████████████████████████| 61kB 5.0MB/s \n","\u001b[?25hCollecting soundfile==0.10.3.post1\n"," Downloading https://files.pythonhosted.org/packages/eb/f2/3cbbbf3b96fb9fa91582c438b574cff3f45b29c772f94c400e2c99ef5db9/SoundFile-0.10.3.post1-py2.py3-none-any.whl\n","Collecting unidecode==1.1.1\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/d0/42/d9edfed04228bacea2d824904cae367ee9efd05e6cce7ceaaedd0b0ad964/Unidecode-1.1.1-py2.py3-none-any.whl (238kB)\n","\u001b[K |████████████████████████████████| 245kB 15.8MB/s \n","\u001b[?25hRequirement already satisfied: numba==0.48.0 in /usr/local/lib/python3.6/dist-packages (from jukebox==1.0) (0.48.0)\n","Collecting librosa==0.7.2\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/77/b5/1817862d64a7c231afd15419d8418ae1f000742cac275e85c74b219cbccb/librosa-0.7.2.tar.gz (1.6MB)\n","\u001b[K |████████████████████████████████| 1.6MB 16.3MB/s \n","\u001b[?25hCollecting mpi4py>=3.0.0\n","\u001b[?25l Downloading https://files.pythonhosted.org/packages/ec/8f/bbd8de5ba566dd77e408d8136e2bab7fdf2b97ce06cab830ba8b50a2f588/mpi4py-3.0.3.tar.gz (1.4MB)\n","\u001b[K |████████████████████████████████| 1.4MB 27.9MB/s \n","\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from fire==0.1.3->jukebox==1.0) (1.15.0)\n","Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.6/dist-packages (from soundfile==0.10.3.post1->jukebox==1.0) (1.14.3)\n","Requirement already satisfied: llvmlite<0.32.0,>=0.31.0dev0 in /usr/local/lib/python3.6/dist-packages (from numba==0.48.0->jukebox==1.0) (0.31.0)\n","Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.6/dist-packages (from numba==0.48.0->jukebox==1.0) (1.18.5)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from numba==0.48.0->jukebox==1.0) (50.3.2)\n","Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from librosa==0.7.2->jukebox==1.0) (2.1.9)\n","Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from librosa==0.7.2->jukebox==1.0) (1.4.1)\n","Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python3.6/dist-packages (from librosa==0.7.2->jukebox==1.0) (0.22.2.post1)\n","Requirement already satisfied: joblib>=0.12 in /usr/local/lib/python3.6/dist-packages (from librosa==0.7.2->jukebox==1.0) (0.17.0)\n","Requirement already satisfied: decorator>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from librosa==0.7.2->jukebox==1.0) (4.4.2)\n","Requirement already satisfied: resampy>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from librosa==0.7.2->jukebox==1.0) (0.2.2)\n","Requirement already satisfied: pycparser in /usr/local/lib/python3.6/dist-packages (from cffi>=1.0->soundfile==0.10.3.post1->jukebox==1.0) (2.20)\n","Building wheels for collected packages: jukebox, fire, librosa, mpi4py\n"," Building wheel for jukebox (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for jukebox: filename=jukebox-1.0-cp36-none-any.whl size=196069 sha256=557cee1eb241c928ba9f6c9f4e653f04a51c351d97e116b83c6612ab6ac66d31\n"," Stored in directory: /tmp/pip-ephem-wheel-cache-ma3iyn2k/wheels/1a/40/b2/d6bbe926438562fb7b3657d4928e9a085878244650e2949051\n"," Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for fire: filename=fire-0.1.3-py2.py3-none-any.whl size=49706 sha256=0bc4da08577cd316270b645623d91084094baedb08121dc5912ec25685668455\n"," Stored in directory: /root/.cache/pip/wheels/2a/1a/4d/6b30377c3051e76559d1185c1dbbfff15aed31f87acdd14c22\n"," Building wheel for librosa (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for librosa: filename=librosa-0.7.2-cp36-none-any.whl size=1612885 sha256=7222e455826a69e76e3ffed732202ebcc82572e8087a51c0422fd50cd22cd2bf\n"," Stored in directory: /root/.cache/pip/wheels/4c/6e/d7/bb93911540d2d1e44d690a1561871e5b6af82b69e80938abef\n"," Building wheel for mpi4py (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for mpi4py: filename=mpi4py-3.0.3-cp36-cp36m-linux_x86_64.whl size=2074472 sha256=073a8b3b17cb05f5a9629d1b055d48fc10b353c8806ea01863f5c9cf5747d532\n"," Stored in directory: /root/.cache/pip/wheels/18/e0/86/2b713dd512199096012ceca61429e12b960888de59818871d6\n","Successfully built jukebox fire librosa mpi4py\n","Installing collected packages: fire, tqdm, soundfile, unidecode, librosa, mpi4py, jukebox\n"," Found existing installation: tqdm 4.41.1\n"," Uninstalling tqdm-4.41.1:\n"," Successfully uninstalled tqdm-4.41.1\n"," Found existing installation: librosa 0.6.3\n"," Uninstalling librosa-0.6.3:\n"," Successfully uninstalled librosa-0.6.3\n","Successfully installed fire-0.1.3 jukebox-1.0 librosa-0.7.2 mpi4py-3.0.3 soundfile-0.10.3.post1 tqdm-4.45.0 unidecode-1.1.1\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"taDHgk1WCC_C","executionInfo":{"status":"ok","timestamp":1605705325671,"user_tz":0,"elapsed":17318,"user":{"displayName":"Lokesh Saini","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgZnsdAWAjb7BpCuaFoEopOgAONRk9EvirLaapB=s64","userId":"04919528447711792423"}},"outputId":"4df58e09-f4f9-4e79-b537-7dfdb7006053"},"source":["import jukebox\n","import torch as t\n","import librosa\n","import os\n","from IPython.display import Audio\n","from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model\n","from jukebox.hparams import Hyperparams, setup_hparams\n","from jukebox.sample import sample_single_window, _sample, \\\n"," sample_partial_window, upsample, \\\n"," load_prompts\n","from jukebox.utils.dist_utils import setup_dist_from_mpi\n","from jukebox.utils.torch_utils import empty_cache\n","rank, local_rank, device = setup_dist_from_mpi()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Using cuda True\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"89FftI5kc-Az"},"source":["# Sample from the 5B or 1B Lyrics Model\n"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"65aR2OZxmfzq","executionInfo":{"status":"ok","timestamp":1605705677440,"user_tz":0,"elapsed":336641,"user":{"displayName":"Lokesh Saini","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgZnsdAWAjb7BpCuaFoEopOgAONRk9EvirLaapB=s64","userId":"04919528447711792423"}},"outputId":"dbdfc166-168d-4b80-e5e2-8b1383f7da66"},"source":["model = '5b_lyrics' # or '5b' or '1b_lyrics'\n","hps = Hyperparams()\n","hps.sr = 44100\n","hps.n_samples = 3 if model in ('5b', '5b_lyrics') else 8\n","# Specifies the directory to save the sample in.\n","# We set this to the Google Drive mount point.\n","hps.name = '/content/gdrive/My Drive/test3'\n","chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n","max_batch_size = 9 if model in ('5b', '5b_lyrics') else 16\n","hps.levels = 3\n","hps.hop_fraction = [.5,.5,.125]\n","\n","vqvae, *priors = MODELS[model]\n","vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)\n","top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Downloading from gce\n","Restored from /root/.cache/jukebox-assets/models/5b/vqvae.pth.tar\n","0: Loading vqvae in eval mode\n","Loading artist IDs from /usr/local/lib/python3.6/dist-packages/jukebox/data/ids/v2_artist_ids.txt\n","Loading artist IDs from /usr/local/lib/python3.6/dist-packages/jukebox/data/ids/v2_genre_ids.txt\n","Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576\n","0: Converting to fp16 params\n","Downloading from gce\n","Restored from /root/.cache/jukebox-assets/models/5b_lyrics/prior_level_2.pth.tar\n","0: Loading prior in eval mode\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"rvf-5pnjbmI1"},"source":["# Select mode\n","Run one of these cells to select the desired mode."]},{"cell_type":"code","metadata":{"id":"VVOQ3egdj65y"},"source":["# The default mode of operation.\n","# Creates songs based on artist and genre conditioning.\n","mode = 'ancestral'\n","codes_file=None\n","audio_file=None\n","prompt_length_in_seconds=None"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Vqqv2rJKkMXd"},"source":["# Prime song creation using an arbitrary audio sample.\n","mode = 'primed'\n","codes_file=None\n","# Specify an audio file here.\n","audio_file = '/content/gdrive/My Drive/revenge.wav'\n","# Specify how many seconds of audio to prime on.\n","prompt_length_in_seconds=25"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OxZMi-S3cT2b"},"source":["Run this cell to automatically resume from the latest checkpoint file, but only if the checkpoint file exists.\n","This will override the selected mode.\n","We will assume the existance of a checkpoint means generation is complete and it's time for upsamping to occur."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"GjRwyTDhbvf-","executionInfo":{"elapsed":1171,"status":"ok","timestamp":1605619480560,"user":{"displayName":"","photoUrl":"","userId":""},"user_tz":0},"outputId":"c5e812f4-587e-4f14-89fa-94480e2d3948"},"source":["if os.path.exists(hps.name):\n"," # Identify the lowest level generated and continue from there.\n"," for level in [1, 2]:\n"," data = f\"{hps.name}/level_{level}/data.pth.tar\"\n"," if os.path.isfile(data):\n"," mode = 'upsample'\n"," codes_file = data\n"," print('Upsampling from level '+str(level))\n"," break\n","print('mode is now '+mode)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Upsampling from level 1\n","mode is now upsample\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"UA2UhOZ4YfZj"},"source":["Run the cell below regardless of which mode you chose."]},{"cell_type":"code","metadata":{"id":"Jp7nKnCmk1bx"},"source":["sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JYKiwkzy0Iyf"},"source":["Specify your choice of artist, genre, lyrics, and length of musical sample. "]},{"cell_type":"code","metadata":{"id":"-sY9aGHcZP-u"},"source":["sample_length_in_seconds = 90 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n"," # range work well, with generation time proportional to sample length. \n"," # This total length affects how quickly the model \n"," # progresses through lyrics (model also generates differently\n"," # depending on if it thinks it's in the beginning, middle, or end of sample)\n","hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n","assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"qD0qxQeLaTR0"},"source":["# Note: Metas can contain different prompts per sample.\n","# By default, all samples use the same prompt.\n","metas = [dict(artist = \"Johnny Cash\",\n"," genre = \"Rap\",\n"," total_length = hps.sample_length,\n"," offset = 0,\n"," lyrics = \"\"\"Imagine there's no countries\n","It isn't hard to do\n","Nothing to kill or die for\n","And no religion, too\n","Imagine all the people\n","Living life in peace\n","You, you may say I'm a dreamer\n","But I'm not the only one\n","I hope someday you will join us\n","And the world will be as one\n","Imagine no possessions\n","I wonder if you can\n","No need for greed or hunger\n","A brotherhood of man\n","Imagine all the people\n","Sharing all the world\n","You, you may say I'm a dreamer\n","But I'm not the only one\n","I hope someday you will join us\n","And the world will live as one\n","\"\"\",\n"," ),\n"," ] * hps.n_samples\n","labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6PHC1XnEfV4Y"},"source":["Optionally adjust the sampling temperature (we've found .98 or .99 to be our favorite). \n"]},{"cell_type":"code","metadata":{"id":"eNwKyqYraTR9"},"source":["sampling_temperature = .98\n","\n","lower_batch_size = 16\n","max_batch_size = 9 if model in ('5b', '5b_lyrics') else 16\n","lower_level_chunk_size = 32\n","chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n","sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,\n"," chunk_size=lower_level_chunk_size),\n"," dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,\n"," chunk_size=lower_level_chunk_size),\n"," dict(temp=sampling_temperature, fp16=True, \n"," max_batch_size=max_batch_size, chunk_size=chunk_size)]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"S3j0gT3HfrRD"},"source":["Now we're ready to sample from the model. We'll generate the top level (2) first, followed by the first upsampling (level 1), and the second upsampling (0). In this CoLab we load the top prior separately from the upsamplers, because of memory concerns on the hosted runtimes. If you are using a local machine, you can also load all models directly with make_models, and then use sample.py's ancestral_sampling to put this all in one step.\n","\n","After each level, we decode to raw audio and save the audio files. \n","\n","This next cell will take a while (approximately 10 minutes per 20 seconds of music sample)"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9a1tlvcVlHhN","executionInfo":{"status":"ok","timestamp":1605711460072,"user_tz":0,"elapsed":5435154,"user":{"displayName":"Lokesh Saini","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgZnsdAWAjb7BpCuaFoEopOgAONRk9EvirLaapB=s64","userId":"04919528447711792423"}},"outputId":"28c98d8b-272e-4b2d-bcd1-58a9c2146c25"},"source":["if sample_hps.mode == 'ancestral':\n"," zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]\n"," zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)\n","elif sample_hps.mode == 'upsample':\n"," assert sample_hps.codes_file is not None\n"," # Load codes.\n"," data = t.load(sample_hps.codes_file, map_location='cpu')\n"," zs = [z.cuda() for z in data['zs']]\n"," assert zs[-1].shape[0] == hps.n_samples, f\"Expected bs = {hps.n_samples}, got {zs[-1].shape[0]}\"\n"," del data\n"," print('Falling through to the upsample step later in the notebook.')\n","elif sample_hps.mode == 'primed':\n"," assert sample_hps.audio_file is not None\n"," audio_files = sample_hps.audio_file.split(',')\n"," duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n"," x = load_prompts(audio_files, duration, hps)\n"," zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])\n"," zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)\n","else:\n"," raise ValueError(f'Unknown sample mode {sample_hps.mode}.')"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Sampling level 2\n","Sampling 8192 tokens for [0,8192]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [1024,9216]. Conditioning on 7589 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","475/475 [01:59<00:00, 3.98it/s]\n","603/603 [00:54<00:00, 11.16it/s]\n","Sampling 8192 tokens for [2048,10240]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.30it/s]\n","Sampling 8192 tokens for [3072,11264]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.32it/s]\n","Sampling 8192 tokens for [4096,12288]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.30it/s]\n","Sampling 8192 tokens for [5120,13312]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.31it/s]\n","Sampling 8192 tokens for [6144,14336]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.31it/s]\n","Sampling 8192 tokens for [7168,15360]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [8192,16384]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [9216,17408]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [10240,18432]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [11264,19456]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.30it/s]\n","Sampling 8192 tokens for [12288,20480]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.32it/s]\n","Sampling 8192 tokens for [13312,21504]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.05it/s]\n","1024/1024 [01:30<00:00, 11.27it/s]\n","Sampling 8192 tokens for [14336,22528]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.32it/s]\n","Sampling 8192 tokens for [15360,23552]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.32it/s]\n","Sampling 8192 tokens for [16384,24576]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [17408,25600]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [18432,26624]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.34it/s]\n","Sampling 8192 tokens for [19456,27648]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.33it/s]\n","Sampling 8192 tokens for [20480,28672]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.34it/s]\n","Sampling 8192 tokens for [21504,29696]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.06it/s]\n","1024/1024 [01:30<00:00, 11.31it/s]\n","Sampling 8192 tokens for [22528,30720]. Conditioning on 7168 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","448/448 [01:50<00:00, 4.05it/s]\n","1024/1024 [01:30<00:00, 11.28it/s]\n","Sampling 8192 tokens for [22815,31007]. Conditioning on 7905 tokens\n","Primed sampling 3 samples with temp=0.98, top_k=0, top_p=0.0\n","495/495 [02:06<00:00, 3.92it/s]\n","287/287 [00:25<00:00, 11.08it/s]\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"-gxY9aqHqfLJ"},"source":["Listen to the results from the top level (note this will sound very noisy until we do the upsampling stage). You may have more generated samples, depending on the batch size you requested."]},{"cell_type":"code","metadata":{"id":"TPZENDGZqOOb"},"source":["Audio(f'{hps.name}/level_2/item_0.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EJc3bQxmusc6"},"source":["We are now done with the large top_prior model, and instead load the upsamplers."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"W5VLX0zRapIm","executionInfo":{"status":"ok","timestamp":1605712515892,"user_tz":0,"elapsed":92845,"user":{"displayName":"Lokesh Saini","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgZnsdAWAjb7BpCuaFoEopOgAONRk9EvirLaapB=s64","userId":"04919528447711792423"}},"outputId":"f4987f20-b2c4-42a4-eaef-36d36b4962be"},"source":["# Set this False if you are on a local machine that has enough memory (this allows you to do the\n","# lyrics alignment visualization during the upsampling stage). For a hosted runtime, \n","# we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.\n","if True:\n"," del top_prior\n"," empty_cache()\n"," top_prior=None\n","upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]\n","labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Conditioning on 1 above level(s)\n","Checkpointing convs\n","Checkpointing convs\n","Loading artist IDs from /usr/local/lib/python3.6/dist-packages/jukebox/data/ids/v2_artist_ids.txt\n","Loading artist IDs from /usr/local/lib/python3.6/dist-packages/jukebox/data/ids/v2_genre_ids.txt\n","Level:0, Cond downsample:4, Raw to tokens:8, Sample length:65536\n","Downloading from gce\n","Restored from /root/.cache/jukebox-assets/models/5b/prior_level_0.pth.tar\n","0: Loading prior in eval mode\n","Conditioning on 1 above level(s)\n","Checkpointing convs\n","Checkpointing convs\n","Loading artist IDs from /usr/local/lib/python3.6/dist-packages/jukebox/data/ids/v2_artist_ids.txt\n","Loading artist IDs from /usr/local/lib/python3.6/dist-packages/jukebox/data/ids/v2_genre_ids.txt\n","Level:1, Cond downsample:4, Raw to tokens:32, Sample length:262144\n","Downloading from gce\n","Restored from /root/.cache/jukebox-assets/models/5b/prior_level_1.pth.tar\n","0: Loading prior in eval mode\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"eH_jUhGDprAt"},"source":["Please note: this next upsampling step will take several hours. At the free tier, Google CoLab lets you run for 12 hours. As the upsampling is completed, samples will appear in the Files tab (you can access this at the left of the CoLab), under \"samples\" (or whatever hps.name is currently). Level 1 is the partially upsampled version, and then Level 0 is fully completed."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9lkJgLolpZ6w","outputId":"adcc76ff-cdfa-4241-85c3-f467e4166a01"},"source":["zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Sampling level 1\n","Sampling 8192 tokens for [0,8192]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [4096,12288]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [8192,16384]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [12288,20480]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [16384,24576]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [20480,28672]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [24576,32768]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [28672,36864]. Conditioning on 5780 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","181/181 [00:18<00:00, 10.01it/s]\n","2412/2412 [03:04<00:00, 13.09it/s]\n","Sampling 8192 tokens for [32768,40960]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.15it/s]\n","4096/4096 [05:15<00:00, 13.00it/s]\n","Sampling 8192 tokens for [36864,45056]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.16it/s]\n","4096/4096 [05:12<00:00, 13.13it/s]\n","Sampling 8192 tokens for [40960,49152]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.69it/s]\n","4096/4096 [05:19<00:00, 12.82it/s]\n","Sampling 8192 tokens for [45056,53248]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.57it/s]\n","4096/4096 [05:23<00:00, 12.68it/s]\n","Sampling 8192 tokens for [49152,57344]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.91it/s]\n","4096/4096 [05:23<00:00, 12.66it/s]\n","Sampling 8192 tokens for [53248,61440]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.79it/s]\n","4096/4096 [05:22<00:00, 12.68it/s]\n","Sampling 8192 tokens for [57344,65536]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.92it/s]\n","4096/4096 [05:18<00:00, 12.88it/s]\n","Sampling 8192 tokens for [61440,69632]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.00it/s]\n","4096/4096 [05:11<00:00, 13.15it/s]\n","Sampling 8192 tokens for [65536,73728]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.35it/s]\n","4096/4096 [05:21<00:00, 12.73it/s]\n","Sampling 8192 tokens for [69632,77824]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.82it/s]\n","4096/4096 [05:23<00:00, 12.65it/s]\n","Sampling 8192 tokens for [73728,81920]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.58it/s]\n","4096/4096 [05:28<00:00, 12.46it/s]\n","Sampling 8192 tokens for [77824,86016]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.67it/s]\n","4096/4096 [05:16<00:00, 12.96it/s]\n","Sampling 8192 tokens for [81920,90112]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.12it/s]\n","4096/4096 [05:16<00:00, 12.95it/s]\n","Sampling 8192 tokens for [86016,94208]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.31it/s]\n","4096/4096 [05:15<00:00, 12.99it/s]\n","Sampling 8192 tokens for [90112,98304]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.30it/s]\n","4096/4096 [05:25<00:00, 12.59it/s]\n","Sampling 8192 tokens for [94208,102400]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.81it/s]\n","4096/4096 [05:16<00:00, 12.96it/s]\n","Sampling 8192 tokens for [98304,106496]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.08it/s]\n","4096/4096 [05:10<00:00, 13.20it/s]\n","Sampling 8192 tokens for [102400,110592]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.11it/s]\n","4096/4096 [05:09<00:00, 13.24it/s]\n","Sampling 8192 tokens for [106496,114688]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.14it/s]\n","4096/4096 [05:05<00:00, 13.40it/s]\n","Sampling 8192 tokens for [110592,118784]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.34it/s]\n","4096/4096 [05:10<00:00, 13.18it/s]\n","Sampling 8192 tokens for [114688,122880]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.04it/s]\n","4096/4096 [05:14<00:00, 13.01it/s]\n","Sampling 8192 tokens for [115836,124028]. Conditioning on 7044 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","221/221 [00:23<00:00, 9.59it/s]\n","1148/1148 [01:28<00:00, 12.99it/s]\n","Sampling level 0\n","Sampling 8192 tokens for [0,8192]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [4096,12288]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [8192,16384]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [12288,20480]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [16384,24576]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [20480,28672]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [24576,32768]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [28672,36864]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [32768,40960]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [36864,45056]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [40960,49152]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [45056,53248]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [49152,57344]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [53248,61440]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [57344,65536]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [61440,69632]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [65536,73728]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [69632,77824]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [73728,81920]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [77824,86016]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [81920,90112]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [86016,94208]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [90112,98304]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [94208,102400]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [98304,106496]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [102400,110592]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [106496,114688]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [110592,118784]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [114688,122880]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [118784,126976]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [122880,131072]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [126976,135168]. Conditioning on 8192 tokens\n","Sampling 8192 tokens for [131072,139264]. Conditioning on 6736 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","211/211 [00:22<00:00, 9.30it/s]\n","1456/1456 [01:52<00:00, 12.98it/s]\n","Sampling 8192 tokens for [135168,143360]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.17it/s]\n","4096/4096 [05:13<00:00, 13.08it/s]\n","Sampling 8192 tokens for [139264,147456]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 10.04it/s]\n","4096/4096 [05:15<00:00, 12.96it/s]\n","Sampling 8192 tokens for [143360,151552]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.84it/s]\n","4096/4096 [05:12<00:00, 13.09it/s]\n","Sampling 8192 tokens for [147456,155648]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.96it/s]\n","4096/4096 [05:15<00:00, 12.98it/s]\n","Sampling 8192 tokens for [151552,159744]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.91it/s]\n","4096/4096 [05:23<00:00, 12.68it/s]\n","Sampling 8192 tokens for [155648,163840]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.31it/s]\n","4096/4096 [05:22<00:00, 12.70it/s]\n","Sampling 8192 tokens for [159744,167936]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.94it/s]\n","4096/4096 [05:15<00:00, 13.00it/s]\n","Sampling 8192 tokens for [163840,172032]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.94it/s]\n","4096/4096 [05:13<00:00, 13.08it/s]\n","Sampling 8192 tokens for [167936,176128]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.90it/s]\n","4096/4096 [05:19<00:00, 12.83it/s]\n","Sampling 8192 tokens for [172032,180224]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:12<00:00, 9.88it/s]\n","4096/4096 [05:25<00:00, 12.57it/s]\n","Sampling 8192 tokens for [176128,184320]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.45it/s]\n","4096/4096 [05:25<00:00, 12.57it/s]\n","Sampling 8192 tokens for [180224,188416]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.68it/s]\n","4096/4096 [05:23<00:00, 12.66it/s]\n","Sampling 8192 tokens for [184320,192512]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.71it/s]\n","4096/4096 [05:25<00:00, 12.57it/s]\n","Sampling 8192 tokens for [188416,196608]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.47it/s]\n","4096/4096 [05:26<00:00, 12.53it/s]\n","Sampling 8192 tokens for [192512,200704]. Conditioning on 4096 tokens\n","Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0\n","128/128 [00:13<00:00, 9.36it/s]\n","1980/4096 [02:39<02:51, 12.32it/s]Buffered data was truncated after reaching the output size limit."],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"3SJgBYJPri55"},"source":["Listen to your final sample!"]},{"cell_type":"code","metadata":{"id":"2ip2PPE0rgAb"},"source":["Audio(f'{hps.name}/level_0/item_0.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"8JAgFxytwrLG"},"source":["del upsamplers\n","empty_cache()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LpvvFH85bbBC"},"source":["# Co-Composing with the 5B or 1B Lyrics Model"]},{"cell_type":"markdown","metadata":{"id":"nFDROuS7gFQY"},"source":["For more control over the generations, try co-composing with either the 5B or 1B Lyrics Models. Again, specify your artist, genre, and lyrics. However, now instead of generating the entire sample, the model will return 3 short options for the opening of the piece (or up to 16 options if you use the 1B model instead). Choose your favorite, and then continue the loop, for as long as you like. Throughout these steps, you'll be listening to the audio at the top prior level, which means it will sound quite noisy. When you are satisfied with your co-creation, continue on through the upsampling section. This will render the piece in higher audio quality.\n","\n","NOTE: CoLab will first assign you a lower memory machine if you are using a hosted runtime. The next cell will run out of memory, and then you'll be prompted to restart with more memory (then return to the top of this CoLab). If you continue to have memory issues after this (or run into issues on your own home setup), switch to the 1B model. "]},{"cell_type":"code","metadata":{"id":"3y-q8ifhGBlU"},"source":["model = \"5b_lyrics\" # or \"1b_lyrics\"\n","hps = Hyperparams()\n","hps.sr = 44100\n","hps.n_samples = 3 if model in ('5b', '5b_lyrics') else 16\n","# Specifies the directory to save the sample in.\n","# We set this to the Google Drive mount point.\n","hps.name = '/content/gdrive/My Drive/co_composer'\n","hps.sample_length = 1048576 if model in ('5b', '5b_lyrics') else 786432 \n","chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n","max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n","hps.hop_fraction = [.5, .5, .125] \n","hps.levels = 3\n","\n","vqvae, *priors = MODELS[model]\n","vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = hps.sample_length)), device)\n","top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uY8QqH1Bil2-"},"source":["# Select mode\n","Run one of these cells to select the desired mode."]},{"cell_type":"code","metadata":{"id":"AIvlWcnNil2-"},"source":["# The default mode of operation.\n","# Creates songs based on artist and genre conditioning.\n","mode = 'ancestral'\n","codes_file=None\n","audio_file=None\n","prompt_length_in_seconds=None"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mAjCnbQHil3A"},"source":["# Prime song creation using an arbitrary audio sample.\n","mode = 'primed'\n","codes_file=None\n","# Specify an audio file here.\n","audio_file = '/content/gdrive/My Drive/primer.wav'\n","# Specify how many seconds of audio to prime on.\n","prompt_length_in_seconds=12"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SgHeyokTil3C"},"source":["Run the cell below regardless of which mode you chose."]},{"cell_type":"code","metadata":{"id":"XS2CnkVcil3C"},"source":["sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"68hz4x7igq0c"},"source":["Specify your choice of artist, genre, lyrics, and length of musical sample. "]},{"cell_type":"code","metadata":{"id":"z1QelEBiil3F"},"source":["sample_length_in_seconds = 71 # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n"," # range work well, with generation time proportional to sample length. \n"," # This total length affects how quickly the model \n"," # progresses through lyrics (model also generates differently\n"," # depending on if it thinks it's in the beginning, middle, or end of sample)\n","hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n","assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"QDMvH_1zUHo6"},"source":["metas = [dict(artist = \"Zac Brown Band\",\n"," genre = \"Country\",\n"," total_length = hps.sample_length,\n"," offset = 0,\n"," lyrics = \"\"\"I met a traveller from an antique land,\n"," Who said—“Two vast and trunkless legs of stone\n"," Stand in the desert. . . . Near them, on the sand,\n"," Half sunk a shattered visage lies, whose frown,\n"," And wrinkled lip, and sneer of cold command,\n"," Tell that its sculptor well those passions read\n"," Which yet survive, stamped on these lifeless things,\n"," The hand that mocked them, and the heart that fed;\n"," And on the pedestal, these words appear:\n"," My name is Ozymandias, King of Kings;\n"," Look on my Works, ye Mighty, and despair!\n"," Nothing beside remains. Round the decay\n"," Of that colossal Wreck, boundless and bare\n"," The lone and level sands stretch far away\n"," \"\"\",\n"," ),\n"," ] * hps.n_samples\n","labels = top_prior.labeller.get_batch_labels(metas, 'cuda')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"B9onZMEXh34f"},"source":["## Generate 3 options for the start of the song\n","\n","Initial generation is set to be 4 seconds long, but feel free to change this"]},{"cell_type":"code","metadata":{"id":"c6peEj8I_HHO"},"source":["def seconds_to_tokens(sec, sr, prior, chunk_size):\n"," tokens = sec * hps.sr // prior.raw_to_tokens\n"," tokens = ((tokens // chunk_size) + 1) * chunk_size\n"," assert tokens <= prior.n_ctx, 'Choose a shorter generation length to stay within the top prior context'\n"," return tokens"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2gn2GXt3zt3y"},"source":["initial_generation_in_seconds = 4\n","tokens_to_sample = seconds_to_tokens(initial_generation_in_seconds, hps.sr, top_prior, chunk_size)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U0zcWcMoiigl"},"source":["Change the sampling temperature if you like (higher is more random). Our favorite is in the range .98 to .995"]},{"cell_type":"code","metadata":{"id":"NHbH68H7VMeO"},"source":["sampling_temperature = .98\n","\n","lower_batch_size = 16\n","max_batch_size = 3 if model in ('5b', '5b_lyrics') else 16\n","lower_level_chunk_size = 32\n","chunk_size = 16 if model in ('5b', '5b_lyrics') else 32\n","sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=lower_batch_size,\n"," chunk_size=lower_level_chunk_size)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"JGZEPe-WTt4g"},"source":["if sample_hps.mode == 'ancestral':\n"," zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]\n"," zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n","elif sample_hps.mode == 'primed':\n"," assert sample_hps.audio_file is not None\n"," audio_files = sample_hps.audio_file.split(',')\n"," duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n"," x = load_prompts(audio_files, duration, hps)\n"," zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])\n"," zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n","x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mveN4Be8jK2J"},"source":["Listen to your generated samples, and then pick a favorite. If you don't like any, go back and rerun the cell above. \n","\n","** NOTE this is at the noisy top level, upsample fully (in the next section) to hear the final audio version"]},{"cell_type":"code","metadata":{"id":"LrJSGMhUOhZg"},"source":["for i in range(hps.n_samples):\n"," librosa.output.write_wav(f'noisy_top_level_generation_{i}.wav', x[i], sr=44100)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"rQ4ersQ5OhZr"},"source":["Audio('noisy_top_level_generation_0.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-GdqzrGkOhZv"},"source":["Audio('noisy_top_level_generation_1.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"gE5S8hyZOhZy"},"source":["Audio('noisy_top_level_generation_2.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"t2-mEJaqZfuS"},"source":["If you don't like any of the options, return a few cells back to \"Sample a few options...\" and rerun from there."]},{"cell_type":"markdown","metadata":{"id":"o7CzSiv0MmFP"},"source":["## Choose your favorite sample and request longer generation\n","\n","---\n","\n","(Repeat from here)\n"]},{"cell_type":"code","metadata":{"id":"j_XFtVi99CIY"},"source":["my_choice=0"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pgk3sHHBLYoq"},"source":["zs[2]=zs[2][my_choice].repeat(hps.n_samples,1)\n","t.save(zs, 'zs-checkpoint2.t')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"W8Rd9xxm565S"},"source":["# Set to True to load the previous checkpoint:\n","if False:\n"," zs=t.load('zs-checkpoint2.t') "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"k12xjMgHkRGP"},"source":["Choose the length of the continuation. The 1B model can generate up to 17 second samples and the 5B up to 23 seconds, but you'll want to pick a shorter continuation length so that it will be able to look back at what you've generated already. Here we've chosen 4 seconds."]},{"cell_type":"code","metadata":{"id":"h3_-0a07kHHG"},"source":["continue_generation_in_seconds=4\n","tokens_to_sample = seconds_to_tokens(continue_generation_in_seconds, hps.sr, top_prior, chunk_size)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GpPG3Ifqk8ue"},"source":["The next step asks the top prior to generate more of the sample. It'll take up to a few minutes, depending on the sample length you request."]},{"cell_type":"code","metadata":{"id":"YoHkeSTaEyLj"},"source":["zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n","x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ymhUqEdhleEi"},"source":["Now listen to the longer versions of the sample you selected, and again choose a favorite sample. If you don't like any, return back to the cell where you can load the checkpoint, and continue again from there.\n","\n","When the samples start getting long, you might not always want to listen from the start, so change the playback start time later on if you like."]},{"cell_type":"code","metadata":{"id":"2H1LNLTa_R6a"},"source":["playback_start_time_in_seconds = 0 "],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"r4SBGAmsnJtH"},"source":["for i in range(hps.n_samples):\n"," librosa.output.write_wav(f'top_level_continuation_{i}.wav', x[i][playback_start_time_in_seconds*44100:], sr=44100)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2WeyE5Qtnmeo"},"source":["Audio('top_level_continuation_0.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"BKtfEtcaazXE"},"source":["Audio('top_level_continuation_1.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"7yrlS0XwK2S0"},"source":["Audio('top_level_continuation_2.wav')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-OJT704dvnGv"},"source":["To make a longer song, return back to \"Choose your favorite sample\" and loop through that again"]},{"cell_type":"markdown","metadata":{"id":"RzCrkCZJvUcQ"},"source":["# Upsample Co-Composition to Higher Audio Quality"]},{"cell_type":"markdown","metadata":{"id":"4MPgukwMmB0p"},"source":["Choose your favorite sample from your latest group of generations. (If you haven't already gone through the Co-Composition block, make sure to do that first so you have a generation to upsample)."]},{"cell_type":"code","metadata":{"id":"yv-pNNPHBQYC"},"source":["choice = 0\n","select_best_sample = True # Set false if you want to upsample all your samples \n"," # upsampling sometimes yields subtly different results on multiple runs,\n"," # so this way you can choose your favorite upsampling"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"v17cEAqyCgfo"},"source":["if select_best_sample:\n"," zs[2]=zs[2][choice].repeat(zs[2].shape[0],1)\n","\n","t.save(zs, 'zs-top-level-final.t')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0YjK-Ac0tBfu"},"source":["Note: If you are using a CoLab hosted runtime on the free tier, you may want to download this zs-top-level-final.t file, and then restart an instance and load it in the next cell. The free tier will last a maximum of 12 hours, and the upsampling stage can take many hours, depending on how long a sample you have generated."]},{"cell_type":"code","metadata":{"id":"qqlR9368s3jJ"},"source":["if False:\n"," zs = t.load('zs-top-level-final.t')\n","\n","assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'\n","hps.sample_length = zs[2].shape[1]*top_prior.raw_to_tokens"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"jzHwF_iqgIWM"},"source":["# Set this False if you are on a local machine that has enough memory (this allows you to do the\n","# lyrics alignment visualization). For a hosted runtime, we'll need to go ahead and delete the top_prior\n","# if you are using the 5b_lyrics model.\n","if True:\n"," del top_prior\n"," empty_cache()\n"," top_prior=None\n","\n","upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"q22Ier6YSkKS"},"source":["sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=16, chunk_size=32),\n"," dict(temp=0.99, fp16=True, max_batch_size=16, chunk_size=32),\n"," None]\n","\n","if type(labels)==dict:\n"," labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [labels] "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"T1MCa9_jnjpf"},"source":["This next step upsamples 2 levels. The level_1 samples will be available after around one hour (depending on the length of your sample) and are saved under {hps.name}/level_0/item_0.wav, while the fully upsampled level_0 will likely take 4-12 hours. You can access the wav files down below, or using the \"Files\" panel at the left of this CoLab.\n","\n","(Please note, if you are using this CoLab on Google's free tier, you may want to download intermediate steps as the connection will last for a maximum 12 hours.)"]},{"cell_type":"code","metadata":{"id":"NcNT5qIRMmHq"},"source":["zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"W2jTYLPBc29M"},"source":["Audio(f'{hps.name}/level_0/item_0.wav')"],"execution_count":null,"outputs":[]}]}