Converting torch Model to ttnn

Note

This particular example only works on Grayskull.

Not all converted models may be functional on all Tenstorrent hardware (Grayskull, Wormhole, or others). Functionality is on a case-by-case basis.

There are many ways to convert a torch model to ttnn.

This is the recommend approach:
  1. Re-writing torch model using functional torch APIs

  2. Converting operations of the functional torch model to ttnn operations

  3. Optimizing functional ttnn model

Step 1 - Rewriting the Model

Given a torch model, it can be rewritten using functional torch APIs.

For example, given the following torch model:

# From transformers.models.bert.modeling_bert.BertIntermediate

import torch

class BertIntermediate(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = torch.nn.Linear(config.hidden_size, config.intermediate_size)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.nn.functional.gelu(hidden_states)
        return hidden_states

Following TDD, the first step is to write a test for the model:

import pytest
import torch
import transformers

import ttnn
import torch_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32)
    torch_output = model(torch_hidden_states) # Golden output

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model, # Function to initialize the model
        convert_to_ttnn=lambda *_: False, # Keep the weights as torch tensors
    )

    output = torch_bert.bert_intermediate(
        torch_hidden_states,
        parameters=parameters,
    )

    assert_with_pcc(torch_output, output, 0.9999)

And finally, the model can be rewritten using functional torch APIs to make the test pass:

# torch_bert.py

def bert_intermediate(hidden_states, *, parameters):
    hidden_states = hidden_states @ parameters.dense.weight
    hidden_states = hidden_states + parameters.dense.bias
    hidden_states = torch.nn.functional.gelu(hidden_states)
    return hidden_states

Note

parameters is a dictionary which sets its keys as its attributes, so both parameters["dense"]["weight"] and parameters.dense.weight are valid.

The structure of parameters follows the structure of the model class. In this case, BertIntermediate has a single attribute dense, so parameters has a single attribute dense. And dense is a torch.nn.Linear object, so it in turn has two attributes weight and bias.

Step 2 - Switching to ttnn Operations

Starting off with the test:

import pytest
import torch
import transformers

import ttnn
import ttnn_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(device, model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1)
    torch_output = model(torch_hidden_states)

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model,
        device=device, # Device to put the parameters on
    )

    hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    output = ttnn_bert.bert_intermediate(
        hidden_states,
        parameters=parameters,
    )
    output = ttnn.to_torch(output)

    assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.999)

Then implementing the function using ttnn operations:

# ttnn_bert.py

import ttnn

def bert_intermediate(
    hidden_states,
    *,
    parameters,
):
    output = hidden_states @ parameters.dense.weight
    output = output + parameters.dense.bias
    output = ttnn.gelu(output)
    return output

Step 3 - Optimizing the Model

Starting off with the test:

import pytest
import torch
import transformers

import ttnn
import ttnn_bert

from models.utility_functions import torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc

@pytest.mark.parametrize("model_name", ["phiyodr/bert-large-finetuned-squad2"])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("sequence_size", [384])
def test_bert_intermediate(device, model_name, batch_size, sequence_size):
    torch.manual_seed(0)

    config = transformers.BertConfig.from_pretrained(model_name)
    model = transformers.models.bert.modeling_bert.BertIntermediate(config).eval()

    torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1)
    torch_output = model(torch_hidden_states)

    parameters = preprocess_model_parameters(
        initialize_model=lambda: model,
        device=device, # Device to put the parameters on
        custom_preprocessor=ttnn_bert.custom_preprocessor, # Use custom_preprocessor to set ttnn.bfloat8_b data type for the weights and biases
    )

    hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    output = ttnn_bert.bert_intermediate(
        hidden_states,
        parameters=parameters,
    )
    output = ttnn.to_torch(output)

    assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.999)

And the optimized model can be something like this:

# ttnn_optimized_bert.py

import ttnn
import transformers

def custom_preprocessor(model, name):

    parameters = {}
    if isinstance(model, transformers.models.bert.modeling_bert.BertIntermediate):
        parameters["weight"] = ttnn.model_preprocessing.preprocess_linear_weight(model.weight, dtype=ttnn.bfloat8_b)
        parameters["bias"] = ttnn.model_preprocessing.preprocess_linear_bias(model.bias, dtype=ttnn.bfloat8_b)

    return parameters

def bert_intermediate(
    hidden_states,
    *,
    parameters,
    num_cores_x,
):
    batch_size, *_ = hidden_states.shape

    num_cores_x = 12
    output = ttnn.linear(
        hidden_states,
        ff1_weight,
        bias=ff1_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG, # Put the output into local core memory
        core_grid=(batch_size, num_cores_x), # Specify manual core grid to get the best possible performance
        activation="gelu", # Fuse Gelu
    )
    return True

More examples

Additional examples can be found in the integration tests.