numpy 优化Python函数SmoothStep的多个条件用于Numba矢量化

azpvetkf  于 2023-10-19  发布在  Python
关注(0)|答案(1)|浏览(82)

我实现了一个使用SmoothStep创建平滑矩形函数的函数:

import numpy as np
from numba import jit, njit
import matplotlib.pyplot as plt

@njit
def GenSmoothStep( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
    
    lowClip  = max(lowVal - rollOffWidth, 0)
    highClip = min(highVal + rollOffWidth, 1)

    for ii in range(vX.size):
        valX = vX.flat[ii]
        if valX < lowClip:
            vY.flat[ii] = 0.0
        elif valX < lowVal:
            # Smoothstep [lowClip, lowVal]
            valXN = (lowVal - valX) / (lowVal - lowClip)
            vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
        elif valX > highClip:
            vY.flat[ii] = 0.0
        elif valX > highVal:
            # Smoothstep [highVal, highClip]
            valXN = (valX - highVal) / (highClip - highVal)
            vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
        else:
            vY.flat[ii] = 1.0

numGridPts = 1000

lowVal  = 0.15
highVal = 0.75
rollOffWidth = 0.3

vX = np.linspace(0, 1, numGridPts)
vY = np.empty_like(vX)

GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)

plt.plot(vX, vY)

该函数包括几个条件,这意味着向量化不友好。
我想知道是否有一些简单的步骤,使功能更Numba友好。

更新

我采用了@AndrejKesely的代码并更新了它,以处理我的代码中的边缘情况(lowVal = 0.0和/或highVal = 1.0)。
我还添加了一个变体来剪辑没有分支。这是当前的状态:

# %% 

import numpy as np
from numba import jit, njit
import matplotlib.pyplot as plt
from timeit import timeit

