我正在使用tsai 0.3.5进行时间序列分类。但是训练一个纪元需要花费不寻常的时间。有人能告诉我如何确保tsai使用GPU而不是CPU吗?
请在下面找到我的代码.
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
from pickle import load
from multiprocessing import Process
import numpy as np
from tsai.all import *
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
num_datesets = 1
for dataset_idx in range(num_datesets):
X_train = load(open(r"X_train_"+str(dataset_idx)+".pkl", 'rb'))
y_train = load(open(r"y_train_"+str(dataset_idx)+".pkl", 'rb'))
X_test = load(open(r"X_test_"+str(dataset_idx)+".pkl", 'rb'))
y_test = load(open(r"y_test_"+str(dataset_idx)+".pkl", 'rb'))
print("dataset loaded")
learn = TSClassifier(X_train, y_train, arch=InceptionTimePlus, arch_config=dict(fc_dropout=0.5))
print("training started")
learn.fit_one_cycle(5, 0.0005)
learn.export("tsai_"+str(dataset_idx)+".pkl")
probas, target, preds = learn.get_X_preds(X_test, y_test)
precision, recall, thresholds = precision_recall_curve(target, probas)
plt.clf()
plt.fill_between(recall, precision)
plt.ylabel("Precision")
plt.xlabel("Recall")
plt.title("tsai_"+str(dataset_idx)+"_precision_recall_curve")
plt.savefig("tsai_"+str(dataset_idx)+".png")
plt.show()
1条答案
按热度按时间kq0g1dla1#
它只使用GPU。增加批处理大小似乎可以解决这个问题。Tsai中的默认批处理大小非常小。