PyTorch阶乘函数

ruarlubt  于 9个月前  发布在  其他
关注(0)|答案(3)|浏览(134)

PyTorch似乎没有计算阶乘的函数。在PyTorch中是否有方法可以做到这一点?我希望在Torch中手动计算泊松分布(我知道这是存在的:https://pytorch.org/docs/stable/generated/torch.poisson.html)并且公式需要分母中的阶乘。
泊松分布:https://en.wikipedia.org/wiki/Poisson_distribution

nbewdwxp

nbewdwxp1#

我想你可以找到它作为torch.jit._builtins.math.factorialpytorch以及numpyscipyFactorial in numpy and scipy)使用python的内置math.factorial

import math

import numpy as np
import scipy as sp
import torch

print(torch.jit._builtins.math.factorial is math.factorial)
print(np.math.factorial is math.factorial)
print(sp.math.factorial is math.factorial)

个字符
但是,相比之下,除了“主流”math.factorial之外,scipy还包含非常“特殊”的阶乘函数scipy.special.factorial。与math模块中的函数不同,它对数组进行操作:

from scipy import special

print(special.factorial is math.factorial)
False
# the all known factorial functions
factorials = (
    math.factorial,
    torch.jit._builtins.math.factorial,
    np.math.factorial,
    sp.math.factorial,
    special.factorial,
)

# Let's run some tests
tnsr = torch.tensor(3)

for fn in factorials:
    try:
        out = fn(tnsr)
    except Exception as err:
        print(fn.__name__, fn.__module__, ':', err)
    else:
        print(fn.__name__, fn.__module__, ':', out)
factorial math : 6
factorial math : 6
factorial math : 6
factorial math : 6
factorial scipy.special._basic : tensor(6., dtype=torch.float64)
tnsr = torch.tensor([1, 2, 3])

for fn in factorials:
    try:
        out = fn(tnsr)
    except Exception as err:
        print(fn.__name__, fn.__module__, ':', err)
    else:
        print(fn.__name__, fn.__module__, ':', out)
factorial math : only integer tensors of a single element can be converted to an index
factorial math : only integer tensors of a single element can be converted to an index
factorial math : only integer tensors of a single element can be converted to an index
factorial math : only integer tensors of a single element can be converted to an index
factorial scipy.special._basic : tensor([1., 2., 6.], dtype=torch.float64)
a0x5cqrl

a0x5cqrl2#

内置的math模块(docs)提供了一个函数,将给定积分的阶乘作为int返回。

import math

x = math.factorial(5)
print(x)
print(type(x))

字符串

输出

120
<class 'int'>

sg3maiej

sg3maiej3#

torch.lgamma计算伽马的对数,其等于(x-1)!

import torch

def log_factorial(x):
  return torch.lgamma(x + 1)

def factorial(x):
  return torch.exp(log_factorial(x))

字符串

相关问题