python JAX中的负采样

zzzyeukh  于 2023-01-12  发布在  Python
关注(0)|答案(1)|浏览(150)

我目前正在JAX中实现一个负采样算法。其思想是从一个范围中采样负,该范围排除了一些不可接受的输出。我目前的解决方案接近于以下内容:

import jax.numpy as jnp
    import jax
    max_range = 5
    n_samples = 2
    true_cases = jnp.array(
        [
            [1,2],
            [1,4],
            [0,5]
        ]
    )
    # i combine the true cases in a dictionary of the following form:
    non_acceptable_as_negatives = {
        0: jnp.array([5]),
        1: jnp.array([2,4]),
        2: jnp.array([]),
        3: jnp.array([]),
        4: jnp.array([]),
        5: jnp.array([])
    }
    negatives = []
    key = jax.random.PRNGKey(42)
    for i in true_cases[:,0]:
        key,use_key  = jax.random.split(key,2)
        p = jnp.ones((max_range+1,))
        p = p.at[non_acceptable_as_negatives[int(i)]].set(0)
        p = p / p.sum()
        negatives.append(
            jax.random.choice(use_key,
                jnp.arange(max_range+1),
                (1, n_samples),
                replace=False,
                p=p,
                )
        )

然而,这看起来a)相当复杂,B)性能不是很好,因为原始的真实案例包含~200_000个条目,最大范围是~ 50_000。我如何改进这个解决方案?是否有一种更JAX的方式来存储不同大小的数组,我目前存储在non_acceptable_as_negatives dict中?提前感谢

owfi6suc

owfi6suc1#

Jax数组是不可变的,这意味着不复制整个数组就不能编辑它,这里的主要问题是每次迭代都要创建两次向量p,我建议你通过numpy只计算一次概率:

import numpy as np

non_acceptable_as_negatives = {
    0: np.array([5]),
    1: np.array([2,4]),
    2: np.array([]),
    3: np.array([]),
    4: np.array([]),
    5: np.array([])
}

probas = np.ones((max_range+1, max_range+1))
for k, idx in non_acceptable_as_negatives.items():
    for i in idx:
        probas[k, i] = 0
probas = probas / probas.sum(axis=1, keepdims=True)
probas = jnp.array(probas)

然后,为了进一步加快算法的速度,您可以编译choice函数。

from functools import partial

@partial(jax.jit, static_argnums=1)
def sample(key, max_range, probas):
    key, use_key  = jax.random.split(key, 2)
    return jax.random.choice(use_key,
            jnp.arange(max_range+1),
            (1, n_samples),
            replace=False,
            p=probas[i],
            ), key

最后:

for i in true_cases[:,0]:
    neg, key = aux(key, max_range, probas)
    negatives.append(neg)

相关问题