ttnn.sampling

ttnn.sampling(input_values_tensor: ttnn.Tensor, input_indices_tensor: ttnn.Tensor, k: ttnn.Tensor, p: ttnn.Tensor, temp: ttnn.Tensor, seed: int = 0, sub_core_grids: ttnn.CoreRangeSet = None, optional_output_tensor: ttnn.Tensor = None) None

Samples from the input_values_tensor based on provided top-k and top-p constraints.

This operation samples values from the input_values_tensor based on the provided thresholds k (top-k sampling) and p (top-p nucleus sampling). The operation uses the input_indices_tensor for indexing and applies sampling under the given seed for reproducibility.

The op first converts the input_values_tensor into probabilities by doing a softmax.

In top-k sampling, the op considers only the k highest-probability values from the input distribution. The remaining values are ignored, regardless of their probabilities. In top-p sampling, the op selects values from the input distribution such that the cumulative probability mass is less than or equal to a threshold p. When combining top-k and top-p sampling, the op first applies the top-k filter and then the top-p filter.

Within this selected corpus, multinomial sampling is applied. Multinomial sampling selects values from a given distribution by comparing each probability with a randomly generated number between 0 and 1. Specifically, the operation identifies the largest cumulative probability that exceeds the random threshold.

The op finally returns input_indices_tensor[final_index] where final_index is the index of the largest cumulative probability > random number found in the multinomial sampling.

Currently, this operation supports inputs and outputs with specific memory layout and data type constraints.

Parameters:
  • input_values_tensor (ttnn.Tensor) – The input tensor containing values to sample from.

  • input_indices_tensor (ttnn.Tensor) – The input tensor containing indices to assist with sampling.

  • k (ttnn.Tensor) – Top-k values for sampling.

  • p (ttnn.Tensor) – Top-p (nucleus) probabilities for sampling.

  • temp (ttnn.Tensor) – Temperature tensor for scaling (1/T).

  • seed (int, optional) – Seed for sampling randomness. Defaults to 0.

  • sub_core_grids (ttnn.CoreRangeSet, optional) – Core range set for multicore execution. Defaults to None.

  • optional_output_tensor (ttnn.Tensor, optional) – Preallocated output tensor. Defaults to None.

Note

This operations only supports inputs and outputs according to the following data types and layout:

input_values_tensor

dtype

layout

BFLOAT16

TILE

input_indices_tensor

dtype

layout

UINT32, INT32

ROW_MAJOR

k

dtype

layout

UINT32

ROW_MAJOR

p, temp

dtype

layout

BFLOAT16

ROW_MAJOR

If no output_tensor is provided, the return tensor will be as follows:

output_tensor (default)

dtype

layout

UINT32

ROW_MAJOR

If output_tensor is provided, the supported data types and layout are:

output_tensor (if provided)

dtype

layout

INT32, UINT32

ROW_MAJOR

Returns:

ttnn.Tensor: The output tensor containing sampled indices.

Memory Support:
  • Interleaved: DRAM and L1

Limitations:
  • Inputs must be 4D tensors with shape [N, C, H, W], and must be located on the device.

  • The input tensors must represent exactly 32 users based on their shape (i.e. N*C*H = 32).

  • The last dimension of:attr:input_values_tensor must be padded to a multiple of 32

  • The overall shape of input_values_tensor must match that of input_indices_tensor.

  • k: Must contain 32 values, in the range ‘(0,32]’.

  • p, temp: Must contain 32 values in the range [0.0, 1.0].

  • sub_core_grids (if provided): number of cores must equal the number of users (which is constrained to 32).

Example

# Input values tensor for N*C*H = 32 users
input_tensor = ttnn.rand([1, 1, 32, 64], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

# Input indices tensor: this example uses sequential indices [0, 1, 2, ..., W-1] for each of the 32 users
# Resulting in a final shape of [1, 1, 32, 64]
indices_1d = ttnn.arange(0, 64, dtype=ttnn.int32, device=device)
indices_reshaped = ttnn.reshape(indices_1d, [1, 1, 1, 64])
input_indices_tensor = ttnn.repeat(indices_reshaped, (1, 1, 32, 1))

# k tensor: 32 values in range (0, 32] for top-k sampling
k_tensor = ttnn.full([32], fill_value=10, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

# p tensor: 32 values in range [0.0, 1.0] for top-p sampling
p_tensor = ttnn.full([32], fill_value=0.9, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

# temp tensor: 32 temperature values in range [0.0, 1.0]
temp_tensor = ttnn.ones([32], dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)

output = ttnn.sampling(input_tensor, input_indices_tensor, k=k_tensor, p=p_tensor, temp=temp_tensor)
logger.info(f"Sampling result: {output}")