是否有其他替代torch.Generator()的方法来支持MPS设备类型?(M1 Mac、PyTorch)

btqmn9zl  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(163)

我正在运行python代码在我的m1 mac上实现稳定扩散,并在我的text 2 img函数中得到此错误。我知道Pytorch最近开始支持m1 GPU。我得到此错误(RuntimeError:torch.Generator()api不支持设备类型MPS。)我把代码放在下面,并突出显示给我错误的行。如果能得到任何帮助,我将不胜感激,谢谢!

def txt2img(prompt, width, height, guidance_scale, steps, seed): 
  global pipe, pipe_type 

  if pipe_type != 'txt2img': 
    pipe = None 
    clear_memory() 

    pipe_type = 'txt2img' 
    pipe = StableDiffusionPipeline.from_pretrained( 
      "CompVis/stable-diffusion-v1-4", 
      revision="fp16", 
      torch_dtype=torch.float16,
      use_auth_token=YOUR_TOKEN # use huggingface token for private model
    ).to("mps") 

  seed = random.randint(0, 2**32) if seed == -1 else seed 
  generator = torch.Generator(device='mps').manual_seed(int(seed)) 

  pipe.enable_attention_slicing() 
  with autocast("mps"): 
    image = pipe(prompt=prompt,
                 height=height, width=width,
                 num_inference_steps=steps, guidance_scale=guidance_scale, 
                 generator=generator).images[0] 

  return [[image], seed]

错误所指涉的主要程式码行如下:generator = torch.Generator(device='mps').manual_seed(int(seed))

tvokkenx

tvokkenx1#

由于您只需要一个随机数,因此只需在CPU中生成即可:

generator = torch.Generator().manual_seed(int(seed))
ruarlubt

ruarlubt2#

如果我使用torch.has_mps.manual_seed(SEED),它会显示AttributeError:'bool'对象没有'manual_seed'属性

相关问题