如何将Matlab函数矢量化?

ctehm74n  于 2023-04-06  发布在  Matlab
关注(0)|答案(1)|浏览(175)

这个问题是来自超级用户的migrated,因为它可以在Stack Overflow. Migrated上回答5天前。
我试图通过矢量化来加速Matlab函数。该函数看起来像这样:

function tradePL = tradepl(signals, periodPL)
tradePL = [];
PL = 0;
for i = 2:length(signals)-1
    if signals(i) == 0
        if signals(i-1) ~= 0          % trade ends
            PL = PL + periodPL(i+1);  
            tradePL = [tradePL,PL];  
            PL = 0;
        end
    else
        if signals(i-1) == signals(i)  || signals(i-1) == 0
            PL = PL + periodPL(i+1);
        else
            tradePL = [tradePL, PL];
            PL = periodPL(i+1);
        end
    end
end

其工作原理如下:

signals = {0, 1, 0, -1, -1, -1, 1, 1, 0, 0};
periodPL = {0, 0, -0.0150, 3.0000, 0.9850, -0.0150, 1.0000, 
   1.0000, -3.0150, 0};
tradePL = tradepl[signals, periodPL]

{2.985,1.97,-2.015}

该函数试图做的是将周期性PL分组并总计为交易,这些交易跨越不同长度的周期。
周期t的交易由周期t-1的信号触发,当信号的符号改变时终止。在给定的例子中,信号中有三个(对)符号变化,因此有三笔交易。我想分组并添加periodPL值,以便每组的总和是相应交易的PL。

3pvhb19x

3pvhb19x1#

我使用cumsum对你的代码进行了向量化,如下所示:

function tradePL = tradepl_vec(signals, periodPL)
    % Get trigger indices for each trade
    trigger_idx_a = find(signals(2:end)~=0  & signals(1:end-1)==0) + 1; % start trade
    trigger_idx_b = find(signals(2:end)~=0 & abs(diff(signals))==2) + 1; % start and end trade
    trigger_idx_c = find(signals(2:end-1)==0  & signals(1:end-2)~=0) + 2; % end trade
    
    % extract start and end indices for cumultative sums
    start_idx = sort([trigger_idx_a, trigger_idx_b]);
    end_idx = sort([trigger_idx_b, trigger_idx_c]);
    % catch additional start and end indices
    if end_idx(1) <= start_idx(1)
        start_idx = [2, start_idx];
    end
    if length(start_idx) ~= length(end_idx)
        start_idx = start_idx(1:end-1);
    end
    % calculate sums of grouped trades
    periodPL_sums = cumsum(periodPL);
    tradePL = periodPL_sums(end_idx) - periodPL_sums(start_idx);
end

我从你的if-语句中提取了三个开始和/或结束交易的条件:

  • 当信号不是0并且之前是0时,交易在下一个指数开始
  • 当信号从-1切换到1或从1切换到-1时,交易在当前指数结束,新的交易在下一个指数开始
  • 当信号为0且在下一个指数交易结束前不为0时

这将计算给定示例的期望输出,并对我能想到的一些其他序列执行相同的操作。
注意:代码可能会在signals向量的末尾找到一个额外的开始索引,这个索引被忽略了,或者在代码的开头找到一个结束索引,其中给定的代码暗示了一个开始索引,所以我也添加了它。
这里显示了periodPLsignals的交易组比较。
Plot of signals and grouped periodPL
我不确定这是否是预期的行为(我希望组向左移动一个索引),但这是给定代码的行为,所以我希望,我不太理解代码背后的意图。如果不是,可以通过更改trigger_idx_变量行中的加法来调整移位。

编辑:我不得不纠正条件,检查第一个结束索引是否在第一个开始索引之前。此外,您的代码没有找到从第二个到最后一个到最后一个索引的零交叉,如果有意,trigger_idx_b应该只考虑end-1的索引而不是end。关于代码的速度,对于长度为10^6的随机信号,上面的函数比原始函数快7倍,比注解中建议的优化方法快3.5/3倍,使用tradePL(end+1) = PL;或预分配tradePL大小为signals并删除未使用的元素tradePL(j:end) = [];。但是,由于即使是您的原始代码在我的笔记本电脑上也只需要半秒多一点的时间,我不知道速度是否应该是您最关心的问题,正如评论中所指出的那样。

相关问题