Chinese-CLIP load_from_name 加入 flash-attn 支持

5n0oy7gb  于 4个月前  发布在  其他
关注(0)|答案(4)|浏览(49)

感谢你提供的代码实现,它对我的帮助很大。然而,在使用load_from_name函数时,我发现它并不支持flash-attn,因此我自己实现了这部分代码。尽管它可以正常运行,但我仍然不确定实现是否正确。如果作者有空,能否帮我检查一下这个实现是否正确呢?如果是正确的,作者可以将我的实现加入到仓库中。非常感谢!

以下是代码片段:

------- ps: add use_flash_attention keyword -------

def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, vision_model_name: str = None, text_model_name: str = None,
input_resolution: int = None, use_flash_attention: bool = False):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

with open(model_path, 'rb') as opened_file:
    # loading saved checkpoint
    checkpoint = torch.load(opened_file, map_location="cpu")

model = create_model(model_name, checkpoint, use_flash_attention=use_flash_attention)
if str(device) == "cpu":
    model.float()
else:
    model.to(device)
return model, image_transform(model_input_resolution)
------- ps: convert flash_attention weight -------

def create_model(model_name, checkpoint=None, use_flash_attention=False):
vision_model, text_model = model_name.split('@')
# Initialize the model.
vision_model_config_file = Path(
file).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)

text_model_config_file = Path(
    __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)

with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
    model_info = json.load(fv)
    for k, v in json.load(ft).items():
        model_info[k] = v
if isinstance(model_info['vision_layers'], str):
    model_info['vision_layers'] = eval(model_info['vision_layers'])
print('Model info', model_info)
if use_flash_attention:
    model_info['use_flash_attention'] = use_flash_attention
model = CLIP(**model_info)
convert_weights(model)
        
if checkpoint:
    if use_flash_attention:
        sd = checkpoint["state_dict"]
        sd = {k: v for k, v in sd.items() if "bert.pooler" not in k}
        if next(iter(sd.items()))[0].startswith('module'):
            sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
        # Resize the positional embedding by interpolation, if needed
        resize_pos_embed(sd, model, prefix="module.")
        # Adapt flash attention
        sd = convert_state_dict(sd)
        # Load the state dict
    else:
        sd = checkpoint["state_dict"]
        if next(iter(sd.items()))[0].startswith('module'):
            sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
    model.load_state_dict(sd)
return model
oxiaedzo

oxiaedzo1#

您好,目前在启动flash-attn训练时,保存的ckpt格式与不启动是完全一致的。因此用flash-attn训练得到的ckpt应该是直接可以load进来的,您可以先尝试一下。

Chinese-CLIP/cn_clip/training/train.py
第309行
| | "state_dict": model.state_dict() if not args.use_flash_attention else convert_state_dict(model.state_dict()), |

zed5wv10

zed5wv102#

感谢您的回复,但是您似乎误解了我的意思。
我想做的事情是,在我自己编写的代码段中,直接调用load_from_name函数来获取模型,并且该模型具有直接切换为flash-attn模式的功能。然而,目前的load_from_name方法并没有提供flash-attn的选项。

a6b3iqyw

a6b3iqyw3#

我明白你的意思~我理解目前代码中定义的flash-attn格式只适用于Chinese-CLIP这个项目,而Chinese-CLIP训练得到的模型会自动将flash-attn模型转化为正常模式,所以我想知道目前是在什么情况下需要加载一个flash-attn格式的模型呢。

6qftjkof

6qftjkof4#

@DtYXs

例如,我想将您训练好的chinese-clip应用于其他下游任务。

我可能会有一个针对该下游任务的baseline代码,如果我想更换backbone,希望通过调用load_from_name函数创建一个clip的backbone。如果我进一步想要微调clip,我认为添加一个flash-attn可以更好地帮助我加速代码。这样~

也就是说,我将您的仓库视为一个包来使用,那么实际上我只需要注意load_from_name这个函数。如果有flash-attn的支持,可能会帮助更多人将其应用于下游任务中?

相关问题