pytorch 在FastAPI容器中运行稳定扩散不会释放GPU内存

czq61nw1  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(651)

我正在FastAPI Docker容器中运行Stable Diffusion。它运行正常,但在执行多次推理调用后,我注意到GPU的vRAM变满,推理失败。就好像在执行推理后内存没有立即释放一样。有什么办法强制释放内存吗?

下面是main.py中的脚本:

import logging
import os
import random
import time
import torch
from diffusers import StableDiffusionPipeline
from fastapi import FastAPI, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List, Optional

# Load default logging configuration

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)

# Load Stable Diffusion model

log.info('Load Stable Diffusion model')
model_path = './models/stable-diffusion-v1-4'
pipe = StableDiffusionPipeline.from_pretrained(
    model_path,
    revision='fp16',
    torch_dtype=torch.float16
)

# Move pipeline to GPU for faster inference

pipe = pipe.to('cuda')
pipe.enable_attention_slicing()

# Declare inputs and outputs data types for the API endpoint

class Payload(BaseModel):
    prompt: str                 # String of text used to generate the images
    num_images = 1              # Number of images to be generated
    height = 512                # Height of the images to be generated
    width = 512                 # Width of the images to be generated
    seed: Optional[int] = None  # Random integer used as a seed to guide the image generator
    num_steps = 40              # Number of inference steps, results are better the more steps you use, at a cost of slower inference
    guidance_scale = 8.5        # Forces generation to better match the prompt, 7 or 8.5 give good results, results are better the larger the number is, but will be less diverse

class Response(BaseModel):
    images: List[str]
    nsfw_content_detected: List[bool]
    prompt: str
    num_images: int
    height: int
    width: int
    seed: int
    num_steps: int
    guidance_scale: float

# Create FastAPI app

log.info('Start API')
app = FastAPI(title='Stable Diffusion')
app.mount("/static", StaticFiles(directory="./static"), name="static") # Mount folder to expose generated images

# Declare imagine endpoint for inference

@app.post('/imagine', response_model=Response, description='Runs inferences with Stable Diffusion.')
def imagine(payload: Payload, request: Request):
    """The imagine function generates the /imagine endpoint and runs inferences"""

    try:
        # Check payload
        log.info(f'Payload: {payload}')

        # Default seed with a random integer if it is not provided by user
        if payload.seed is None:
            payload.seed = random.randint(-999999999, 999999999)
        generator = torch.Generator('cuda').manual_seed(payload.seed)

        # Create multiple prompts according to the number of images
        prompt = [payload.prompt] * payload.num_images

        # Run inference on GPU
        log.info('Run inference')
        with torch.autocast('cuda'):
            result = pipe(
                prompt=prompt,
                height=payload.height,
                width=payload.width,
                num_inference_steps=payload.num_steps,
                guidance_scale=payload.guidance_scale,
                generator=generator
            )
        log.info('Inference completed')

        # Save images
        images_urls = []
        for image in result.images:
            image_name = str(time.time()).replace('.', '') + '.png'
            image_path = os.path.join('static', image_name)
            image.save(image_path)
            image_url = request.url_for('static', path=image_name)
            images_urls.append(image_url)

        # Build response object
        response = {}
        response['images'] = images_urls
        response['nsfw_content_detected'] = result['nsfw_content_detected']
        response['prompt'] = payload.prompt
        response['num_images'] = payload.num_images
        response['height'] = payload.height
        response['width'] = payload.width
        response['seed'] = payload.seed
        response['num_steps'] = payload.num_steps
        response['guidance_scale'] = payload.guidance_scale

        return response

    except Exception as e:
        log.error(repr(e))
        raise HTTPException(status_code=500, detail=repr(e))
t98cgbkg

t98cgbkg1#

我能够通过在进行推理后添加此代码片段来解决此问题...我认为这真的应该添加到文档中的不同示例中。感谢我的同事,他从Stable Diffusion WebUI存储库中找到了此代码片段。

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

相关问题