我目前正在JAX中实现一个负采样算法。其思想是从一个范围中采样负,该范围排除了一些不可接受的输出。我目前的解决方案接近于以下内容:
import jax.numpy as jnp
import jax
max_range = 5
n_samples = 2
true_cases = jnp.array(
[
[1,2],
[1,4],
[0,5]
]
)
# i combine the true cases in a dictionary of the following form:
non_acceptable_as_negatives = {
0: jnp.array([5]),
1: jnp.array([2,4]),
2: jnp.array([]),
3: jnp.array([]),
4: jnp.array([]),
5: jnp.array([])
}
negatives = []
key = jax.random.PRNGKey(42)
for i in true_cases[:,0]:
key,use_key = jax.random.split(key,2)
p = jnp.ones((max_range+1,))
p = p.at[non_acceptable_as_negatives[int(i)]].set(0)
p = p / p.sum()
negatives.append(
jax.random.choice(use_key,
jnp.arange(max_range+1),
(1, n_samples),
replace=False,
p=p,
)
)
然而,这看起来a)相当复杂,B)性能不是很好,因为原始的真实案例包含~200_000个条目,最大范围是~ 50_000。我如何改进这个解决方案?是否有一种更JAX的方式来存储不同大小的数组,我目前存储在non_acceptable_as_negatives dict中?提前感谢
1条答案
按热度按时间owfi6suc1#
Jax数组是不可变的,这意味着不复制整个数组就不能编辑它,这里的主要问题是每次迭代都要创建两次向量
p
,我建议你通过numpy只计算一次概率:然后,为了进一步加快算法的速度,您可以编译
choice
函数。最后: