stablehlo-builder
stablehlo-builder
is a tool for creating stableHLO operations. It provides support for MLIR modules to be generated from user-constructed ops.
Getting started
StableHLOBuilder
is a builder class providing the API for creating stableHLO ops. The python package builder
contains everything needed to create ops through a StableHLOBuilder
object. builder.stablehlo.stablehlo_utils
contains the APIs for wrapping op-creating-functions into MLIR modules and flatbuffers files.
from builder.stablehlo.stablehlo_builder import StableHLOBuilder
from builder.stablehlo.stablehlo_utils import build_stablehlo_module
Creating a StableHLO module
build_stablehlo_module
defines an MLIR module specified as a python function. It wraps fn
in a MLIR FuncOp then wraps that in an MLIR module, and finally ties arguments of that FuncOp to test function inputs. It will instantiate and pass a StableHLOBuilder
object as the last argument of fn
. Each op returns an OpView
type which is a type of Operand
that can be passed into another builder op as an input.
def build_stablehlo_module(
fn: Callable,
inputs_shapes: List[Shape],
inputs_types: Optional[List[Union[torch.dtype, TypeInfo]]] = None,
mesh_shape: Optional[Tuple[int, int]] = None,
module_dump: bool = False,
base: Optional[str] = None,
output_root: str = ".",
)
Example
from builder.base.builder import Operand
from builder.stablehlo.stablehlo_builder import StableHLOBuilder
from builder.stablehlo.stablehlo_utils import build_stablehlo_module
shapes = [(32, 32), (32, 32), (32, 32)]
def model(in0: Operand, in1: Operand, in2: Operand, builder: StableHLOBuilder):
return builder.add(in0, in1)
module, builder = build_stablehlo_module(model, shapes)
Returns
An MLIR module containing an MLIR op graph defined by fn
and the TTIRBuilder
object used to create it
module {
func.func @model(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}
}
Creating a StableHLO module with Shardy annotations
StableHLOBuilder
allows you to attach shardy annotations to the generated mlir graph.
Example
from builder.base.builder import Operand
from builder.stablehlo.stablehlo_builder import StableHLOBuilder
from builder.stablehlo.stablehlo_utils import build_stablehlo_module
shapes = [(32, 32), (32, 32)]
def model(in0: Operand, in1: Operand, shlo_builder: StableHLOBuilder):
tensor_sharding_attr = shlo_builder.tensor_sharding_attr(
mesh_name="mesh",
dimension_shardings=[
shlo_builder.dimension_sharding_attr(
axes=[shlo_builder.axis_ref_attr(name="x")],
is_closed=True,
),
shlo_builder.dimension_sharding_attr(
axes=[shlo_builder.axis_ref_attr(name="y")],
is_closed=False,
)
]
)
shlo_builder.sharding_constraint(in0, tensor_sharding_attr=tensor_sharding_attr)
return shlo_builder.add(in0, in1)
module, shlo_builder = build_stablehlo_module(model, shapes)
Returns
An MLIR module containing shardy annotations.
module {
sdy.mesh @mesh = <["x"=1, "y"=8]>
func.func @model(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = sdy.sharding_constraint %arg0 <@mesh, [{"x"}, {"y", ?}]> : tensor<32x32xf32>
%1 = stablehlo.add %arg0, %arg1 : tensor<32x32xf32>
return %1 : tensor<32x32xf32>
}
}