Test Infra
Test infra consists of main “tester” classes and a few helper ones. Its main goal is making test writing easy.
Here is a brief class diagram of the infra:

Op and Graph Tests
Op tester exposes easy to use functions:
run_op_test(...)
run_op_test_with_random_inputs(...)
They wrap the instantiation of the OpTester and all the underlying complexity. Users just need to
pass the op (Python function) they want to test to one of these functions like this:
def test_add(x_shape: tuple, y_shape: tuple):
def add(x: jax.Array, y: jax.Array) -> jax.Array:
return jnp.add(x, y)
run_op_test_with_random_inputs(add, [x_shape, y_shape])
That is all that is required.
GraphTester is currently identical to OpTester and also exposes
run_graph_test(...)
run_graph_test_with_random_inputs(...)
which are meant to be used in the same way as for op tests.
Model Tests
Test a model by inheriting one of the *ModelTester classes and overriding the required methods.
Read the docstring of the appropriate class you want to inherit for more information.
JAX Model Example
First, you define a model:
class MNISTMLPModel(nn.Module):
hidden_sizes: tuple[int]
@nn.compact
def __call__(self, x: jax.Array):
x = x.reshape((x.shape[0], -1))
for h in self.hidden_sizes:
x = nn.Dense(features=h)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.softmax(x)
return x
Then you define a tester by inheriting JaxModelTester:
class MNISTMLPTester(JaxModelTester):
def __init__(
self,
hidden_sizes: Sequence[int],
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._hidden_sizes = hidden_sizes
super().__init__(comparison_config, run_mode)
# @override
def _get_model(self) -> nn.Module:
return MNISTMLPModel(self._hidden_sizes)
# @override
def _get_forward_method_name(self) -> str:
return "apply"
# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
key = jax.random.PRNGKey(37)
img = jax.random.normal(key, (4, 28, 28, 1)) # B, H, W, C
# Channels is 1 as MNIST is in grayscale.
return img
# @override
def _get_forward_method_args(self):
inp = self._get_input_activations()
parameters = self._model.init(jax.random.PRNGKey(42), inp)
return [parameters, inp]
Finally, you run the test:
@pytest.fixture
def inference_tester(request) -> MNISTMLPTester:
return MNISTMLPTester(request.param)
@pytest.mark.parametrize(
"inference_tester", [(256, 128, 64)], indirect=True, ids=lambda val: f"{val}"
)
def test_mnist_mlp_inference(inference_tester: MNISTMLPTester):
inference_tester.test()
Serialization and FileCheck
Serializing IR to Disk
To serialize compilation artifacts (MLIR, TTNN IRs) to disk, use the --serialize flag:
pytest path/to/test.py::test_name --serialize
Your test must pass the request fixture for serialization to work:
For op/graph tests:
def test_my_op(request):
run_op_test(MyOp(), [torch.randn(32, 32)], request=request)
For model tests:
def test_my_model(model_tester: MyModelTester, request):
model_tester.test(request=request)
Artifacts are written to output_artifact/<sanitized_test_name>/.
Running FileCheck
To verify IR transformations, use the @pytest.mark.filecheck decorator:
@pytest.mark.filecheck(["add.ttnn.mlir", "matmul_fusion.ttir.mlir"])
def test_my_op(request):
run_op_test(MyOp(), [torch.randn(32, 32)], request=request)
FileCheck automatically serializes artifacts, runs pattern matching, and fails on mismatches.
For pattern file syntax and conventions, see tests/filecheck/filecheck.md.