ttnn.sampling
- ttnn.sampling(input_values_tensor: ttnn.Tensor, input_indices_tensor: ttnn.Tensor, k: ttnn.Tensor, p: ttnn.Tensor, temp: ttnn.Tensor, seed: int = 0, sub_core_grids: ttnn.CoreRangeSet = None, optional_output_tensor: ttnn.Tensor = None) None
-
Samples from the
input_values_tensorbased on provided top-k and top-p constraints.This operation samples values from the
input_values_tensorbased on the provided thresholdsk(top-k sampling) andp(top-p nucleus sampling). The operation uses theinput_indices_tensorfor indexing and applies sampling under the given seed for reproducibility.The op first converts the
input_values_tensorinto probabilities by doing a softmax.In top-k sampling, the op considers only the k highest-probability values from the input distribution. The remaining values are ignored, regardless of their probabilities. In top-p sampling, the op selects values from the input distribution such that the cumulative probability mass is less than or equal to a threshold p. When combining top-k and top-p sampling, the op first applies the top-k filter and then the top-p filter.
Within this selected corpus, multinomial sampling is applied. Multinomial sampling selects values from a given distribution by comparing each probability with a randomly generated number between 0 and 1. Specifically, the operation identifies the largest cumulative probability that exceeds the random threshold.
The op finally returns input_indices_tensor[final_index] where final_index is the index of the largest cumulative probability > random number found in the multinomial sampling.
Currently, this operation supports inputs and outputs with specific memory layout and data type constraints.
- Parameters:
-
input_values_tensor (ttnn.Tensor) – The input tensor containing values to sample from.
input_indices_tensor (ttnn.Tensor) – The input tensor containing indices to assist with sampling.
k (ttnn.Tensor) – Top-k values for sampling.
p (ttnn.Tensor) – Top-p (nucleus) probabilities for sampling.
temp (ttnn.Tensor) – Temperature tensor for scaling (1/T).
seed (int, optional) – Seed for sampling randomness. Defaults to 0.
sub_core_grids (ttnn.CoreRangeSet, optional) – Core range set for multicore execution. Defaults to None.
optional_output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.
Note
This operations only supports inputs and outputs according to the following data types and layout:
input_values_tensor dtype
layout
BFLOAT16
TILE
input_indices_tensor dtype
layout
UINT32, INT32
ROW_MAJOR
k dtype
layout
UINT32
ROW_MAJOR
p, temp dtype
layout
BFLOAT16
ROW_MAJOR
If no
output_tensoris provided, the return tensor will be as follows:output_tensor (default) dtype
layout
UINT32
ROW_MAJOR
If
output_tensoris provided, the supported data types and layout are:output_tensor (if provided) dtype
layout
INT32, UINT32
ROW_MAJOR
- Returns:
-
ttnn.Tensor: The output tensor containing sampled indices.
- Memory Support:
-
Interleaved: DRAM and L1
- Limitations:
-
Inputs must be 4D tensors with shape [N, C, H, W], and must be located on the device.
The input tensors must represent exactly 32 users based on their shape (i.e. N*C*H = 32).
The last dimension of:attr:input_values_tensor must be padded to a multiple of 32
The overall shape of
input_values_tensormust match that ofinput_indices_tensor.k: Must contain 32 values, in the range ‘(0,32]’.p,temp: Must contain 32 values in the range [0.0, 1.0].sub_core_grids(if provided): number of cores must equal the number of users (which is constrained to 32).
Example
# Input values tensor for N*C*H = 32 users input_tensor = ttnn.rand([1, 1, 32, 64], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) # Input indices tensor: this example uses sequential indices [0, 1, 2, ..., W-1] for each of the 32 users # Resulting in a final shape of [1, 1, 32, 64] indices_1d = ttnn.arange(0, 64, dtype=ttnn.int32, device=device) indices_reshaped = ttnn.reshape(indices_1d, [1, 1, 1, 64]) input_indices_tensor = ttnn.repeat(indices_reshaped, (1, 1, 32, 1)) # k tensor: 32 values in range (0, 32] for top-k sampling k_tensor = ttnn.full([32], fill_value=10, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) # p tensor: 32 values in range [0.0, 1.0] for top-p sampling p_tensor = ttnn.full([32], fill_value=0.9, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) # temp tensor: 32 temperature values in range [0.0, 1.0] temp_tensor = ttnn.ones([32], dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) output = ttnn.sampling(input_tensor, input_indices_tensor, k=k_tensor, p=p_tensor, temp=temp_tensor) logger.info(f"Sampling result: {output}")