Mixed Precision (Per-Tensor Weight Dtype Overrides)

When uniform weight conversion (e.g. experimental_weight_dtype: "bfp_bf8") causes accuracy degradation in specific layers, you can specify dtype overrides on a per-tensor basis. This lets you keep sensitive layers at higher precision (e.g. bf16) while converting the rest to a lower format (e.g. bfp_bf8 or bfp_bf4).

Note: Currently only matmul/linear layer weight overrides are propagated and respected. Convolution weights on lower data types are not yet supported through the compiler.

Build a dict mapping parameter names (or glob patterns) to target dtypes and call apply_weight_dtype_overrides():

from tt_torch import apply_weight_dtype_overrides

# Override specific weights by name.
apply_weight_dtype_overrides(model, {
    "fc2.weight": "bfp_bf8",
})

# Or use glob patterns to target groups of layers.
apply_weight_dtype_overrides(model, {
    "model.layers.*.mlp.gate_proj.weight": "bfp_bf4",
    "model.layers.*.mlp.up_proj.weight": "bfp_bf4",
    "model.layers.0.self_attn.q_proj.weight": "bf16",
})

# A "default" key applies to all weights, with specific overrides taking precedence.
apply_weight_dtype_overrides(model, {
    "default": "bfp_bf8",
    "model.layers.0.self_attn.q_proj.weight": "bf16",
})

Call this after creating the model and before torch.compile. See examples/pytorch/mnist_performant.py for a complete working example that lowers the last linear layer weight to bfp_bf8.

Method 2: JSON config + CLI

For large models with hundreds of weight parameters, use the tt-gen-weight-template CLI to generate a JSON template, then edit it:

tt-gen-weight-template --loader third_party/tt_forge_models/llama/causal_lm/pytorch/loader.py

CLI options:

OptionDescription
--loader(Required) Path to a model loader.py file
--variantVariant enum name (e.g. LLAMA_3_1_8B). Defaults to the loader's DEFAULT_VARIANT
--list-variantsList available variants and exit
--default-dtypeDefault dtype for all entries: bfp_bf8 (default), bfp_bf4, or bf16
--output-dirOverride output directory (default: mixed_precision_configs/ next to the loader)
--auto-classtransformers Auto* class to use (default: AutoModelForCausalLM)

The output is a JSON file mapping every weight parameter to a dtype string. Edit it to fine-tune per-layer dtypes:

{
    "model.layers.*.mlp.gate_proj.weight": "bfp_bf4",
    "model.layers.*.mlp.up_proj.weight": "bfp_bf4",
    "model.layers.0.self_attn.q_proj.weight": "bf16"
}

Then pass the JSON file path to apply_weight_dtype_overrides():

apply_weight_dtype_overrides(model, "path/to/config.json")

Auto-discovery for tests and benchmarks: JSON configs placed in mixed_precision_configs/ next to the model's loader.py are automatically discovered by the model test runner (tests/runner/test_models.py) and LLM benchmarks (tests/benchmark/benchmarks/llm_benchmark.py).

In-model annotation

If you control the model code, you can annotate weights directly in the forward pass using torch.ops.tt.weight_dtype_override:

def forward(self, x):
    w = torch.ops.tt.weight_dtype_override(self.fc.weight, "bfp_bf8")
    return torch.matmul(x, w)

This is useful for custom models or when you need dtype overrides to interact with other operations (e.g. tensor-parallel sharding). In practice this is rarely needed — the dict and JSON methods above cover most use cases.

How it works

Overrides are applied transparently via torch.nn.utils.parametrize — there is no need to edit model forward functions or manually insert custom ops (unless using in-model annotation). The apply_weight_dtype_overrides() function registers a parametrization on each matched weight that injects a torch.ops.tt.weight_dtype_override call. During compilation, a C++ frontend pass extracts these annotations and propagates them as per-argument attributes for the tt-mlir weight dtype conversion pass.

Note: If apply_weight_dtype_overrides() is called multiple times on the same model (e.g. first with a dict, then with a JSON config), the first call has priority for any given weight — already-parametrized weights are not overridden by subsequent calls.