{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Build a graph of a pytorch based model\n", "\n", "In this tutorial, we walk you through the steps of building a graph of a pytorch model. The example chosen will be the Dit-XL/2 Model. Note that many of these steps mirror what you see here: https://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb. More details on this model can be found at https://github.com/facebookresearch/DiT?tab=readme-ov-file \n", "\n", "The assumption is that this notebook has access to a machine with a Tenstorrent device and that tt-metal has been successfully built.\n", "\n", "We will follow these steps:\n", "- Clone the library from https://github.com/facebookresearch/DiT.git\n", "- Download DiT-XL\n", "- Sample from the Pre-trained DiT Models and build the graph\n", "- Display the graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Clone the library from https://github.com/facebookresearch/DiT.git" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-04-05 05:49:00.322 | INFO | ttnn::28 - ttnn: model cache was enabled\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fatal: destination path 'DiT' already exists and is not an empty directory.\n", "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu\n", "Requirement already up-to-date: diffusers in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (0.27.2)\n", "Requirement already up-to-date: timm in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (0.9.16)\n", "Requirement already satisfied, skipping upgrade: filelock in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (3.13.1)\n", "Requirement already satisfied, skipping upgrade: Pillow in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (9.5.0)\n", "Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (2023.12.25)\n", "Requirement already satisfied, skipping upgrade: importlib-metadata in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (7.1.0)\n", "Requirement already satisfied, skipping upgrade: numpy in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (1.24.4)\n", "Requirement already satisfied, skipping upgrade: huggingface-hub>=0.20.2 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (0.21.4)\n", "Requirement already satisfied, skipping upgrade: safetensors>=0.3.1 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (0.4.2)\n", "Requirement already satisfied, skipping upgrade: requests in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from diffusers) (2.31.0)\n", "Requirement already satisfied, skipping upgrade: pyyaml in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from timm) (5.3.1)\n", "Requirement already satisfied, skipping upgrade: torchvision in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from timm) (0.17.1+cpu)\n", "Requirement already satisfied, skipping upgrade: torch in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from timm) (2.2.1+cpu)\n", "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from importlib-metadata->diffusers) (3.18.1)\n", "Requirement already satisfied, skipping upgrade: fsspec>=2023.5.0 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from huggingface-hub>=0.20.2->diffusers) (2023.9.2)\n", "Requirement already satisfied, skipping upgrade: typing-extensions>=3.7.4.3 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from huggingface-hub>=0.20.2->diffusers) (4.10.0)\n", "Requirement already satisfied, skipping upgrade: packaging>=20.9 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from huggingface-hub>=0.20.2->diffusers) (24.0)\n", "Requirement already satisfied, skipping upgrade: tqdm>=4.42.1 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from huggingface-hub>=0.20.2->diffusers) (4.65.0)\n", "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from requests->diffusers) (2024.2.2)\n", "Requirement already satisfied, skipping upgrade: idna<4,>=2.5 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from requests->diffusers) (3.6)\n", "Requirement already satisfied, skipping upgrade: charset-normalizer<4,>=2 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from requests->diffusers) (3.3.2)\n", "Requirement already satisfied, skipping upgrade: urllib3<3,>=1.21.1 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from requests->diffusers) (2.2.1)\n", "Requirement already satisfied, skipping upgrade: sympy in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from torch->timm) (1.12)\n", "Requirement already satisfied, skipping upgrade: jinja2 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from torch->timm) (3.1.3)\n", "Requirement already satisfied, skipping upgrade: networkx in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from torch->timm) (3.1)\n", "Requirement already satisfied, skipping upgrade: mpmath>=0.19 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from sympy->torch->timm) (1.3.0)\n", "Requirement already satisfied, skipping upgrade: MarkupSafe>=2.0 in /home/ubuntu/git/tt-metal/build/python_env/lib/python3.8/site-packages (from jinja2->torch->timm) (2.1.5)\n", "GPU not found. Using CPU instead.\n" ] } ], "source": [ "import ttnn\n", "\n", "!git clone https://github.com/facebookresearch/DiT.git\n", "import DiT, os\n", "os.chdir('DiT')\n", "os.environ['PYTHONPATH'] = '/env/python:/content/DiT'\n", "!pip install diffusers timm --upgrade\n", "# DiT imports:\n", "import torch\n", "from torchvision.utils import save_image\n", "from diffusion import create_diffusion\n", "from diffusers.models import AutoencoderKL\n", "from download import find_model\n", "# We have a name collision in python's namespace with ttnn using models and models also existing in DiT\n", "# So here we just append DiT to the front.\n", "from DiT.models import DiT_XL_2\n", "from PIL import Image\n", "from IPython.display import display\n", "torch.set_grad_enabled(False)\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "if device == \"cpu\":\n", " print(\"GPU not found. Using CPU instead.\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download DiT-XL/2 Models" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "image_size = 256 #@param [256, 512]\n", "vae_model = \"stabilityai/sd-vae-ft-ema\" #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n", "latent_size = int(image_size) // 8\n", "# Load model:\n", "model = DiT_XL_2(input_size=latent_size).to(device)\n", "state_dict = find_model(f\"DiT-XL-2-{image_size}x{image_size}.pt\")\n", "model.load_state_dict(state_dict)\n", "model.eval() # important!\n", "vae = AutoencoderKL.from_pretrained(vae_model).to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sample from Pre-trained DiT Models and build the graph\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3117d12d79234600bd29780debc8da25", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/250 [00:00\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "_5\n", "\n", "torch.nn.modules.conv.Conv2d\n", "duration: 2.4 ms\n", "\n", "\n", "_580\n", "\n", "diffusers.models.autoencoders.vae.Decoder\n", "duration: 9.9 s\n", "\n", "\n", "conv_in_11\n", "\n", "torch.nn.modules.conv.Conv2d\n", "conv_in\n", "duration: 3.0 ms\n", "\n", "\n", "mid_block_127\n", "\n", "diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D\n", "mid_block\n", "duration: 447.2 ms\n", "\n", "\n", "mid_block.resnets.0_45\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "mid_block.resnets.0\n", "duration: 158.9 ms\n", "\n", "\n", "mid_block.resnets.0.norm1_18\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "mid_block.resnets.0.norm1\n", "duration: 2.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_21\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 1.2 ms\n", "\n", "\n", "mid_block.resnets.0.conv1_26\n", "\n", "torch.nn.modules.conv.Conv2d\n", "mid_block.resnets.0.conv1\n", "duration: 69.0 ms\n", "\n", "\n", "mid_block.resnets.0.norm2_31\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "mid_block.resnets.0.norm2\n", "duration: 4.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_34\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 2.8 ms\n", "\n", "\n", "mid_block.resnets.0.dropout_37\n", "\n", "torch.nn.modules.dropout.Dropout\n", "mid_block.resnets.0.dropout\n", "duration: 422.5 µs\n", "\n", "\n", "mid_block.resnets.0.conv2_42\n", "\n", "torch.nn.modules.conv.Conv2d\n", "mid_block.resnets.0.conv2\n", "duration: 66.0 ms\n", "\n", "\n", "mid_block.attentions.0_93\n", "\n", "diffusers.models.attention_processor.Attention\n", "mid_block.attentions.0\n", "duration: 127.5 ms\n", "\n", "\n", "mid_block.attentions.0.group_norm_54\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "mid_block.attentions.0.group_norm\n", "duration: 1.9 ms\n", "\n", "\n", "mid_block.attentions.0.to_q_60\n", "\n", "torch.nn.modules.linear.Linear\n", "mid_block.attentions.0.to_q\n", "duration: 16.1 ms\n", "\n", "\n", "mid_block.attentions.0.to_k_65\n", "\n", "torch.nn.modules.linear.Linear\n", "mid_block.attentions.0.to_k\n", "duration: 16.4 ms\n", "\n", "\n", "mid_block.attentions.0.to_v_70\n", "\n", "torch.nn.modules.linear.Linear\n", "mid_block.attentions.0.to_v\n", "duration: 17.9 ms\n", "\n", "\n", "mid_block.attentions.0.to_out.0_85\n", "\n", "torch.nn.modules.linear.Linear\n", "mid_block.attentions.0.to_out.0\n", "duration: 14.4 ms\n", "\n", "\n", "mid_block.attentions.0.to_out.1_88\n", "\n", "torch.nn.modules.dropout.Dropout\n", "mid_block.attentions.0.to_out.1\n", "duration: 462.1 µs\n", "\n", "\n", "mid_block.resnets.1_126\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "mid_block.resnets.1\n", "duration: 145.8 ms\n", "\n", "\n", "mid_block.resnets.1.norm1_99\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "mid_block.resnets.1.norm1\n", "duration: 2.0 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_102\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 2.1 ms\n", "\n", "\n", "mid_block.resnets.1.conv1_107\n", "\n", "torch.nn.modules.conv.Conv2d\n", "mid_block.resnets.1.conv1\n", "duration: 59.3 ms\n", "\n", "\n", "mid_block.resnets.1.norm2_112\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "mid_block.resnets.1.norm2\n", "duration: 3.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_115\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 1.7 ms\n", "\n", "\n", "mid_block.resnets.1.dropout_118\n", "\n", "torch.nn.modules.dropout.Dropout\n", "mid_block.resnets.1.dropout\n", "duration: 418.7 µs\n", "\n", "\n", "mid_block.resnets.1.conv2_123\n", "\n", "torch.nn.modules.conv.Conv2d\n", "mid_block.resnets.1.conv2\n", "duration: 58.1 ms\n", "\n", "\n", "up_blocks.0_237\n", "\n", "diffusers.models.unets.unet_2d_blocks.UpDecoderBlock2D\n", "up_blocks.0\n", "duration: 684.9 ms\n", "\n", "\n", "up_blocks.0.resnets.0_162\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.0.resnets.0\n", "duration: 139.4 ms\n", "\n", "\n", "up_blocks.0.resnets.0.norm1_135\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.0.resnets.0.norm1\n", "duration: 2.8 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_138\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 1.1 ms\n", "\n", "\n", "up_blocks.0.resnets.0.conv1_143\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.resnets.0.conv1\n", "duration: 56.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.norm2_148\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.0.resnets.0.norm2\n", "duration: 2.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_151\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 1.9 ms\n", "\n", "\n", "up_blocks.0.resnets.0.dropout_154\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.0.resnets.0.dropout\n", "duration: 403.4 µs\n", "\n", "\n", "up_blocks.0.resnets.0.conv2_159\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.resnets.0.conv2\n", "duration: 57.4 ms\n", "\n", "\n", "up_blocks.0.resnets.1_195\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.0.resnets.1\n", "duration: 142.0 ms\n", "\n", "\n", "up_blocks.0.resnets.1.norm1_168\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.0.resnets.1.norm1\n", "duration: 1.9 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_171\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 869.0 µs\n", "\n", "\n", "up_blocks.0.resnets.1.conv1_176\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.resnets.1.conv1\n", "duration: 57.2 ms\n", "\n", "\n", "up_blocks.0.resnets.1.norm2_181\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.0.resnets.1.norm2\n", "duration: 2.4 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_184\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 1.0 ms\n", "\n", "\n", "up_blocks.0.resnets.1.dropout_187\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.0.resnets.1.dropout\n", "duration: 422.0 µs\n", "\n", "\n", "up_blocks.0.resnets.1.conv2_192\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.resnets.1.conv2\n", "duration: 57.9 ms\n", "\n", "\n", "up_blocks.0.resnets.2_228\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.0.resnets.2\n", "duration: 137.9 ms\n", "\n", "\n", "up_blocks.0.resnets.2.norm1_201\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.0.resnets.2.norm1\n", "duration: 1.9 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_204\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 864.7 µs\n", "\n", "\n", "up_blocks.0.resnets.2.conv1_209\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.resnets.2.conv1\n", "duration: 55.6 ms\n", "\n", "\n", "up_blocks.0.resnets.2.norm2_214\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.0.resnets.2.norm2\n", "duration: 3.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_217\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 1.0 ms\n", "\n", "\n", "up_blocks.0.resnets.2.dropout_220\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.0.resnets.2.dropout\n", "duration: 396.5 µs\n", "\n", "\n", "up_blocks.0.resnets.2.conv2_225\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.resnets.2.conv2\n", "duration: 57.4 ms\n", "\n", "\n", "up_blocks.0.upsamplers.0_236\n", "\n", "diffusers.models.upsampling.Upsample2D\n", "up_blocks.0.upsamplers.0\n", "duration: 253.6 ms\n", "\n", "\n", "up_blocks.0.upsamplers.0.conv_235\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.0.upsamplers.0.conv\n", "duration: 247.1 ms\n", "\n", "\n", "up_blocks.1_346\n", "\n", "diffusers.models.unets.unet_2d_blocks.UpDecoderBlock2D\n", "up_blocks.1\n", "duration: 2.8 s\n", "\n", "\n", "up_blocks.1.resnets.0_271\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.1.resnets.0\n", "duration: 610.7 ms\n", "\n", "\n", "up_blocks.1.resnets.0.norm1_244\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.1.resnets.0.norm1\n", "duration: 15.1 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_247\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 5.5 ms\n", "\n", "\n", "up_blocks.1.resnets.0.conv1_252\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.resnets.0.conv1\n", "duration: 244.5 ms\n", "\n", "\n", "up_blocks.1.resnets.0.norm2_257\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.1.resnets.0.norm2\n", "duration: 30.1 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_260\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 15.9 ms\n", "\n", "\n", "up_blocks.1.resnets.0.dropout_263\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.1.resnets.0.dropout\n", "duration: 423.0 µs\n", "\n", "\n", "up_blocks.1.resnets.0.conv2_268\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.resnets.0.conv2\n", "duration: 259.9 ms\n", "\n", "\n", "up_blocks.1.resnets.1_304\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.1.resnets.1\n", "duration: 605.3 ms\n", "\n", "\n", "up_blocks.1.resnets.1.norm1_277\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.1.resnets.1.norm1\n", "duration: 7.9 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_280\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 5.6 ms\n", "\n", "\n", "up_blocks.1.resnets.1.conv1_285\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.resnets.1.conv1\n", "duration: 256.4 ms\n", "\n", "\n", "up_blocks.1.resnets.1.norm2_290\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.1.resnets.1.norm2\n", "duration: 13.4 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_293\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 6.3 ms\n", "\n", "\n", "up_blocks.1.resnets.1.dropout_296\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.1.resnets.1.dropout\n", "duration: 407.9 µs\n", "\n", "\n", "up_blocks.1.resnets.1.conv2_301\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.resnets.1.conv2\n", "duration: 268.5 ms\n", "\n", "\n", "up_blocks.1.resnets.2_337\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.1.resnets.2\n", "duration: 598.8 ms\n", "\n", "\n", "up_blocks.1.resnets.2.norm1_310\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.1.resnets.2.norm1\n", "duration: 4.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_313\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 6.3 ms\n", "\n", "\n", "up_blocks.1.resnets.2.conv1_318\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.resnets.2.conv1\n", "duration: 260.5 ms\n", "\n", "\n", "up_blocks.1.resnets.2.norm2_323\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.1.resnets.2.norm2\n", "duration: 12.4 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_326\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 6.7 ms\n", "\n", "\n", "up_blocks.1.resnets.2.dropout_329\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.1.resnets.2.dropout\n", "duration: 406.5 µs\n", "\n", "\n", "up_blocks.1.resnets.2.conv2_334\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.resnets.2.conv2\n", "duration: 264.3 ms\n", "\n", "\n", "up_blocks.1.upsamplers.0_345\n", "\n", "diffusers.models.upsampling.Upsample2D\n", "up_blocks.1.upsamplers.0\n", "duration: 940.2 ms\n", "\n", "\n", "up_blocks.1.upsamplers.0.conv_344\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.1.upsamplers.0.conv\n", "duration: 906.6 ms\n", "\n", "\n", "up_blocks.2_460\n", "\n", "diffusers.models.unets.unet_2d_blocks.UpDecoderBlock2D\n", "up_blocks.2\n", "duration: 3.2 s\n", "\n", "\n", "up_blocks.2.resnets.0_385\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.2.resnets.0\n", "duration: 944.5 ms\n", "\n", "\n", "up_blocks.2.resnets.0.norm1_353\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.2.resnets.0.norm1\n", "duration: 35.5 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_356\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 28.1 ms\n", "\n", "\n", "up_blocks.2.resnets.0.conv1_361\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.0.conv1\n", "duration: 453.6 ms\n", "\n", "\n", "up_blocks.2.resnets.0.norm2_366\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.2.resnets.0.norm2\n", "duration: 18.5 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_369\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 15.2 ms\n", "\n", "\n", "up_blocks.2.resnets.0.dropout_372\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.2.resnets.0.dropout\n", "duration: 420.1 µs\n", "\n", "\n", "up_blocks.2.resnets.0.conv2_377\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.0.conv2\n", "duration: 222.2 ms\n", "\n", "\n", "up_blocks.2.resnets.0.conv_shortcut_382\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.0.conv_shortcut\n", "duration: 55.1 ms\n", "\n", "\n", "up_blocks.2.resnets.1_418\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.2.resnets.1\n", "duration: 610.7 ms\n", "\n", "\n", "up_blocks.2.resnets.1.norm1_391\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.2.resnets.1.norm1\n", "duration: 17.2 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_394\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 14.2 ms\n", "\n", "\n", "up_blocks.2.resnets.1.conv1_399\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.1.conv1\n", "duration: 222.6 ms\n", "\n", "\n", "up_blocks.2.resnets.1.norm2_404\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.2.resnets.1.norm2\n", "duration: 17.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_407\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 14.7 ms\n", "\n", "\n", "up_blocks.2.resnets.1.dropout_410\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.2.resnets.1.dropout\n", "duration: 461.6 µs\n", "\n", "\n", "up_blocks.2.resnets.1.conv2_415\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.1.conv2\n", "duration: 222.2 ms\n", "\n", "\n", "up_blocks.2.resnets.2_451\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.2.resnets.2\n", "duration: 610.9 ms\n", "\n", "\n", "up_blocks.2.resnets.2.norm1_424\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.2.resnets.2.norm1\n", "duration: 17.9 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_427\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 14.2 ms\n", "\n", "\n", "up_blocks.2.resnets.2.conv1_432\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.2.conv1\n", "duration: 220.9 ms\n", "\n", "\n", "up_blocks.2.resnets.2.norm2_437\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.2.resnets.2.norm2\n", "duration: 17.8 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_440\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 14.8 ms\n", "\n", "\n", "up_blocks.2.resnets.2.dropout_443\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.2.resnets.2.dropout\n", "duration: 423.2 µs\n", "\n", "\n", "up_blocks.2.resnets.2.conv2_448\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.resnets.2.conv2\n", "duration: 222.7 ms\n", "\n", "\n", "up_blocks.2.upsamplers.0_459\n", "\n", "diffusers.models.upsampling.Upsample2D\n", "up_blocks.2.upsamplers.0\n", "duration: 963.5 ms\n", "\n", "\n", "up_blocks.2.upsamplers.0.conv_458\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.2.upsamplers.0.conv\n", "duration: 881.8 ms\n", "\n", "\n", "up_blocks.3_566\n", "\n", "diffusers.models.unets.unet_2d_blocks.UpDecoderBlock2D\n", "up_blocks.3\n", "duration: 2.7 s\n", "\n", "\n", "up_blocks.3.resnets.0_499\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.3.resnets.0\n", "duration: 1.2 s\n", "\n", "\n", "up_blocks.3.resnets.0.norm1_467\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.3.resnets.0.norm1\n", "duration: 215.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_470\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 55.7 ms\n", "\n", "\n", "up_blocks.3.resnets.0.conv1_475\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.0.conv1\n", "duration: 417.9 ms\n", "\n", "\n", "up_blocks.3.resnets.0.norm2_480\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.3.resnets.0.norm2\n", "duration: 32.3 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_483\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 28.5 ms\n", "\n", "\n", "up_blocks.3.resnets.0.dropout_486\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.3.resnets.0.dropout\n", "duration: 423.0 µs\n", "\n", "\n", "up_blocks.3.resnets.0.conv2_491\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.0.conv2\n", "duration: 213.5 ms\n", "\n", "\n", "up_blocks.3.resnets.0.conv_shortcut_496\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.0.conv_shortcut\n", "duration: 61.2 ms\n", "\n", "\n", "up_blocks.3.resnets.1_532\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.3.resnets.1\n", "duration: 718.7 ms\n", "\n", "\n", "up_blocks.3.resnets.1.norm1_505\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.3.resnets.1.norm1\n", "duration: 32.6 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_508\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 27.9 ms\n", "\n", "\n", "up_blocks.3.resnets.1.conv1_513\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.1.conv1\n", "duration: 216.2 ms\n", "\n", "\n", "up_blocks.3.resnets.1.norm2_518\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.3.resnets.1.norm2\n", "duration: 32.4 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_521\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 28.5 ms\n", "\n", "\n", "up_blocks.3.resnets.1.dropout_524\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.3.resnets.1.dropout\n", "duration: 536.2 µs\n", "\n", "\n", "up_blocks.3.resnets.1.conv2_529\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.1.conv2\n", "duration: 213.0 ms\n", "\n", "\n", "up_blocks.3.resnets.2_565\n", "\n", "diffusers.models.resnet.ResnetBlock2D\n", "up_blocks.3.resnets.2\n", "duration: 718.0 ms\n", "\n", "\n", "up_blocks.3.resnets.2.norm1_538\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.3.resnets.2.norm1\n", "duration: 32.8 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_541\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 28.0 ms\n", "\n", "\n", "up_blocks.3.resnets.2.conv1_546\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.2.conv1\n", "duration: 213.2 ms\n", "\n", "\n", "up_blocks.3.resnets.2.norm2_551\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "up_blocks.3.resnets.2.norm2\n", "duration: 32.6 ms\n", "\n", "\n", "up_blocks.0.resnets.0.nonlinearity_554\n", "\n", "torch.nn.modules.activation.SiLU\n", "up_blocks.0.resnets.0.nonlinearity\n", "duration: 28.4 ms\n", "\n", "\n", "up_blocks.3.resnets.2.dropout_557\n", "\n", "torch.nn.modules.dropout.Dropout\n", "up_blocks.3.resnets.2.dropout\n", "duration: 416.5 µs\n", "\n", "\n", "up_blocks.3.resnets.2.conv2_562\n", "\n", "torch.nn.modules.conv.Conv2d\n", "up_blocks.3.resnets.2.conv2\n", "duration: 213.1 ms\n", "\n", "\n", "conv_norm_out_571\n", "\n", "torch.nn.modules.normalization.GroupNorm\n", "conv_norm_out\n", "duration: 33.6 ms\n", "\n", "\n", "conv_act_574\n", "\n", "torch.nn.modules.activation.SiLU\n", "conv_act\n", "duration: 28.3 ms\n", "\n", "\n", "conv_out_579\n", "\n", "torch.nn.modules.conv.Conv2d\n", "conv_out\n", "duration: 26.7 ms\n", "\n", "\n", "\n", "torch_input_0\n", "\n", "\n", "torch.Tensor\n", "\n", "Output 0\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_1\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "torch_input_0:#0->module_input_1\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_4\n", "\n", "\n", "torch.conv2d\n", "duration: 1.0 ms\n", "\n", "Input 0\n", "(4,)\n", "torch.float32\n", "\n", "Input 1\n", "(4, 4, 1, 1)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_1:#0->conv2d_4:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_2\n", "\n", "\n", "torch.nn.Parameter\n", "weight\n", "\n", "Output 0\n", "(4, 4, 1, 1)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_2:#0->conv2d_4:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_3\n", "\n", "\n", "torch.nn.Parameter\n", "bias\n", "\n", "Output 0\n", "(4,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_3:#0->conv2d_4:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_6\n", "\n", "\n", "sample\n", "\n", "Output 0\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_4:#0->module_input_6\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_7\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_6:#0->module_input_7\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_10\n", "\n", "\n", "torch.conv2d\n", "duration: 1.8 ms\n", "\n", "Input 0\n", "(512, 4, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 4, 32, 32)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_7:#0->conv2d_10:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_8\n", "\n", "\n", "torch.nn.Parameter\n", "conv_in.weight\n", "\n", "Output 0\n", "(512, 4, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_8:#0->conv2d_10:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_9\n", "\n", "\n", "torch.nn.Parameter\n", "conv_in.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_9:#0->conv2d_10:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_12\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_10:#0->module_input_12\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_13\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_12:#0->module_input_13\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_14\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_13:#0->module_input_14\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_43\n", "\n", "\n", "torch.Tensor.add\n", "duration: 647.1 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_13:#0->add_43:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_17\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 902.7 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_14:#0->group_norm_17:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_15\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_15:#0->group_norm_17:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_16\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_16:#0->group_norm_17:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_19\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_17:#0->module_input_19\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_20\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 839.9 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_19:#0->silu_20:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_22\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_20:#0->module_input_22\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_25\n", "\n", "\n", "torch.conv2d\n", "duration: 67.6 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_22:#0->conv2d_25:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_23\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_23:#0->conv2d_25:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_24\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_24:#0->conv2d_25:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_27\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_25:#0->module_input_27\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_30\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 3.0 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_27:#0->group_norm_30:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_28\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_28:#0->group_norm_30:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_29\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_29:#0->group_norm_30:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_32\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_30:#0->module_input_32\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_33\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 2.4 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_32:#0->silu_33:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_35\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_33:#0->module_input_35\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_36\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 21.0 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_35:#0->dropout_36:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_38\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_36:#0->module_input_38\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_41\n", "\n", "\n", "torch.conv2d\n", "duration: 64.7 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_38:#0->conv2d_41:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_39\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_39:#0->conv2d_41:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_40\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.0.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_40:#0->conv2d_41:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_41:#0->add_43:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_44\n", "\n", "\n", "torch.Tensor.div\n", "duration: 227.2 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "add_43:#0->div_44:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_46\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "div_44:#0->module_input_46\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "view_47\n", "\n", "\n", "torch.Tensor.view\n", "duration: 17.6 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_46:#0->view_47:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_91\n", "\n", "\n", "torch.Tensor.add\n", "duration: 3.2 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_46:#0->add_91:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "transpose_48\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 18.4 µs\n", "\n", "Input 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "view_47:#0->transpose_48:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "transpose_49\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 10.0 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_48:#0->transpose_49:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_50\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_49:#0->module_input_50\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_53\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 643.5 µs\n", "\n", "Input 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_50:#0->group_norm_53:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_51\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.group_norm.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_51:#0->group_norm_53:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_52\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.group_norm.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_52:#0->group_norm_53:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "transpose_55\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 12.6 µs\n", "\n", "Input 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_53:#0->transpose_55:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_56\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_55:#0->module_input_56\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_61\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_55:#0->module_input_61\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_66\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_55:#0->module_input_66\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "linear_59\n", "\n", "\n", "torch.nn.functional.linear\n", "duration: 14.8 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_56:#0->linear_59:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_57\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_q.weight\n", "\n", "Output 0\n", "(512, 512)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_57:#0->linear_59:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_58\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_q.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_58:#0->linear_59:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "view_71\n", "\n", "\n", "torch.Tensor.view\n", "duration: 17.9 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "\n", "\n", "linear_59:#0->view_71:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "linear_64\n", "\n", "\n", "torch.nn.functional.linear\n", "duration: 15.2 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_61:#0->linear_64:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_62\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_k.weight\n", "\n", "Output 0\n", "(512, 512)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_62:#0->linear_64:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_63\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_k.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_63:#0->linear_64:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "view_73\n", "\n", "\n", "torch.Tensor.view\n", "duration: 11.4 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "\n", "\n", "linear_64:#0->view_73:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "linear_69\n", "\n", "\n", "torch.nn.functional.linear\n", "duration: 16.6 ms\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_66:#0->linear_69:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_67\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_v.weight\n", "\n", "Output 0\n", "(512, 512)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_67:#0->linear_69:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_68\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_v.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_68:#0->linear_69:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "view_75\n", "\n", "\n", "torch.Tensor.view\n", "duration: 10.7 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "\n", "\n", "linear_69:#0->view_75:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "transpose_72\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 16.5 µs\n", "\n", "Input 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "view_71:#0->transpose_72:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "scaled_dot_product_attention_77\n", "\n", "\n", "torch.nn.functional.scaled_dot_product_attention\n", "duration: 37.3 ms\n", "\n", "Input 0\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_72:#0->scaled_dot_product_attention_77:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "transpose_74\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 12.9 µs\n", "\n", "Input 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "view_73:#0->transpose_74:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "transpose_74:#0->scaled_dot_product_attention_77:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "transpose_76\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 11.4 µs\n", "\n", "Input 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "view_75:#0->transpose_76:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "transpose_76:#0->scaled_dot_product_attention_77:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "transpose_78\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 15.3 µs\n", "\n", "Input 0\n", "(8, 1, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "\n", "\n", "scaled_dot_product_attention_77:#0->transpose_78:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "reshape_79\n", "\n", "\n", "torch.Tensor.reshape\n", "duration: 19.1 µs\n", "\n", "Input 0\n", "(8, 1024, 1, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_78:#0->reshape_79:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "to_80\n", "\n", "\n", "torch.Tensor.to\n", "duration: 7.2 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "reshape_79:#0->to_80:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_81\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "to_80:#0->module_input_81\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "linear_84\n", "\n", "\n", "torch.nn.functional.linear\n", "duration: 11.5 ms\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_81:#0->linear_84:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_82\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_out.0.weight\n", "\n", "Output 0\n", "(512, 512)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_82:#0->linear_84:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_83\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.attentions.0.to_out.0.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_83:#0->linear_84:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_86\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "linear_84:#0->module_input_86\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_87\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 18.6 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_86:#0->dropout_87:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "transpose_89\n", "\n", "\n", "torch.Tensor.transpose\n", "duration: 15.5 µs\n", "\n", "Input 0\n", "(8, 1024, 512)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_87:#0->transpose_89:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "reshape_90\n", "\n", "\n", "torch.Tensor.reshape\n", "duration: 34.1 µs\n", "\n", "Input 0\n", "(8, 512, 1024)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "transpose_89:#0->reshape_90:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "reshape_90:#0->add_91:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "div_92\n", "\n", "\n", "torch.Tensor.div\n", "duration: 366.2 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "add_91:#0->div_92:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_94\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "div_92:#0->module_input_94\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_95\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_94:#0->module_input_95\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_124\n", "\n", "\n", "torch.Tensor.add\n", "duration: 5.6 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_94:#0->add_124:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_98\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 729.8 µs\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_95:#0->group_norm_98:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_96\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_96:#0->group_norm_98:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_97\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_97:#0->group_norm_98:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_100\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_98:#0->module_input_100\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_101\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 1.7 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_100:#0->silu_101:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_103\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_101:#0->module_input_103\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_106\n", "\n", "\n", "torch.conv2d\n", "duration: 58.1 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_103:#0->conv2d_106:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_104\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_104:#0->conv2d_106:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_105\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_105:#0->conv2d_106:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_108\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_106:#0->module_input_108\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_111\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 1.9 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_108:#0->group_norm_111:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_109\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_109:#0->group_norm_111:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_110\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_110:#0->group_norm_111:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_113\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_111:#0->module_input_113\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_114\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 1.3 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_113:#0->silu_114:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_116\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_114:#0->module_input_116\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_117\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 16.9 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_116:#0->dropout_117:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_119\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_117:#0->module_input_119\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_122\n", "\n", "\n", "torch.conv2d\n", "duration: 56.8 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_119:#0->conv2d_122:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_120\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_120:#0->conv2d_122:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_121\n", "\n", "\n", "torch.nn.Parameter\n", "mid_block.resnets.1.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_121:#0->conv2d_122:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_122:#0->add_124:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_125\n", "\n", "\n", "torch.Tensor.div\n", "duration: 1.6 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "add_124:#0->div_125:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "to_128\n", "\n", "\n", "torch.Tensor.to\n", "duration: 7.9 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "div_125:#0->to_128:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_129\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "to_128:#0->module_input_129\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_130\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_129:#0->module_input_130\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_131\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_130:#0->module_input_131\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_160\n", "\n", "\n", "torch.Tensor.add\n", "duration: 5.0 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_130:#0->add_160:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_134\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 1.4 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_131:#0->group_norm_134:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_132\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_132:#0->group_norm_134:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_133\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_133:#0->group_norm_134:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_136\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_134:#0->module_input_136\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_137\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 691.7 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_136:#0->silu_137:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_139\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_137:#0->module_input_139\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_142\n", "\n", "\n", "torch.conv2d\n", "duration: 55.0 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_139:#0->conv2d_142:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_140\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_140:#0->conv2d_142:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_141\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_141:#0->conv2d_142:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_144\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_142:#0->module_input_144\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_147\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 1.0 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_144:#0->group_norm_147:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_145\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_145:#0->group_norm_147:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_146\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_146:#0->group_norm_147:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_149\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_147:#0->module_input_149\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_150\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 1.4 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_149:#0->silu_150:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_152\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_150:#0->module_input_152\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_153\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 17.2 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_152:#0->dropout_153:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_155\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_153:#0->module_input_155\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_158\n", "\n", "\n", "torch.conv2d\n", "duration: 56.1 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_155:#0->conv2d_158:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_156\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_156:#0->conv2d_158:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_157\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.0.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_157:#0->conv2d_158:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_158:#0->add_160:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_161\n", "\n", "\n", "torch.Tensor.div\n", "duration: 448.7 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "add_160:#0->div_161:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_163\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "div_161:#0->module_input_163\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_164\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_163:#0->module_input_164\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_193\n", "\n", "\n", "torch.Tensor.add\n", "duration: 5.0 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_163:#0->add_193:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_167\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 726.5 µs\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_164:#0->group_norm_167:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_165\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_165:#0->group_norm_167:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_166\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_166:#0->group_norm_167:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_169\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_167:#0->module_input_169\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_170\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 478.5 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_169:#0->silu_170:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_172\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_170:#0->module_input_172\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_175\n", "\n", "\n", "torch.conv2d\n", "duration: 54.8 ms\n", "\n", "Input 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_172:#0->conv2d_175:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_173\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_173:#0->conv2d_175:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_174\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_174:#0->conv2d_175:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_177\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_175:#0->module_input_177\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_180\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 1.1 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_177:#0->group_norm_180:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_178\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_178:#0->group_norm_180:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_179\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_179:#0->group_norm_180:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_182\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_180:#0->module_input_182\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_183\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 628.0 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_182:#0->silu_183:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_185\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_183:#0->module_input_185\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_186\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 31.5 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_185:#0->dropout_186:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_188\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_186:#0->module_input_188\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_191\n", "\n", "\n", "torch.conv2d\n", "duration: 56.6 ms\n", "\n", "Input 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_188:#0->conv2d_191:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_189\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_189:#0->conv2d_191:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_190\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.1.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_190:#0->conv2d_191:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_191:#0->add_193:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_194\n", "\n", "\n", "torch.Tensor.div\n", "duration: 3.5 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "add_193:#0->div_194:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_196\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "div_194:#0->module_input_196\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_197\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_196:#0->module_input_197\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_226\n", "\n", "\n", "torch.Tensor.add\n", "duration: 5.2 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_196:#0->add_226:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_200\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 684.0 µs\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_197:#0->group_norm_200:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_198\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_198:#0->group_norm_200:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_199\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_199:#0->group_norm_200:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_202\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_200:#0->module_input_202\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_203\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 460.6 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_202:#0->silu_203:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_205\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_203:#0->module_input_205\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_208\n", "\n", "\n", "torch.conv2d\n", "duration: 54.3 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_205:#0->conv2d_208:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_206\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_206:#0->conv2d_208:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_207\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_207:#0->conv2d_208:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_210\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_208:#0->module_input_210\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_213\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 2.0 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_210:#0->group_norm_213:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_211\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_211:#0->group_norm_213:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_212\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_212:#0->group_norm_213:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_215\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_213:#0->module_input_215\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_216\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 661.1 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_215:#0->silu_216:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_218\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "silu_216:#0->module_input_218\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_219\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 16.2 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_218:#0->dropout_219:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_221\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_219:#0->module_input_221\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_224\n", "\n", "\n", "torch.conv2d\n", "duration: 56.1 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_221:#0->conv2d_224:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_222\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_222:#0->conv2d_224:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_223\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.resnets.2.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_223:#0->conv2d_224:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_224:#0->add_226:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_227\n", "\n", "\n", "torch.Tensor.div\n", "duration: 540.7 µs\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "add_226:#0->div_227:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_229\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "\n", "\n", "div_227:#0->module_input_229\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "interpolate_230\n", "\n", "\n", "torch.nn.functional.interpolate\n", "duration: 4.1 ms\n", "\n", "Input 0\n", "(8, 512, 32, 32)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_229:#0->interpolate_230:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_231\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "interpolate_230:#0->module_input_231\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_234\n", "\n", "\n", "torch.conv2d\n", "duration: 245.7 ms\n", "\n", "Input 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_231:#0->conv2d_234:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_232\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.upsamplers.0.conv.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_232:#0->conv2d_234:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_233\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.0.upsamplers.0.conv.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_233:#0->conv2d_234:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_238\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_234:#0->module_input_238\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_239\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_238:#0->module_input_239\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_240\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_239:#0->module_input_240\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_269\n", "\n", "\n", "torch.Tensor.add\n", "duration: 9.3 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_239:#0->add_269:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_243\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 13.8 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_240:#0->group_norm_243:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_241\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_241:#0->group_norm_243:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_242\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_242:#0->group_norm_243:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_245\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_243:#0->module_input_245\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_246\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 5.1 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_245:#0->silu_246:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_248\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "silu_246:#0->module_input_248\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_251\n", "\n", "\n", "torch.conv2d\n", "duration: 243.1 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_248:#0->conv2d_251:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_249\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_249:#0->conv2d_251:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_250\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_250:#0->conv2d_251:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_253\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_251:#0->module_input_253\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_256\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 28.7 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_253:#0->group_norm_256:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_254\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_254:#0->group_norm_256:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_255\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_255:#0->group_norm_256:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_258\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_256:#0->module_input_258\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_259\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 15.5 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_258:#0->silu_259:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_261\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "silu_259:#0->module_input_261\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_262\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 19.3 µs\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_261:#0->dropout_262:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_264\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_262:#0->module_input_264\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_267\n", "\n", "\n", "torch.conv2d\n", "duration: 258.5 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_264:#0->conv2d_267:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_265\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_265:#0->conv2d_267:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_266\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.0.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_266:#0->conv2d_267:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_267:#0->add_269:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_270\n", "\n", "\n", "torch.Tensor.div\n", "duration: 4.9 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "add_269:#0->div_270:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_272\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "div_270:#0->module_input_272\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_273\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_272:#0->module_input_273\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_302\n", "\n", "\n", "torch.Tensor.add\n", "duration: 11.2 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_272:#0->add_302:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_276\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 6.4 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_273:#0->group_norm_276:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_274\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_274:#0->group_norm_276:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_275\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_275:#0->group_norm_276:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_278\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_276:#0->module_input_278\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_279\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 5.2 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_278:#0->silu_279:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_281\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "silu_279:#0->module_input_281\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_284\n", "\n", "\n", "torch.conv2d\n", "duration: 254.8 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_281:#0->conv2d_284:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_282\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_282:#0->conv2d_284:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_283\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_283:#0->conv2d_284:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_286\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_284:#0->module_input_286\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_289\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 12.0 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_286:#0->group_norm_289:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_287\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_287:#0->group_norm_289:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_288\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_288:#0->group_norm_289:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_291\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_289:#0->module_input_291\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_292\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 5.9 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_291:#0->silu_292:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_294\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "silu_292:#0->module_input_294\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_295\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 19.6 µs\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_294:#0->dropout_295:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_297\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_295:#0->module_input_297\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_300\n", "\n", "\n", "torch.conv2d\n", "duration: 267.0 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_297:#0->conv2d_300:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_298\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_298:#0->conv2d_300:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_299\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.1.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_299:#0->conv2d_300:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_300:#0->add_302:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_303\n", "\n", "\n", "torch.Tensor.div\n", "duration: 4.9 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "add_302:#0->div_303:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_305\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "div_303:#0->module_input_305\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_306\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_305:#0->module_input_306\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_335\n", "\n", "\n", "torch.Tensor.add\n", "duration: 11.1 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_305:#0->add_335:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_309\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 2.9 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_306:#0->group_norm_309:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_307\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_307:#0->group_norm_309:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_308\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_308:#0->group_norm_309:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_311\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_309:#0->module_input_311\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_312\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 5.9 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_311:#0->silu_312:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_314\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "silu_312:#0->module_input_314\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_317\n", "\n", "\n", "torch.conv2d\n", "duration: 258.8 ms\n", "\n", "Input 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_314:#0->conv2d_317:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_315\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.conv1.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_315:#0->conv2d_317:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_316\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.conv1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_316:#0->conv2d_317:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_319\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_317:#0->module_input_319\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_322\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 11.0 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(512,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_319:#0->group_norm_322:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_320\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.norm2.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_320:#0->group_norm_322:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_321\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.norm2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_321:#0->group_norm_322:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_324\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_322:#0->module_input_324\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_325\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 6.2 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_324:#0->silu_325:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_327\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "silu_325:#0->module_input_327\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_328\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 17.9 µs\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_327:#0->dropout_328:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_330\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_328:#0->module_input_330\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_333\n", "\n", "\n", "torch.conv2d\n", "duration: 262.8 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Input 2\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_330:#0->conv2d_333:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_331\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.conv2.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_331:#0->conv2d_333:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_332\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.resnets.2.conv2.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_332:#0->conv2d_333:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_333:#0->add_335:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_336\n", "\n", "\n", "torch.Tensor.div\n", "duration: 5.0 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "add_335:#0->div_336:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_338\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "\n", "\n", "div_336:#0->module_input_338\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "interpolate_339\n", "\n", "\n", "torch.nn.functional.interpolate\n", "duration: 19.0 ms\n", "\n", "Input 0\n", "(8, 512, 64, 64)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_338:#0->interpolate_339:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_340\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "interpolate_339:#0->module_input_340\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_343\n", "\n", "\n", "torch.conv2d\n", "duration: 905.1 ms\n", "\n", "Input 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_340:#0->conv2d_343:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_341\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.upsamplers.0.conv.weight\n", "\n", "Output 0\n", "(512, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_341:#0->conv2d_343:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_342\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.1.upsamplers.0.conv.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_342:#0->conv2d_343:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_347\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_343:#0->module_input_347\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_348\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_347:#0->module_input_348\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_349\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_348:#0->module_input_349\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_378\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_348:#0->module_input_378\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_352\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 34.0 ms\n", "\n", "Input 0\n", "(512,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "Input 2\n", "(512,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_349:#0->group_norm_352:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_350\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.norm1.weight\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_350:#0->group_norm_352:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_351\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.norm1.bias\n", "\n", "Output 0\n", "(512,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_351:#0->group_norm_352:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_354\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_352:#0->module_input_354\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_355\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 27.5 ms\n", "\n", "Input 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_354:#0->silu_355:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_357\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "silu_355:#0->module_input_357\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_360\n", "\n", "\n", "torch.conv2d\n", "duration: 452.1 ms\n", "\n", "Input 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(256,)\n", "torch.float32\n", "\n", "Input 2\n", "(256, 512, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_357:#0->conv2d_360:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_358\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.conv1.weight\n", "\n", "Output 0\n", "(256, 512, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_358:#0->conv2d_360:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_359\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.conv1.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_359:#0->conv2d_360:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_362\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_360:#0->module_input_362\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_365\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 15.9 ms\n", "\n", "Input 0\n", "(256,)\n", "torch.float32\n", "\n", "Input 1\n", "(256,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_362:#0->group_norm_365:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_363\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.norm2.weight\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_363:#0->group_norm_365:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_364\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.norm2.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_364:#0->group_norm_365:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_367\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_365:#0->module_input_367\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_368\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 14.8 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_367:#0->silu_368:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_370\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "silu_368:#0->module_input_370\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_371\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 20.3 µs\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_370:#0->dropout_371:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_373\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_371:#0->module_input_373\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_376\n", "\n", "\n", "torch.conv2d\n", "duration: 220.7 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_373:#0->conv2d_376:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_374\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.conv2.weight\n", "\n", "Output 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_374:#0->conv2d_376:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_375\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.conv2.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_375:#0->conv2d_376:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "add_383\n", "\n", "\n", "torch.Tensor.add\n", "duration: 13.8 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_376:#0->add_383:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "conv2d_381\n", "\n", "\n", "torch.conv2d\n", "duration: 53.7 ms\n", "\n", "Input 0\n", "(8, 512, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(256, 512, 1, 1)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_378:#0->conv2d_381:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_379\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.conv_shortcut.weight\n", "\n", "Output 0\n", "(256, 512, 1, 1)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_379:#0->conv2d_381:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_380\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.0.conv_shortcut.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_380:#0->conv2d_381:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_381:#0->add_383:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "div_384\n", "\n", "\n", "torch.Tensor.div\n", "duration: 13.5 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "add_383:#0->div_384:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_386\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "div_384:#0->module_input_386\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_387\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_386:#0->module_input_387\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_416\n", "\n", "\n", "torch.Tensor.add\n", "duration: 13.9 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_386:#0->add_416:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_390\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 15.8 ms\n", "\n", "Input 0\n", "(256,)\n", "torch.float32\n", "\n", "Input 1\n", "(256,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_387:#0->group_norm_390:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_388\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.norm1.weight\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_388:#0->group_norm_390:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_389\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.norm1.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_389:#0->group_norm_390:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_392\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_390:#0->module_input_392\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_393\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 13.8 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_392:#0->silu_393:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_395\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "silu_393:#0->module_input_395\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_398\n", "\n", "\n", "torch.conv2d\n", "duration: 221.1 ms\n", "\n", "Input 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_395:#0->conv2d_398:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_396\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.conv1.weight\n", "\n", "Output 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_396:#0->conv2d_398:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_397\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.conv1.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_397:#0->conv2d_398:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_400\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_398:#0->module_input_400\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_403\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 15.8 ms\n", "\n", "Input 0\n", "(256,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_400:#0->group_norm_403:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_401\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.norm2.weight\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_401:#0->group_norm_403:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_402\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.norm2.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_402:#0->group_norm_403:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_405\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_403:#0->module_input_405\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_406\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 14.3 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_405:#0->silu_406:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_408\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "silu_406:#0->module_input_408\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_409\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 21.7 µs\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_408:#0->dropout_409:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_411\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_409:#0->module_input_411\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_414\n", "\n", "\n", "torch.conv2d\n", "duration: 220.6 ms\n", "\n", "Input 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_411:#0->conv2d_414:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_412\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.conv2.weight\n", "\n", "Output 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_412:#0->conv2d_414:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_413\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.1.conv2.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_413:#0->conv2d_414:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_414:#0->add_416:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_417\n", "\n", "\n", "torch.Tensor.div\n", "duration: 13.6 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "add_416:#0->div_417:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_419\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "div_417:#0->module_input_419\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_420\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_419:#0->module_input_420\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_449\n", "\n", "\n", "torch.Tensor.add\n", "duration: 14.1 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_419:#0->add_449:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_423\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 16.5 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(256,)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_420:#0->group_norm_423:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_421\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.norm1.weight\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_421:#0->group_norm_423:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_422\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.norm1.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_422:#0->group_norm_423:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_425\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_423:#0->module_input_425\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_426\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 13.8 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_425:#0->silu_426:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_428\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "silu_426:#0->module_input_428\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_431\n", "\n", "\n", "torch.conv2d\n", "duration: 219.4 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 1\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_428:#0->conv2d_431:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_429\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.conv1.weight\n", "\n", "Output 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_429:#0->conv2d_431:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_430\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.conv1.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_430:#0->conv2d_431:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_433\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_431:#0->module_input_433\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_436\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 16.2 ms\n", "\n", "Input 0\n", "(256,)\n", "torch.float32\n", "\n", "Input 1\n", "(256,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_433:#0->group_norm_436:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_434\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.norm2.weight\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_434:#0->group_norm_436:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_435\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.norm2.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_435:#0->group_norm_436:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_438\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_436:#0->module_input_438\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_439\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 14.3 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_438:#0->silu_439:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_441\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "silu_439:#0->module_input_441\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_442\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 21.5 µs\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_441:#0->dropout_442:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_444\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_442:#0->module_input_444\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_447\n", "\n", "\n", "torch.conv2d\n", "duration: 221.1 ms\n", "\n", "Input 0\n", "(256,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Input 2\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_444:#0->conv2d_447:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_445\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.conv2.weight\n", "\n", "Output 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_445:#0->conv2d_447:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_446\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.resnets.2.conv2.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_446:#0->conv2d_447:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_447:#0->add_449:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_450\n", "\n", "\n", "torch.Tensor.div\n", "duration: 13.6 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "add_449:#0->div_450:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_452\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "\n", "\n", "div_450:#0->module_input_452\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "interpolate_453\n", "\n", "\n", "torch.nn.functional.interpolate\n", "duration: 54.9 ms\n", "\n", "Input 0\n", "(8, 256, 128, 128)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_452:#0->interpolate_453:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_454\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "interpolate_453:#0->module_input_454\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_457\n", "\n", "\n", "torch.conv2d\n", "duration: 880.3 ms\n", "\n", "Input 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "Input 2\n", "(256,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_454:#0->conv2d_457:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_455\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.upsamplers.0.conv.weight\n", "\n", "Output 0\n", "(256, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_455:#0->conv2d_457:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_456\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.2.upsamplers.0.conv.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_456:#0->conv2d_457:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_461\n", "\n", "\n", "hidden_states\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_457:#0->module_input_461\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_462\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_461:#0->module_input_462\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_463\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_462:#0->module_input_463\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_492\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_462:#0->module_input_492\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_466\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 61.3 ms\n", "\n", "Input 0\n", "(256,)\n", "torch.float32\n", "\n", "Input 1\n", "(256,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_463:#0->group_norm_466:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_464\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.norm1.weight\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_464:#0->group_norm_466:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_465\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.norm1.bias\n", "\n", "Output 0\n", "(256,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_465:#0->group_norm_466:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_468\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_466:#0->module_input_468\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_469\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 55.1 ms\n", "\n", "Input 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_468:#0->silu_469:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_471\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_469:#0->module_input_471\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_474\n", "\n", "\n", "torch.conv2d\n", "duration: 416.1 ms\n", "\n", "Input 0\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(128, 256, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(128,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_471:#0->conv2d_474:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_472\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.conv1.weight\n", "\n", "Output 0\n", "(128, 256, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_472:#0->conv2d_474:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_473\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.conv1.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_473:#0->conv2d_474:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_476\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_474:#0->module_input_476\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_479\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 30.8 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_476:#0->group_norm_479:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_477\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.norm2.weight\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_477:#0->group_norm_479:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_478\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.norm2.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_478:#0->group_norm_479:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_481\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_479:#0->module_input_481\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_482\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 28.0 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_481:#0->silu_482:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_484\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_482:#0->module_input_484\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_485\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 21.9 µs\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_484:#0->dropout_485:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_487\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_485:#0->module_input_487\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_490\n", "\n", "\n", "torch.conv2d\n", "duration: 212.0 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_487:#0->conv2d_490:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_488\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.conv2.weight\n", "\n", "Output 0\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_488:#0->conv2d_490:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_489\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.conv2.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_489:#0->conv2d_490:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "add_497\n", "\n", "\n", "torch.Tensor.add\n", "duration: 27.3 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_490:#0->add_497:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "conv2d_495\n", "\n", "\n", "torch.conv2d\n", "duration: 60.0 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(128, 256, 1, 1)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 256, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_492:#0->conv2d_495:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_493\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.conv_shortcut.weight\n", "\n", "Output 0\n", "(128, 256, 1, 1)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_493:#0->conv2d_495:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_494\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.0.conv_shortcut.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_494:#0->conv2d_495:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_495:#0->add_497:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "div_498\n", "\n", "\n", "torch.Tensor.div\n", "duration: 26.9 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "add_497:#0->div_498:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_500\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "div_498:#0->module_input_500\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_501\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_500:#0->module_input_501\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_530\n", "\n", "\n", "torch.Tensor.add\n", "duration: 28.0 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_500:#0->add_530:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_504\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 31.3 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(128,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_501:#0->group_norm_504:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_502\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.norm1.weight\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_502:#0->group_norm_504:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_503\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.norm1.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_503:#0->group_norm_504:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_506\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_504:#0->module_input_506\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_507\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 27.5 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_506:#0->silu_507:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_509\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_507:#0->module_input_509\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_512\n", "\n", "\n", "torch.conv2d\n", "duration: 214.7 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_509:#0->conv2d_512:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_510\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.conv1.weight\n", "\n", "Output 0\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_510:#0->conv2d_512:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_511\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.conv1.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_511:#0->conv2d_512:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_514\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_512:#0->module_input_514\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_517\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 31.0 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 2\n", "(128,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_514:#0->group_norm_517:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_515\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.norm2.weight\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_515:#0->group_norm_517:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_516\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.norm2.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_516:#0->group_norm_517:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_519\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_517:#0->module_input_519\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_520\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 28.0 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_519:#0->silu_520:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_522\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_520:#0->module_input_522\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_523\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 21.0 µs\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_522:#0->dropout_523:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_525\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_523:#0->module_input_525\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_528\n", "\n", "\n", "torch.conv2d\n", "duration: 211.6 ms\n", "\n", "Input 0\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_525:#0->conv2d_528:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_526\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.conv2.weight\n", "\n", "Output 0\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_526:#0->conv2d_528:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_527\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.1.conv2.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_527:#0->conv2d_528:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_528:#0->add_530:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_531\n", "\n", "\n", "torch.Tensor.div\n", "duration: 27.0 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "add_530:#0->div_531:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_533\n", "\n", "\n", "input_tensor\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "div_531:#0->module_input_533\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_534\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_533:#0->module_input_534\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "add_563\n", "\n", "\n", "torch.Tensor.add\n", "duration: 27.6 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_533:#0->add_563:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_537\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 31.2 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_534:#0->group_norm_537:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_535\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.norm1.weight\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_535:#0->group_norm_537:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_536\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.norm1.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_536:#0->group_norm_537:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_539\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_537:#0->module_input_539\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_540\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 27.6 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_539:#0->silu_540:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_542\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_540:#0->module_input_542\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_545\n", "\n", "\n", "torch.conv2d\n", "duration: 211.7 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_542:#0->conv2d_545:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_543\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.conv1.weight\n", "\n", "Output 0\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_543:#0->conv2d_545:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_544\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.conv1.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_544:#0->conv2d_545:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_547\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "conv2d_545:#0->module_input_547\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_550\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 31.2 ms\n", "\n", "Input 0\n", "(128,)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_547:#0->group_norm_550:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_548\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.norm2.weight\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_548:#0->group_norm_550:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_549\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.norm2.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_549:#0->group_norm_550:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_552\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_550:#0->module_input_552\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_553\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 27.9 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_552:#0->silu_553:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_555\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_553:#0->module_input_555\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "dropout_556\n", "\n", "\n", "torch.nn.functional.dropout\n", "duration: 21.0 µs\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_555:#0->dropout_556:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_558\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "dropout_556:#0->module_input_558\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_561\n", "\n", "\n", "torch.conv2d\n", "duration: 211.5 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "Input 2\n", "(128,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_558:#0->conv2d_561:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_559\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.conv2.weight\n", "\n", "Output 0\n", "(128, 128, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_559:#0->conv2d_561:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_560\n", "\n", "\n", "torch.nn.Parameter\n", "up_blocks.3.resnets.2.conv2.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_560:#0->conv2d_561:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "conv2d_561:#0->add_563:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "div_564\n", "\n", "\n", "torch.Tensor.div\n", "duration: 27.0 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "add_563:#0->div_564:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_567\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "div_564:#0->module_input_567\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "group_norm_570\n", "\n", "\n", "torch.nn.functional.group_norm\n", "duration: 31.9 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 1\n", "(128,)\n", "torch.float32\n", "\n", "Input 2\n", "(128,)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_567:#0->group_norm_570:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_568\n", "\n", "\n", "torch.nn.Parameter\n", "conv_norm_out.weight\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_568:#0->group_norm_570:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_569\n", "\n", "\n", "torch.nn.Parameter\n", "conv_norm_out.bias\n", "\n", "Output 0\n", "(128,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_569:#0->group_norm_570:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n", "module_input_572\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "group_norm_570:#0->module_input_572\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "silu_573\n", "\n", "\n", "torch.nn.functional.silu\n", "duration: 27.7 ms\n", "\n", "Input 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_572:#0->silu_573:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "module_input_575\n", "\n", "\n", "input\n", "\n", "Output 0\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "silu_573:#0->module_input_575\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "conv2d_578\n", "\n", "\n", "torch.conv2d\n", "duration: 25.2 ms\n", "\n", "Input 0\n", "(3,)\n", "torch.float32\n", "\n", "Input 1\n", "(8, 128, 256, 256)\n", "torch.float32\n", "\n", "Input 2\n", "(3, 128, 3, 3)\n", "torch.float32\n", "\n", "Output 0\n", "(8, 3, 256, 256)\n", "torch.float32\n", "\n", "\n", "\n", "module_input_575:#0->conv2d_578:$0\n", "\n", "\n", "0 -> 0\n", "\n", "\n", "\n", "torch_parameter_576\n", "\n", "\n", "torch.nn.Parameter\n", "conv_out.weight\n", "\n", "Output 0\n", "(3, 128, 3, 3)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_576:#0->conv2d_578:$1\n", "\n", "\n", "0 -> 1\n", "\n", "\n", "\n", "torch_parameter_577\n", "\n", "\n", "torch.nn.Parameter\n", "conv_out.bias\n", "\n", "Output 0\n", "(3,)\n", "torch.float32\n", "\n", "\n", "\n", "torch_parameter_577:#0->conv2d_578:$2\n", "\n", "\n", "0 -> 2\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Set user inputs:\n", "seed = 0 #@param {type:\"number\"}\n", "torch.manual_seed(seed)\n", "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n", "cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n", "class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n", "samples_per_row = 4 #@param {type:\"number\"}\n", "\n", "# Create diffusion object:\n", "diffusion = create_diffusion(str(num_sampling_steps))\n", "\n", "# Create sampling noise:\n", "n = len(class_labels)\n", "z = torch.randn(n, 4, latent_size, latent_size, device=device)\n", "y = torch.tensor(class_labels, device=device)\n", "\n", "# Setup classifier-free guidance:\n", "z = torch.cat([z, z], 0)\n", "y_null = torch.tensor([1000] * n, device=device)\n", "y = torch.cat([y, y_null], 0)\n", "model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n", "\n", "# Sample images:\n", "samples = diffusion.p_sample_loop(\n", " model.forward_with_cfg, z.shape, z, clip_denoised=False, \n", " model_kwargs=model_kwargs, progress=True, device=device\n", ")\n", "samples, _ = samples.chunk(2, dim=0) # Remove null class samples\n", "\n", "# Here is where we capture the graph into an svg!\n", "with ttnn.tracer.trace():\n", " samples = vae.decode(samples / 0.18215).sample\n", " \n", "ttnn.tracer.visualize(samples, file_name=ttnn.TMP_DIR / \"dit_model_graph.svg\")\n", "\n", "# Optionally Save and display images from DiT:\n", "# save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n", "# normalize=True, value_range=(-1, 1))\n", "# samples = Image.open(\"sample.png\")\n", "# display(samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Display the graph" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import SVG, display\n", "def show_svg():\n", " return SVG(filename=ttnn.TMP_DIR / \"dit_model_graph.svg\")\n", "\n", "show_svg()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }