@@ -1067,6 +1067,41 @@ def stateful_randc(
10671067 "Backend '{}' has not implemented `stateful_randc`." .format (self .name )
10681068 )
10691069
1070+ def probability_sample (
1071+ self : Any , shots : int , p : Tensor , status : Optional [Tensor ] = None , g : Any = None
1072+ ) -> Tensor :
1073+ """
1074+ Drawn ``shots`` samples from probability distribution p, given the external randomness
1075+ determined by uniform distributed ``status`` tensor or backend random generator ``g``.
1076+ This method is similar with ``stateful_randc``, but it supports ``status`` beyond ``g``,
1077+ which is convenient when jit or vmap
1078+
1079+ :param shots: Number of samples to draw with replacement
1080+ :type shots: int
1081+ :param p: prbability vector
1082+ :type p: Tensor
1083+ :param status: external randomness as a tensor with each element drawn uniformly from [0, 1],
1084+ defaults to None
1085+ :type status: Optional[Tensor], optional
1086+ :param g: backend random genrator, defaults to None
1087+ :type g: Any, optional
1088+ :return: The drawn sample as an int tensor
1089+ :rtype: Tensor
1090+ """
1091+ if status is not None :
1092+ status = self .convert_to_tensor (status )
1093+ elif g is not None :
1094+ status = self .stateful_randu (g , shape = [shots ])
1095+ else :
1096+ status = self .implicit_randu (shape = [shots ])
1097+ p = p / self .sum (p )
1098+ p_cuml = self .cumsum (p )
1099+ r = p_cuml [- 1 ] * (1 - self .cast (status , p .dtype ))
1100+ ind = self .searchsorted (p_cuml , r )
1101+ a = self .arange (shots )
1102+ res = self .gather1d (a , ind )
1103+ return res
1104+
10701105 def gather1d (self : Any , operand : Tensor , indices : Tensor ) -> Tensor :
10711106 """
10721107 Return ``operand[indices]``, both ``operand`` and ``indices`` are rank-1 tensor.
0 commit comments