python tqdm不是继续前进,而是重新开始

4xrmg8kj  于 2023-09-29  发布在  Python
关注(0)|答案(2)|浏览(164)

我有以下工作代码。我想打印所有的第5000个项目的损失每一个(% 5000)在一个新的行**,但进度条应继续其方式一次和所有(应只打印一次)**,并显示总进度。我应该如何修改代码?

import torch
from math import tanh,cos
from tqdm import tqdm
from time import sleep

batch, dim_in, dim_h, dim_out = 1, 100, 10, 1

input_X = torch.randn(batch, dim_in)
output_Y = torch.randn(batch, dim_out)

SGD_model = torch.nn.Sequential(
    torch.nn.Linear(dim_in, dim_h),
    torch.nn.Tanh(),
    torch.nn.Linear(dim_h, dim_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

rate_learning = 0.01

optim = torch.optim.SGD(SGD_model.parameters(), lr=rate_learning, momentum=0.01)
    
for values in tqdm(range(1000)):
    pred_y = SGD_model(input_X)
    loss = loss_fn(pred_y, output_Y)
    if values % 100 == 0:
        print(values, loss.item())  
    optim.zero_grad()
    loss.backward()
    optim.step()

三个进度条而不是一个:

chatGPT 4给了我这个代码,但仍然有4个进度条:

import torch
from math import tanh
from time import sleep

batch, dim_in, dim_h, dim_out = 32, 10, 5, 1

input_X = torch.randn(batch, dim_in)
output_Y = torch.randn(batch, dim_out)

SGD_model = torch.nn.Sequential(
    torch.nn.Linear(dim_in, dim_h),
    torch.nn.Tanh(),
    torch.nn.Linear(dim_h, dim_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

rate_learning = 0.0001

optim = torch.optim.SGD(SGD_model.parameters(), lr=rate_learning, momentum=0.4)

# Define the training loop with a range
epochs = 1000
for epoch in range(epochs):
    progress = (epoch + 1) / epochs
    bar_length = 50
    block = int(round(bar_length * progress))
    text = f"Epoch {epoch+1}/{epochs} [{'#' * block + '-' * (bar_length - block)}] {100 * progress:.2f}%"
    pred_y = SGD_model(input_X)
    loss = loss_fn(pred_y, output_Y)
    if epoch % 100 == 0:
        text += f', Loss: {loss.item()}'
    print(text, end='\r')
    optim.zero_grad()
    loss.backward()
    optim.step()

print("\nWeights:", SGD_model[0].weight)
print("Bias:", SGD_model[0].bias)
print("Bias:", SGD_model[2].bias)
cczfrluj

cczfrluj1#

而不是print( ... ),更喜欢tqdm.write( ... )
另外,考虑用logger替换打印。然后,您可以依赖tqdm上下文管理器。从文档中:

with logging_redirect_tqdm():
        for i in trange(9):
            if i == 4:
                LOG.info("console logging redirected to `tqdm.write()`")

作为与原始代码的上下文差异,“use .write()!“归结为:

@@ -85,3 +85,3 @@ for values in tqdm(range(1000)):
     if values % 100 == 0:
-        print(values, loss.item())
+        tqdm.write(f"{values}, {loss.item()}")
     optim.zero_grad()

此外,如果您以绑定的进度条变量开始循环

bar = tqdm(range(1000))
for values in bar:

你可以在循环中使用bar.write( ... )。或者更好的方法是调整bar.set_description( ... )bar.bar_format。这可以防止旧的损失数据从屏幕顶部滚动出来。

ds97pgxw

ds97pgxw2#

考虑使用Enlighten而不是tqdm,因为它本身处理打印。
对于range(),还有一个差一的问题,因为要从0到999,而不是从1到1000。
只是更换

for values in tqdm(range(1000)):

import enlighten

manager = enlighten.get_manager()
pbar = manager.counter(total=1000)

for values in pbar(range(1, 1001)):

你会得到这样的东西。不需要对输出做任何特殊的处理。

100 2.220446049250313e-16
200 2.220446049250313e-16
300 2.220446049250313e-16
400 2.220446049250313e-16
500 2.220446049250313e-16
600 2.220446049250313e-16
700 2.220446049250313e-16
800 2.220446049250313e-16
900 2.220446049250313e-16
1000 2.220446049250313e-16

100%|███████████████████████████████████████| 1000/1000 [00:00<00:00, 8639.66/s]

如果您想自定义格式或添加颜色,请查看documentation

相关问题