Test Infra
Test infra consists of main "tester" classes and a few helper ones. Its main goal is making test writing easy.
Here is a brief class diagram of the infra:

Op and Graph Tests
Op tester exposes easy to use functions:
run_op_test(...)
run_op_test_with_random_inputs(...)
They wrap the instantiation of the OpTester and all the underlying complexity. User just need to
pass the op (python function) they want to test to one of these functions like this:
def test_add(x_shape: tuple, y_shape: tuple):
def add(x: jax.Array, y: jax.Array) -> jax.Array:
return jnp.add(x, y)
run_op_test_with_random_inputs(add, [x_shape, y_shape])
and that's it.
GraphTester is at the moment identical to OpTester, and it too exposes
run_graph_test(...)
run_graph_test_with_random_inputs(...)
which are meant to be used in the same way as for op tests.
Model Tests
Models are tested by inheriting one of *ModelTester classes and overriding required methods.
Please read docstring of appropriate class you want to inherit for more information.
Jax Model Example
First, you define a model:
class MNISTMLPModel(nn.Module):
hidden_sizes: tuple[int]
@nn.compact
def __call__(self, x: jax.Array):
x = x.reshape((x.shape[0], -1))
for h in self.hidden_sizes:
x = nn.Dense(features=h)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.softmax(x)
return x
Then you define a tester by inheriting JaxModelTester:
class MNISTMLPTester(JaxModelTester):
def __init__(
self,
hidden_sizes: Sequence[int],
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._hidden_sizes = hidden_sizes
super().__init__(comparison_config, run_mode)
# @override
def _get_model(self) -> nn.Module:
return MNISTMLPModel(self._hidden_sizes)
# @override
def _get_forward_method_name(self) -> str:
return "apply"
# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
key = jax.random.PRNGKey(37)
img = jax.random.normal(key, (4, 28, 28, 1)) # B, H, W, C
# Channels is 1 as MNIST is in grayscale.
return img
# @override
def _get_forward_method_args(self):
inp = self._get_input_activations()
parameters = self._model.init(jax.random.PRNGKey(42), inp)
return [parameters, inp]
Finally, you run the test:
@pytest.fixture
def inference_tester(request) -> MNISTMLPTester:
return MNISTMLPTester(request.param)
@pytest.mark.parametrize(
"inference_tester", [(256, 128, 64)], indirect=True, ids=lambda val: f"{val}"
)
def test_mnist_mlp_inference(inference_tester: MNISTMLPTester):
inference_tester.test()
Serialization and FileCheck
Serializing IR to Disk
To serialize compilation artifacts (MLIR, TTNN IRs) to disk, use the --serialize flag:
pytest path/to/test.py::test_name --serialize
Your test must pass the request fixture for serialization to work:
For op/graph tests:
def test_my_op(request):
run_op_test(MyOp(), [torch.randn(32, 32)], request=request)
For model tests:
def test_my_model(model_tester: MyModelTester, request):
model_tester.test(request=request)
Artifacts are written to output_artifact/<sanitized_test_name>/.
Running FileCheck
To verify IR transformations, use the @pytest.mark.filecheck decorator:
@pytest.mark.filecheck(["add.ttnn.mlir", "matmul_fusion.ttir.mlir"])
def test_my_op(request):
run_op_test(MyOp(), [torch.randn(32, 32)], request=request)
FileCheck automatically serializes artifacts, runs pattern matching, and fails on mismatches.
For pattern file syntax and conventions, see tests/filecheck/filecheck.md.