SD 的 U-Net 既用到了自注意力,也用到了交叉注意力。自注意力用于图像特征自己内部信息聚合。交叉注意力用于让生成图像对齐文本,其 Q 来自图像特征,K, V 来自文本编码。
由于自注意力其实可以看成一种特殊的交叉注意力,我们可以把自注意力的 K, V 替换成来自另一幅参考图像的特征。这样,扩散模型的生成图片会既和原本要生成的图像相似,又和参考图像相似。当然,用来替换的特征必须和原来的特征「格式一致」,不然就生成不了有意义的结果了。
什么叫「格式一致」呢?我们知道,扩散模型在采样时有很多步,U-Net 中又有许多自注意力层。每一步时的每一个自注意力层的输入都有自己的「格式」。也就是说,如果你要把某时刻某自注意力层的 K, V 替换,就得先生成参考图像,用生成参考图像过程中此时刻此自注意力层的输入替换,而不能用其他时刻或者其他自注意力层的。
attn_processor_dict = {} for k in unet.attn_processors.keys(): if we_want_to_modify(k): attn_processor_dict[k] = MyAttnProcessor() else: attn_processor_dict[k] = AttnProcessor()
unet.set_attn_processor(attn_processor_dict)
实现帧间注意力处理类
熟悉了
AttentionProcessor
类的相关内容,我们来编写自己的帧间注意力处理类。在处理第一帧时,该类的行为不变。对于之后的每一帧,该类的 K, V 输入会被替换成视频第一帧和上一帧的输入在序列长度维度上的拼接结果,即:
你是否会感到疑惑:为什么 K, V 的序列长度可以修改?别忘了,在注意力计算中,Q, K, V 的形状分别是:
。注意力计算只要求 K,V 的序列长度 相同,并没有要求 Q, K 的序列长度相同。
if encoder_hidden_states is None: # Is self attention cross_map = torch.cat( (self.first_maps[t], self.prev_maps[t]), dim=1) res = super().__call__(attn, hidden_states, cross_map, **kwargs)
else: # Is cross attention res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
if encoder_hidden_states is None: # Is self attention if self.state == FIRST_FRAME: res = super().__call__(attn, hidden_states, cross_map, **kwargs) # update maps else: cross_map = torch.cat( (self.first_maps[t], self.prev_maps[t]), dim=1) res = super().__call__(attn, hidden_states, cross_map, **kwargs) # update maps
else: # Is cross attention res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
return resedit(frames[0]) set_attn_state(SUBSEQUENT_FRAMES) for i in range(1, len(frames)): edit(frames[i])edit(frames[0]) set_attn_state(SUBSEQUENT_FRAMES) for i in range(1, len(frames)): edit(frames[i])
if encoder_hidden_states is None: # Is self attention
if self.attn_state.state == AttnState.STORE: res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs) else: cross_map = torch.cat( (self.first_maps[t], self.prev_maps[t]), dim=1) res = super().__call__(attn, hidden_states, cross_map, **kwargs) else: # Is cross attention res = super().__call__(attn, hidden_states, encoder_hidden_states, **kwargs)
return res
到目前为止,假设已经维护好了之前的输入,我们的注意力处理类能执行两种不同的行为了。现在,我们来实现之前输入的维护。使用之前的注意力输入时,我们其实需要知道当前的时刻
t
。当前的时刻也算是另一个状态,最好是也在状态管理类里维护。但为了简化我们的代码,我们可以偷懒让每个处理类自己维护当前时刻。具体做法是:如果知道了去噪迭代的总时刻数,我们就可以令当前时刻从0开始不断自增,直到最大时刻时,再重置为0。加入了时刻处理及之前输入维护的完整代码如下: