感谢你提供的代码实现,它对我的帮助很大。然而,在使用load_from_name
函数时,我发现它并不支持flash-attn
,因此我自己实现了这部分代码。尽管它可以正常运行,但我仍然不确定实现是否正确。如果作者有空,能否帮我检查一下这个实现是否正确呢?如果是正确的,作者可以将我的实现加入到仓库中。非常感谢!
以下是代码片段:
------- ps: add use_flash_attention keyword -------
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, vision_model_name: str = None, text_model_name: str = None,
input_resolution: int = None, use_flash_attention: bool = False):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model(model_name, checkpoint, use_flash_attention=use_flash_attention)
if str(device) == "cpu":
model.float()
else:
model.to(device)
return model, image_transform(model_input_resolution)
------- ps: convert flash_attention weight -------
def create_model(model_name, checkpoint=None, use_flash_attention=False):
vision_model, text_model = model_name.split('@')
# Initialize the model.
vision_model_config_file = Path(
file).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)
text_model_config_file = Path(
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
for k, v in json.load(ft).items():
model_info[k] = v
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
print('Model info', model_info)
if use_flash_attention:
model_info['use_flash_attention'] = use_flash_attention
model = CLIP(**model_info)
convert_weights(model)
if checkpoint:
if use_flash_attention:
sd = checkpoint["state_dict"]
sd = {k: v for k, v in sd.items() if "bert.pooler" not in k}
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
# Resize the positional embedding by interpolation, if needed
resize_pos_embed(sd, model, prefix="module.")
# Adapt flash attention
sd = convert_state_dict(sd)
# Load the state dict
else:
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
return model
4条答案
按热度按时间oxiaedzo1#
您好,目前在启动flash-attn训练时,保存的ckpt格式与不启动是完全一致的。因此用flash-attn训练得到的ckpt应该是直接可以load进来的,您可以先尝试一下。
Chinese-CLIP/cn_clip/training/train.py
第309行
| | "state_dict": model.state_dict() if not args.use_flash_attention else convert_state_dict(model.state_dict()), |
zed5wv102#
感谢您的回复,但是您似乎误解了我的意思。
我想做的事情是,在我自己编写的代码段中,直接调用
load_from_name
函数来获取模型,并且该模型具有直接切换为flash-attn模式的功能。然而,目前的load_from_name
方法并没有提供flash-attn的选项。a6b3iqyw3#
我明白你的意思~我理解目前代码中定义的flash-attn格式只适用于Chinese-CLIP这个项目,而Chinese-CLIP训练得到的模型会自动将flash-attn模型转化为正常模式,所以我想知道目前是在什么情况下需要加载一个flash-attn格式的模型呢。
6qftjkof4#
@DtYXs
例如,我想将您训练好的chinese-clip应用于其他下游任务。
我可能会有一个针对该下游任务的baseline代码,如果我想更换backbone,希望通过调用load_from_name函数创建一个clip的backbone。如果我进一步想要微调clip,我认为添加一个flash-attn可以更好地帮助我加速代码。这样~
也就是说,我将您的仓库视为一个包来使用,那么实际上我只需要注意load_from_name这个函数。如果有flash-attn的支持,可能会帮助更多人将其应用于下游任务中?