@njit
def GenSmoothStep000( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
    
    lowClip  = max(lowVal - rollOffWidth, 0)
    highClip = min(highVal + rollOffWidth, 1)

    for ii in range(vX.size):
        valX = vX.flat[ii]
        if valX < lowClip:
            vY.flat[ii] = 0.0
        elif valX < lowVal:
            # Smoothstep [lowClip, lowVal]
            valXN = (lowVal - valX) / (lowVal - lowClip)
            vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
        elif valX > highClip:
            vY.flat[ii] = 0.0
        elif valX > highVal:
            # Smoothstep [highVal, highClip]
            valXN = (valX - highVal) / (highClip - highVal)
            vY.flat[ii] = 1 - (valXN * valXN * (3 - (2 * valXN)))
        else:
            vY.flat[ii] = 1.0

@njit
def Clamp001( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
    if x < lowBound:
        return lowBound
    if x > highBound:
        return highBound
    return x

@njit
def Clamp002( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
    return max(min(x, highBound), lowBound)

@njit
def SmoothStep001( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
    x = Clamp001((x - lowBound) / (highBound - lowBound), 0.0, 1.0)
    return x * x * (3.0 - 2.0 * x)

@njit
def SmoothStep002( x: float, lowBound: float = 0.0, highBound: float = 1.0 ):
    x = Clamp002((x - lowBound) / (highBound - lowBound), 0.0, 1.0)
    return x * x * (3.0 - 2.0 * x)

@njit
def GenSmoothStep001( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
    lowClip  = max(lowVal - rollOffWidth, 0)
    highClip = min(highVal + rollOffWidth, 1)

    if (highVal == 1.0) and (lowVal == 0.0):
        for ii in range(vX.size):
            vY[ii] = 1.0
    elif (highVal == 1.0):
        for ii in range(vX.size):
            vY[ii] = SmoothStep001(vX[ii], lowClip, lowVal)
    elif (lowVal == 0.0):
        for ii in range(vX.size):
            vY[ii] = 1 - SmoothStep001(vX[ii], highVal, highClip)
    else:
        for ii in range(vX.size):
            vY[ii] = SmoothStep001(vX[ii], lowClip, lowVal) * (1 - SmoothStep001(vX[ii], highVal, highClip))

@njit
def GenSmoothStep002( vX: np.ndarray, lowVal: float, highVal: float, vY: np.ndarray, rollOffWidth: float = 0.1 ):
    lowClip  = max(lowVal - rollOffWidth, 0)
    highClip = min(highVal + rollOffWidth, 1)

    if (highVal == 1.0) and (lowVal == 0.0):
        for ii in range(vX.size):
            vY[ii] = 1.0
    elif (highVal == 1.0):
        for ii in range(vX.size):
            vY[ii] = SmoothStep002(vX[ii], lowClip, lowVal)
    elif (lowVal == 0.0):
        for ii in range(vX.size):
            vY[ii] = 1 - SmoothStep002(vX[ii], highVal, highClip)
    else:
        for ii in range(vX.size):
            vY[ii] = SmoothStep002(vX[ii], lowClip, lowVal) * (1 - SmoothStep002(vX[ii], highVal, highClip))


# %%
# Validation + JIT Compilation
numGridPts = 10_000

lowVal  = 0.35
highVal = 0.55
rollOffWidth = 0.3

vX = np.linspace(0, 1, numGridPts)

hF, vHa = plt.subplots(nrows = 1, ncols = 3, figsize = (16, 5))
vY = np.empty_like(vX)
GenSmoothStep000(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
vHa[0].plot(vX, vY)
vY = np.empty_like(vX)
GenSmoothStep001(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
vHa[1].plot(vX, vY)
vY = np.empty_like(vX)
GenSmoothStep002(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)
vHa[2].plot(vX, vY);

# %%
# Check Performance

time000 = timeit("GenSmoothStep000(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)", number = 10_000, globals = globals())
time001 = timeit("GenSmoothStep001(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)", number = 10_000, globals = globals())
time002 = timeit("GenSmoothStep002(vX, lowVal, highVal, vY, rollOffWidth = rollOffWidth)", number = 10_000, globals = globals())

print(time000)
print(time001)
print(time002)

输出为(在我的计算机上,Intel Core i7-6800K):

0.23776450000877958
0.23713289998704568
0.23025239999697078

所以看起来还是很接近的。

xzlaal3s

xzlaal3s1#

IIUC你只想合并smoothstep

import matplotlib.pyplot as plt
import numpy as np
from numba import njit

@njit
def smoothstep(edge0, edge1, x):
    x = np.clip((x - edge0) / (edge1 - edge0), 0, 1)
    return x * x * (3.0 - 2.0 * x)

numGridPts = 1000

lowVal = 0.15
highVal = 0.75

vX = np.linspace(0, 1, numGridPts)
vY = smoothstep(0, lowVal, vX) * (1 - smoothstep(highVal, 1, vX))

plt.plot(vX, vY)
plt.show()

显示此图表:

编辑:新版本(但通常与组合smoothsteps相同),没有分配和if-s(支持rollOffWidth):

@njit
def clamp(x, lowerlimit=0.0, upperlimit=1.0):
    if x < lowerlimit:
        return lowerlimit
    if x > upperlimit:
        return upperlimit
    return x

@njit
def smoothstep(edge0, edge1, x):
    x = clamp((x - edge0) / (edge1 - edge0), 0, 1)
    return x * x * (3.0 - 2.0 * x)

@njit
def GenSmoothStep2(
    vX: np.ndarray,
    lowVal: float,
    highVal: float,
    vY: np.ndarray,
    rollOffWidth: float = 0.1,
):
    lowClip = max(lowVal - rollOffWidth, 0)
    highClip = min(highVal + rollOffWidth, 1)

    for i in range(len(vX)):
        vY[i] = smoothstep(lowClip, lowVal, vX[i]) * (
            1 - smoothstep(highVal, highClip, vX[i])
        )

基准:

numGridPts = 1000

lowVal = 0.45
highVal = 0.65
rollOffWidth = 0.3

vX = np.linspace(0, 1, numGridPts)
vY = np.empty_like(vX)

# warm up numba:
GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)
GenSmoothStep2(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)

from timeit import timeit

t1 = timeit(
    "GenSmoothStep(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)",
    number=10_000,
    globals=globals(),
)

t2 = timeit(
    "GenSmoothStep2(vX, lowVal, highVal, vY, rollOffWidth=rollOffWidth)",
    number=10_000,
    globals=globals(),
)

print(t1)
print(t2)

我的计算机(AMD 5700 X)上的打印:

0.010718749952502549
0.007840660051442683

因此,新功能的速度快了约36%。

相关问题