python 如何确保timerseriesAI/tsai使用GPU

0yycz8jy  于 2023-03-16  发布在  Python
关注(0)|答案(1)|浏览(184)

我正在使用tsai 0.3.5进行时间序列分类。但是训练一个纪元需要花费不寻常的时间。有人能告诉我如何确保tsai使用GPU而不是CPU吗?
请在下面找到我的代码.

import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
from pickle import load
from multiprocessing import Process
import numpy as np
from tsai.all import *
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve

num_datesets = 1
for dataset_idx in range(num_datesets):
    X_train = load(open(r"X_train_"+str(dataset_idx)+".pkl", 'rb'))
    y_train = load(open(r"y_train_"+str(dataset_idx)+".pkl", 'rb'))
    X_test = load(open(r"X_test_"+str(dataset_idx)+".pkl", 'rb'))
    y_test = load(open(r"y_test_"+str(dataset_idx)+".pkl", 'rb'))
    print("dataset loaded")

    learn = TSClassifier(X_train, y_train, arch=InceptionTimePlus, arch_config=dict(fc_dropout=0.5))

    print("training started")
    learn.fit_one_cycle(5, 0.0005)
    learn.export("tsai_"+str(dataset_idx)+".pkl") 
    
    probas, target, preds = learn.get_X_preds(X_test, y_test)
    precision, recall, thresholds = precision_recall_curve(target, probas)
    plt.clf()
    plt.fill_between(recall, precision)
    plt.ylabel("Precision")
    plt.xlabel("Recall")
    plt.title("tsai_"+str(dataset_idx)+"_precision_recall_curve")
    plt.savefig("tsai_"+str(dataset_idx)+".png")
    plt.show()
kq0g1dla

kq0g1dla1#

它只使用GPU。增加批处理大小似乎可以解决这个问题。Tsai中的默认批处理大小非常小。

相关问题