torch.multinominal方法可以根据给定权重对数组进行多次采样,返回采样后的元素下标
参数说明 input :权重,也就是取每个值的概率,可以是1维或2维。可以不进行归一化。 num_samples : 采样的次数。如果input是二维的,则表示每行的采样次数 replacement :默认值值是False,即不放回采样。如果replacement =False,则num_samples必须小于input中非零元素的数目
按权重采样 从四个元素中随机选择两个,每个元素被选择到的概率分别为:[0.2, 0.2, 0.3, 0.3]:
>>> weights = torch.Tensor([0.9, 0.25, 0.1, 0.15]) # 采样权重 >>> torch.multinomial(weights, 2) tensor([0, 1]) >>> torch.multinomial(weights, 2) tensor([1, 3]) >>> torch.multinomial(weights, 2) tensor([0, 3]) >>> torch.multinomial(weights, 2) tensor([3, 1]) >>> torch.multinomial(weights, 2) tensor([1, 0]) >>> torch.multinomial(weights, 2) tensor([1, 0]) >>> torch.multinomial(weights, 2) tensor([0, 1]) >>> torch.multinomial(weights, 2) tensor([0, 2]) >>> torch.multinomial(weights, 2) tensor([3, 0]) >>> torch.