Resnet Block
[1]:
import torch
import ttnn
from ttnn.tracer import trace, visualize
from ttnn.model_preprocessing import preprocess_model_parameters, fold_batch_norm2d_into_conv2d
2024-09-17 03:34:11.168 | DEBUG | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/ubuntu/.cache/ttnn,model_cache_path=/home/ubuntu/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}
2024-09-17 03:34:11.240 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.pearson_correlation_coefficient be migrated to C++?
2024-09-17 03:34:11.242 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.Conv1d be migrated to C++?
2024-09-17 03:34:11.246 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.conv2d be migrated to C++?
2024-09-17 03:34:11.247 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.reshape be migrated to C++?
2024-09-17 03:34:11.248 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.unsqueeze_to_4D be migrated to C++?
2024-09-17 03:34:11.249 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.squeeze be migrated to C++?
2024-09-17 03:34:11.249 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.from_torch be migrated to C++?
2024-09-17 03:34:11.250 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.to_torch be migrated to C++?
2024-09-17 03:34:11.251 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.to_device be migrated to C++?
2024-09-17 03:34:11.252 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.from_device be migrated to C++?
2024-09-17 03:34:11.253 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.allocate_tensor_on_device be migrated to C++?
2024-09-17 03:34:11.254 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.copy_host_to_device_tensor be migrated to C++?
2024-09-17 03:34:11.254 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.deallocate be migrated to C++?
2024-09-17 03:34:11.255 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.reallocate be migrated to C++?
2024-09-17 03:34:11.256 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.load_tensor be migrated to C++?
2024-09-17 03:34:11.257 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.dump_tensor be migrated to C++?
2024-09-17 03:34:11.258 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.as_tensor be migrated to C++?
2024-09-17 03:34:11.262 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.avg_pool2d be migrated to C++?
2024-09-17 03:34:11.266 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.conv2d be migrated to C++?
2024-09-17 03:34:11.268 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.avg_pool2d be migrated to C++?
2024-09-17 03:34:11.269 | WARNING | ttnn.decorators:operation_decorator:768 - Should ttnn.Conv1d be migrated to C++?
[2]:
torch.manual_seed(0)
device_params = {"l1_small_size": 24576}
device = ttnn.CreateDevice(device_id=0, **device_params)
Device | INFO | Opening user mode device driver
2024-09-17 03:34:11.310 | INFO | SiliconDriver - Detected 1 PCI device : [0]
2024-09-17 03:34:11.324 | WARNING | SiliconDriver - init_detect_tt_device_numanodes(): Could not determine NumaNodeSet for TT device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-09-17 03:34:11.324 | WARNING | SiliconDriver - Could not find NumaNodeSet for TT Device (physical_device_id: 0 pci_bus_id: 0000:07:00.0)
2024-09-17 03:34:11.325 | WARNING | SiliconDriver - bind_area_memory_nodeset(): Unable to determine TT Device to NumaNode mapping for physical_device_id: 0. Skipping membind.
---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
2024-09-17 03:34:11.363 | INFO | SiliconDriver - Software version 6.0.0, Ethernet FW version 6.9.0 (Device 0)
Metal | INFO | Initializing device 0. Program cache is NOT enabled
Metal | INFO | AI CLK for device 0 is: 1000 MHz
Torch Module (from torchvision)
[3]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> torch.nn.Conv2d:
"""3x3 convolution with padding"""
return torch.nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
class TorchBasicBlock(torch.nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample=None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer=None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = torch.nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
Create torch module and preprocess it to get ttnn parameters
[4]:
torch_model = TorchBasicBlock(inplanes=64, planes=64, stride=1)
torch_input_tensor = torch.rand((8, 64, 56, 56), dtype=torch.float32)
state_dict = torch_model.state_dict()
def create_custom_preprocessor(device):
def custom_preprocessor(torch_model, name, ttnn_module_args):
parameters = {}
conv_weight_1, conv_bias_1 = fold_batch_norm2d_into_conv2d(torch_model.conv1, torch_model.bn1)
parameters["conv1"] = {}
parameters["conv2"] = {}
parameters["conv1"]["weight"] = ttnn.from_torch(conv_weight_1, dtype=ttnn.bfloat16)
parameters["conv1"]["bias"] = ttnn.from_torch(torch.reshape(conv_bias_1, (1, 1, 1, -1)), dtype=ttnn.bfloat16)
conv_weight_2, conv_bias_2 = fold_batch_norm2d_into_conv2d(torch_model.conv2, torch_model.bn2)
parameters["conv2"]["weight"] = ttnn.from_torch(conv_weight_2, dtype=ttnn.bfloat16)
parameters["conv2"]["bias"] = ttnn.from_torch(torch.reshape(conv_bias_2, (1, 1, 1, -1)), dtype=ttnn.bfloat16)
return parameters
return custom_preprocessor
parameters = preprocess_model_parameters(
initialize_model=lambda: torch_model, custom_preprocessor=create_custom_preprocessor(device), device=None
)
2024-09-17 03:34:12.682 | DEBUG | ttnn:manage_config:90 - Set ttnn.CONFIG.enable_logging to False
2024-09-17 03:34:12.683 | DEBUG | ttnn:manage_config:90 - Set ttnn.CONFIG.enable_comparison_mode to False
2024-09-17 03:34:12.684 | WARNING | ttnn.model_preprocessing:from_torch:499 - ttnn: model cache can be enabled by passing model_name argument to preprocess_model[_parameters] and setting env variable TTNN_CONFIG_OVERRIDES='{"enable_model_cache": true}'
2024-09-17 03:34:12.684 | WARNING | ttnn.model_preprocessing:_initialize_model_and_preprocess_parameters:449 - Putting the model in eval mode
2024-09-17 03:34:12.717 | DEBUG | ttnn:manage_config:93 - Restored ttnn.CONFIG.enable_comparison_mode to False
2024-09-17 03:34:12.718 | DEBUG | ttnn:manage_config:93 - Restored ttnn.CONFIG.enable_logging to False
Display the parameters of the module
[5]:
parameters
[5]:
{
conv1: {
weight: ttnn.Tensor(shape=ttnn.Shape([64, 64, 3, 3]), layout=Layout.ROW_MAJOR, dtype=DataType.BFLOAT16),
bias: ttnn.Tensor(shape=ttnn.Shape([1, 1, 1, 64]), layout=Layout.ROW_MAJOR, dtype=DataType.BFLOAT16)
},
conv2: {
weight: ttnn.Tensor(shape=ttnn.Shape([64, 64, 3, 3]), layout=Layout.ROW_MAJOR, dtype=DataType.BFLOAT16),
bias: ttnn.Tensor(shape=ttnn.Shape([1, 1, 1, 64]), layout=Layout.ROW_MAJOR, dtype=DataType.BFLOAT16)
}
}
Display the traced torch graph
[6]:
from IPython.display import SVG
SVG('/tmp/ttnn/model_resnet_block_graph.svg')
[6]:
Implement ttnn version of the module. Pass in the parameters into the constructor.
[7]:
class Conv:
def __init__(
self,
conv_params,
input_shape,
parameters,
*,
act_block_h=None,
reshard=False,
deallocate=True,
height_sharding=True,
activation="",
groups=1,
dtype=ttnn.bfloat16,
) -> None:
self.weights = parameters["weight"]
if "bias" in parameters:
self.bias = parameters["bias"]
else:
self.bias = None
self.kernel_size = (self.weights.shape[2], self.weights.shape[3])
self.conv_params = conv_params
self.out_channels = self.weights.shape[0]
self.act_block_h = act_block_h
self.reshard = reshard
self.deallocate = deallocate
self.activation = activation
self.groups = groups
self.dtype = dtype
self.shard_layout = (
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
)
self.input_shape = input_shape
def __call__(self, device, input_tensor):
conv_config = ttnn.Conv2dConfig(
dtype=self.dtype,
weights_dtype=ttnn.bfloat16,
math_fidelity=ttnn.MathFidelity.LoFi,
activation=self.activation,
shard_layout=self.shard_layout,
fp32_dest_acc_enabled=False,
packer_l1_accum_enabled=False,
input_channels_alignment=16 if self.input_shape[3] < 16 else 32,
deallocate_activation=self.deallocate,
)
if self.act_block_h is not None:
conv_config.act_block_h_override = self.act_block_h
[output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.weights,
bias_tensor=self.bias,
in_channels=self.input_shape[3],
out_channels=self.out_channels,
device=device,
kernel_size=self.kernel_size,
stride=(self.conv_params[0], self.conv_params[1]),
padding=(self.conv_params[2], self.conv_params[3]),
batch_size=self.input_shape[0],
input_height=self.input_shape[1],
input_width=self.input_shape[2],
conv_config=conv_config,
groups=self.groups,
return_output_size=True,
return_prepared_device_weights=True
)
return output_tensor
class TTNNBasicBlock:
def __init__(
self,
parameters,
) -> None:
self.conv1 = Conv([1, 1, 1, 1], [8, 56, 56, 64], parameters=parameters["conv1"])
self.conv2 = Conv([1, 1, 1, 1], [8, 56, 56, 64], parameters=parameters["conv2"])
if "downsample" in parameters:
self.downsample = parameters.downsample
else:
self.downsample = None
def __call__(self, x, device):
identity = x
out = self.conv1(device, x)
out = ttnn.relu(out)
out = self.conv2(device, out)
if self.downsample is not None:
identity = self.downsample(x)
identity = ttnn.reshape(identity, out.shape)
identity = ttnn.to_memory_config(
identity,
memory_config=ttnn.get_memory_config(out),
dtype=ttnn.bfloat16,
)
out = ttnn.add(out, identity, memory_config=ttnn.get_memory_config(out))
out = ttnn.relu(out)
return out
def run_model(model, torch_input_tensor, device):
input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1))
input_tensor = ttnn.from_torch(input_tensor, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=device)
output_tensor = model(input_tensor, device)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))
output_tensor = torch.reshape(output_tensor, torch_input_tensor.shape)
output_tensor = output_tensor.to(torch_input_tensor.dtype)
return output_tensor
Run ttnn module and display the traced graph
[8]:
ttnn_model = TTNNBasicBlock(parameters)
# with ttnn.tracer.trace(): #Issue 12638
output_tensor = run_model(ttnn_model, torch_input_tensor, device=device)
# ttnn.tracer.visualize(output_tensor)
[9]:
ttnn.close_device(device)
Metal | INFO | Closing device 0
Metal | INFO | Disabling and clearing program cache on device 0