Converting PyTorch Model to TT-NN

Note

This particular example only works on Grayskull.

Not all converted models may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others). Functionality is on a case-by-case basis.

There are many ways to convert a torch model to ttnn.

This is the recommend approach:
  1. Re-writing torch model using functional torch APIs

  2. Converting operations of the functional torch model to ttnn operations

  3. Optimizing functional ttnn model

Step 1 - Rewriting the Model

Given a torch model, it can be rewritten using functional torch APIs.

For example, given the following torch model:

# From transformers.models.bert.modeling_bert.BertIntermediate

import torch

class BertIntermediate(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = torch.nn.Linear(config.hidden_size, config.intermediate_size)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.nn.functional.gelu(hidden_states)
        return hidden_states

Following TDD, the first step is to write a test for the model:

import pytest
import torch
import transformers

import ttnn
import torch_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32)
    torch_output = model(torch_hidden_states) # Golden output

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model, # Function to initialize the model
        convert_to_ttnn=lambda *_: False, # Keep the weights as torch tensors
    )

    output = torch_bert.bert_intermediate(
        torch_hidden_states,
        parameters=parameters,
    )

    assert_with_pcc(torch_output, output, 0.9999)

And finally, the model can be rewritten using functional torch APIs to make the test pass:

# torch_bert.py

def bert_intermediate(hidden_states, *, parameters):
    hidden_states = hidden_states @ parameters.dense.weight
    hidden_states = hidden_states + parameters.dense.bias
    hidden_states = torch.nn.functional.gelu(hidden_states)
    return hidden_states

Note

parameters is a dictionary which sets its keys as its attributes, so both parameters["dense"]["weight"] and parameters.dense.weight are valid.

The structure of parameters follows the structure of the model class. In this case, BertIntermediate has a single attribute dense, so parameters has a single attribute dense. And dense is a torch.nn.Linear object, so it in turn has two attributes weight and bias.

Step 2 - Switching to ttnn Operations

Starting off with the test:

import pytest
import torch
import transformers

import ttnn
import ttnn_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(device, model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1)
    torch_output = model(torch_hidden_states)

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model,
        device=device, # Device to put the parameters on
    )

    hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    output = ttnn_bert.bert_intermediate(
        hidden_states,
        parameters=parameters,
    )
    output = ttnn.to_torch(output)

    assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.999)

Then implementing the function using ttnn operations:

# ttnn_bert.py

import ttnn

def bert_intermediate(
    hidden_states,
    *,
    parameters,
):
    output = hidden_states @ parameters.dense.weight
    output = output + parameters.dense.bias
    output = ttnn.gelu(output)
    return output

Step 3 - Optimizing the Model

Starting off with the test:

import pytest
import torch
import transformers

import ttnn
import ttnn_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(device, model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1)
    torch_output = model(torch_hidden_states)

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model,
        device=device, # Device to put the parameters on
        custom_preprocessor=ttnn_bert.custom_preprocessor, # Use custom_preprocessor to set ttnn.bfloat8_b data type for the weights and biases
    )

    hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    output = ttnn_bert.bert_intermediate(
        hidden_states,
        parameters=parameters,
    )
    output = ttnn.to_torch(output)

    assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.999)

And the optimized model can be something like this:

# ttnn_optimized_bert.py

import ttnn
import transformers

def custom_preprocessor(model, name):

    parameters = {}
    if isinstance(model, transformers.models.bert.modeling_bert.BertIntermediate):
        parameters["weight"] = ttnn.model_preprocessing.preprocess_linear_weight(model.weight, dtype=ttnn.bfloat8_b)
        parameters["bias"] = ttnn.model_preprocessing.preprocess_linear_bias(model.bias, dtype=ttnn.bfloat8_b)

    return parameters

def bert_intermediate(
    hidden_states,
    *,
    parameters,
    num_cores_x,
):
    batch_size, *_ = hidden_states.shape

    num_cores_x = 12
    output = ttnn.linear(
        hidden_states,
        ff1_weight,
        bias=ff1_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG, # Put the output into local core memory
        core_grid=(batch_size, num_cores_x), # Specify manual core grid to get the best possible performance
        activation="gelu", # Fuse Gelu
    )
    return True

More examples

Additional examples can be found in the integration tests.