Runtime Stitching
Runtime stitching adds the ability for the runtime to stitch together multiple, indepently compiled programs together at runtime, ie. without compiler knowledge of how the binary programs will be composed.
Motivation
In order to flexibly support arbitrary training schedules / composing multiple models together we want to have the ability for the runtime to stitch graphs together. To achieve this we need to define an ABI kind of interface between the compiler and the runtime.
Simple Example
mod_a = forge.compile(PyTorch_module_a)
mod_b = forge.compile(PyTorch_module_b)
for i in range(10):
outs_a = mod_a(ins_a)
outs_b = mod_b(outs_a)
mod_a and mod_b are 2 independent compile steps, during the compile step for
mod_a it should be completely unaware that mod_b will take place and vice-versa.
In order to achieve this we propose a new runtime concept called stitching:
- forge invokes compile step for
mod_a, tt-mlir compiler determines where the inputs (ins_a) should live, host, device dram, device l1. tt-mlir returns metadata to forge describing where it wants the tensors to reside before invoking flatbuffer submission. - forge invokes compile step for
mod_b, same happens as bullet 1 mod_ais invoked at runtime, forge runtime needs to inspect the compiler metadata to determine where the tensors should live. Runtime manually invokes a new data copy command to get the tenors to the correct memory space / correct memory address.- forge runtime invokes
mod_aprogram submit mod_bis invoked at runtime, this time it might be that the compiler left the tensor outputs in L1, so no data copy is needed to start runningmod_bsince the inputs are already in the correct location.
A more concrete usecase would be a training loop where there are often multiple graphs composed together. #82 Or when we eventually support torch 2.0, the torch runtime can arbitrarily break the graph anywhere.
Proposed Changes
Compiler Metadata
Compiler will encode the input tensor layout information directly into the flatbuffer tensor desc. The flatbuffer schema already exists to express this, we just need to adopt populating it instead of assuming a canonical host layout.
Compiler will decide where the tensors should live, host, device dram, device l1.
Runtime
- Runtime will inspect the tensor desc metadata to determine where the tensors need to end up / what layout they should be in before invoking the program.
- New runtime API
Tensor toLayout(Tensor tensor, ::tt::target::TensorDesc* tensorDesc); - Runtime will need to invoke
toLayouton all input tensors before invoking the program.
Test Plan
- Add a new test to the runtime gtest suite that verifies the runtime can correctly stitch together 2 independently compiled programs.
Concerns
- Tensors pass through device memory spaces (dram, L1) will have a dynamic address, some arbitrary run order of flatbuffer could cause tensors to end up in non-ideal locations in memory. Specifically, L1, a poorly placed tensor might not be able to be moved to a better location without a bounce through DRAM.