o
    gL                     @   sx   d dl mZ d dlmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZmZ d dlmZ eeZG dd	 d	eZdS )
    )	getLogger)TupleUnionN)Fusion)NumpyHelper)	NodeProtoTensorProtohelper)	OnnxModelc                       s*  e Zd ZdZdedededededef fdd	Zd3dededefddZ	dd Z
	
d3dedededeeef fddZdedededededededeedf fddZdedededededededeedf fd d!Zd"d# Zd$d% Zd&d' Zd(d) Zd*d+ Zd,efd-d.Zd/d0 Zd1d2 Z  ZS )4FusionAttentionUnetzB
    Fuse Attention subgraph of UNet into one Attention node.
    modelhidden_size	num_headsis_cross_attentionenable_packed_qkvenable_packed_kvc                    sL   t  ||r
|r
dnddg || _|| _|| _|| _|| _d| _d| _d S )N	AttentionMultiHeadAttentionLayerNormalizationT)	super__init__r   r   r   r   r   num_heads_warninghidden_size_warning)selfr   r   r   r   r   r   	__class__ e/var/www/visachat/venv/lib/python3.10/site-packages/onnxruntime/transformers/fusion_attention_unet.pyr      s   	
zFusionAttentionUnet.__init__F	reshape_q	is_torch2returnc                 C   s   d}|r5| j |d}|r4|jdkr4t|jdkr4| j |jd }t|tjr4t	|j
dgkr4t|}n| j |jd }t|tjrRt	|j
dgkrRt|d }t|tr]|dkr]|S dS )zDetect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
        Returns:
            int: num_heads, or 0 if not found
        r      Concat      )r   
get_parentop_typeleninputget_constant_value
isinstancenpndarraylistshapeint)r   r   r   r   reshape_parentq_shape_valuer   r   r   get_num_heads3   s   	z!FusionAttentionUnet.get_num_headsc                 C   s*   | j |jd }|rt|jd S dS )zDetect hidden_size from LayerNormalization node.
        Args:
            layernorm_node (NodeProto): LayerNormalization node before Q, K and V
        Returns:
            int: hidden_size, or 0 if not found
        r$   r   )r   get_initializerr(   r   to_arrayr.   )r   layernorm_nodelayernorm_biasr   r   r   get_hidden_sizeO   s   z#FusionAttentionUnet.get_hidden_sizer5   c                 C   s   |  ||}|dkr| j}| jdkr*|| jkr*| jr*td| j d| d d| _| |}|dkr6| j}| jdkrS|| jkrS| jrStd| j d| d d| _||fS )aF  Detect num_heads and hidden_size.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
            layernorm_node (NodeProto): LayerNormalization node before Q, K, V
        Returns:
            Tuple[int, int]: num_heads and hidden_size
        r   z--num_heads is z. Detected value is z. Using detected value.Fz--hidden_size is )r2   r   r   loggerwarningr7   r   r   )r   r   r5   r   r   r   r   r   r   get_num_heads_and_hidden_size\   s"   
z1FusionAttentionUnet.get_num_heads_and_hidden_sizeq_matmulk_matmulv_matmulr(   outputNc           %   
   C   s8  | j  }|r/|jd |ks|jd |ks|jd |kr.td|jd |jd |jd  dS n+|jd |ksG|jd |jd ksG|jd |krZtd|jd |jd |jd  dS |dkrq|| dkrqtd| d|  dS | j|jd }	| j|jd }
| j|jd }|	r|
r|sdS |	j}t|	}t|
}t|}td|j	 d	|j	 d
|j	 d|  |r|j	|j	ks|j	|j	krdS |j	d }|dkr||krt
d| d| dtt|j	dd }| jr| jd}|}|}|| }t||||||||||||g||d | }| jjddd}| j|d ||j	d |j	d g|d tjd|jd |d g|d g|d}| j| j|j< | j|d tjdgdd|d|gdd tjd|d |d g|d g|d d}| j| j|j< | j||g | j|||g ntj|||fdd}d| }| jd }| j|d! |||g|d n| jd}| jr|j	|j	krdS |j	d }|j	d }||ksJ |j	d }|j	d }|j	d }||kr||ksJ |}|}|| }t||||||||g||d" | }| jjdd#d}| j|d ||j	d |j	d g|d tjd|jd |d g|d g|d}| j| j|j< | j|d tjdgdd|d"|gdd tjd|d |d g|d$ g|d d}| j| j|j< | j||g | j||g tjd|gtj d%} d| }!| j|d& ||!g| d |r| js||d! |d& g}"n%|d g}"n| js|j!d |j!d |j!d |d& g}"n	|j!d |d$ g}"tj|r| jsd nd|"|g|d}#d'|#_"|#j#t$d(|g |r| jsd)nd*%| jrd+n| jrd,nd-}$| &|$ |#S ).  Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            input (str): input name
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        r   RFor self attention, input hidden state for q and k/v shall be same. Got %s, %s, %sNXFor cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %sinput hidden size # is not a multiple of num of heads r!   qw= kw= vw= hidden_size=Input hidden size (,) is not same as weight dimension of q,k,v (:). Please provide a correct input hidden size or pass in 0r      MatMul
MatMul_QKVname_prefix_weightname	data_typedimsvals_outinputsoutputsrR   _reshape_shape   FrR   rS   rT   rU   rawReshape
_qkv_input_reshape)axisr   _qkv_weightr$   	MatMul_KV	_kv_inputdtype	_qkv_biascom.microsoftr   Attention (self attention)MultiHeadAttention ({})self attention with packed qkvcross attention with packed kvcross attention)'r   r(   r8   debugr   r3   rS   r   r4   r.   
ValueErrorr/   r+   prodr   create_node_namedstackreshapeadd_initializerr	   	make_nodethis_graph_namenode_name_to_graph_namerR   r   INT64nodes_to_addextendnodes_to_removestackr   zerosfloat32r>   domain	attributemake_attributeformatincrease_counter)%r   r;   r<   r=   r   r   r(   r>   is_self_attentionq_weightk_weightv_weight
float_typeqwkwvw
qw_in_sizeqw_out_sizeattention_node_namecnh
qkv_weightmatmul_node_namematmul_nodereshape_nodeqkv_weight_dim
kw_in_size
vw_in_sizekw_out_sizevw_out_size	kv_weightqkv_biasqkv_bias_dimattention_inputsattention_nodecounter_namer   r   r   create_attention_node~   sj  *0	


(
.	




2	
	z)FusionAttentionUnet.create_attention_nodeq_matmul_addk_matmul_addv_matmul_addc           F   
   C   s  | j  }| j|dd}	| j|dd}
| j|dd}| |}|du r'dS |\}}| |}|du r6dS |\}}| |}|du rEdS |\}}|r|	jd |ks`|
jd |ks`|jd |krstd|	jd |
jd |jd  dS |jd |ks|jd |ks|jd |krtd|jd |jd |jd  dS nV|	jd |ks|
jd |jd ks|
jd |krtd|	jd |
jd |jd  dS |jd |ks|jd |jd ks|
jd |krtd|jd |jd |jd  dS |dkr|| dkrtd| d	|  dS | j|	jd
 }| j|
jd
 }| j|jd
 }|r/|r/|s1dS |jdkr>td dS t	
|}t	
|}t	
|}td|j d|j d|j d|  |rP|j|jksr|j|jkrtdS |jd }|dkr||krtd| d| dtt|jd
d }| jrN| jd}|}|}|| } t|||| |||| |||| g||d |  }!| jjddd}"| j|"d tj|!jd |!jd
 g|!d tjd|
jd |"d g|"d g|"d}#| j| j|#j< |jd }$| j|$tjdgdd|| gdd | jjd d!d}%tjd |jd |$g|%d g|%d}&| j| j|&j< | jjd d"d}'tjd |jd |$g|'d g|'d}(| j| j|(j< | jjd d#d})tjd |jd |$g|)d g|)d}*| j| j|*j< | jjd$d%d}+tjd$|&jd |(jd |*jd g|+d g|+d},|,jt d&dg | j| j|,j< |,jd }-| j|-tjdgdd|d |  gdd | jjd d'd}.tjd |,jd |-g|.d g|.d}/| j| j|/j< | jjd(d)d}0tjd(|/jd |#jd g|0d g|0d}1| j| j|1j< |0d }2| j|2tjd*gdd|d| gdd tjd |1jd |2g|d+ g|0d, d}3| j| j|3j< | j!|#|&|(|*|,|/|1|3g | j"|	|
||||g ndS | jd}| j#r|j|jkrcdS |jd }4|jd }5|4|5kstJ |jd
 }|jd
 }6|jd
 }7||7kr|6|7ksJ |4}|}|6| } t|||| |||| g||d- |  }8| jjdd.d}"| j|"d tj|8jd |8jd
 g|8d tjd|
jd |"d g|"d g|"d}#| j| j|#j< |jd }9| j|9tjdgdd|| gdd | jjd d"d}'tjd |jd |9g|'d g|'d}(| j| j|(j< | jjd d#d})tjd |jd |9g|)d g|)d}*| j| j|*j< | jjd$d/d}:tjd$|(jd |*jd g|:d g|:d};|;jt d&dg | j| j|;j< |;jd }<| j|<tjdgdd|d- |  gdd | jjd d0d}=tjd |;jd |<g|=d g|=d}>| j| j|>j< | jjd(d1d}?tjd(|>jd |#jd g|?d g|?d}@| j| j|@j< |?d }2| j|2tjd*gdd|d-| gdd tjd |@jd |2g|d2 g|?d, d}3| j| j|3j< | j!|#|(|*|;|>|@|3g | j"|
|||g ndS tj$d|gtj%d3}Ad| }B| j|d4 tj|Bg|Ad |r7| js1dS |d+ g}Cn| j#s=dS |jd |d2 g}Ctj|rQ| jsQd5nd|C|g|d}Dd6|D_&|Djt d7|g |rp| jspd8nd9'| jrxd:n| j#r~d;nd<}E| (|E |DS )=r?   rL   r   Nr@   z_For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %srA   zeFor cross attention, input hidden state for LoRA q and k/v weights shall be different. Got %s, %s, %srB   rC   r!   
   zBweights are in fp16. Please run fp16 conversion after optimizationrD   rE   rF   rG   rH   rI   rJ   r   rK   rM   rN   rP   rQ   rV   rW   rZ   r#   Fr\   r^   Reshape_LoRA_QReshape_LoRA_KReshape_LoRA_Vr"   Concat_LoRA_QKVra   Reshape_LoRA_QKVAddAdd_Weights_QKVr[   r_   r`   r$   rc   Concat_LoRA_KVReshape_LoRA_KVAdd_Weights_KVrd   re   rg   r   rh   r   ri   rj   rk   rl   rm   ))r   r   match_parentmatch_lora_pathr(   r8   rn   r3   rS   r   r4   r.   ro   r/   r+   rp   r   rq   rr   rs   rt   r   FLOATr	   ru   rv   rw   rR   rx   r>   r   rz   r   ry   r{   r   r}   r~   r   r   r   )Fr   r   r   r   r   r   r(   r>   r   r;   r<   r=   q_lora_nodesq_lora_last_nodeq_lora_matmul_1k_lora_nodesk_lora_last_nodek_lora_matmul_1v_lora_nodesv_lora_last_nodev_lora_matmul_1r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   lora_weight_shape_tensor_nameq_lora_reshape_node_nameq_lora_reshape_nodek_lora_reshape_node_namek_lora_reshape_nodev_lora_reshape_node_namev_lora_reshape_nodeqkv_lora_concat_node_nameqkv_lora_concat_node'reshaped_lora_weights_shape_tensor_nameqkv_lora_reshaped_node_nameqkv_lora_reshaped_nodeadd_weights_node_nameadd_weights_nodeshape_tensor_namer   r   r   r   r   r    kv_lora_weight_shape_tensor_namekv_lora_concat_node_namekv_lora_concat_node*reshaped_kv_lora_weights_shape_tensor_namekv_lora_reshaped_node_namekv_lora_reshaped_nodeadd_kv_weights_node_nameadd_kv_weights_noder   r   r   r   r   r   r   r   create_attention_node_lora~  s  


*0
	



(
.

	






2

	

	z.FusionAttentionUnet.create_attention_node_lorac              	   C   s  |  |||r	d S | j|dd}|d u r | js | j|dd}|d u r&d S |jd }|| }d }|D ]}|jdkr>|} nq3|d u rEd S | ||pP| ||}	|	d ur|	\}
}}}}}}|}| |||
\}}|dkrtt	
d d S | j||||||jd |jd d}|d u rd S n]| ||p| ||}	|	d u rd S |	\}
}}}}}}|}| |||
\}}|dkrt	
d d S | j||||||jd |jd d}|d u rd S | |||
\}}|dkrt	
d d S | j| | j| j|j< | j||g d| _d S )Nr   r   r^   *fuse_attention: failed to detect num_headsr(   r>   T)fuse_a1111_fp16r   r   r   r>   r&   match_qkv_torch1match_qkv_torch2r:   r8   rn   r   match_qkv_torch1_loramatch_qkv_torch2_lorar   ry   appendrv   rw   rR   r{   rz   prune_graph)r   normalize_nodeinput_name_to_nodesoutput_name_to_nodenode_before_layernorm
root_inputchildren_nodesskip_addnode	match_qkvr   reshape_qkvtranspose_qkvr   matmul_qmatmul_kmatmul_vattention_last_nodeq_num_headsq_hidden_sizenew_nodematmul_add_qmatmul_add_kmatmul_add_vr   r   r   fuseR  s   


	
	

zFusionAttentionUnet.fusec              
   C   s|  |j d |kr	dnd}| j|g d|dddddg}|du r!dS |\}}}}}}| j|g dg d}	|	du r@td dS |	\}}}}
| j|g dg d	}|dur\|\}}}n| j|g d
g d}|durs|\}}}}ntd dS | j|g dg d}|du rtd dS |\}}}}| j|g dg d}|du rtd dS |\}}}}}d||||||
fS )z.Match Q, K and V paths exported by PyTorch 1.*r   r!   )r   rL   r^   	Transposer^   rL   Nr^   r   r^   rL   r!   r   r   r   &fuse_attention: failed to match v pathSoftmaxMulrL   r   r   r   r   r   r   rL   r   r   r   r   'fuse_attention: failed to match qk path&fuse_attention: failed to match q path)r   r^   r   r^   rL   r!   r   r   r   r   &fuse_attention: failed to match k pathFr(   r   match_parent_pathr8   rn   )r   r   r   another_input	qkv_nodes_r   r   
matmul_qkvv_nodesr   qk_nodes_softmax_qk_mul_qk	matmul_qk	_add_zeroq_nodes_transpose_qr   r   k_nodesr   r   r   r   r     sF   



z$FusionAttentionUnet.match_qkv_torch1c              	   C   s~  |j d |kr	dnd}| j|g d|ddddg}|du r dS |\}}}}}| j|g dg d}	|	du r>td dS |	\}}}
| j|dd	gddg}|durX|\}}ntd
 dS | j|g dg d}|du rvtd dS |\}}}}| j|g dg d}|du rtd dS |\}}}}| j|g dg d}|du s|d |krtd dS d||||||
fS )z.Match Q, K and V paths exported by PyTorch 2.*r   r!   )r   rL   r^   r   rL   N)r   r^   rL   r!   r   r   r   r   rL   r   )r   r   r^   rL   r   Nr   r   r   r!   Nr   r   r   SqrtDivr  CastSliceShaper   r^   Nr   r!   r   r   r   r   r   z*fuse_attention: failed to match mul_q pathTr  )r   r   r   r  r  r  r   r   r  r  r   r  r	  r  r  mul_qr  r   r   r  _mul_kr   mul_q_nodesr   r   r   r     sL   






z$FusionAttentionUnet.match_qkv_torch2c                 C   s  |j d |kr	dnd}| j|g d|ddddddg}|du r"dS |\}}}}}}}| j|g dg d}	|	du rBtd dS |	\}}}}
| j|g dg d	}|dur^|\}}}n| j|g d
g d}|duru|\}}}}ntd dS | j|g dg d}|du rtd dS |\}}}}| j|g dg d}|du rtd dS |\}}}}}d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*r   r!   )r   r   rL   r^   r   r^   rL   N)r^   r   r^   r   r   +fuse_attention: failed to match LoRA v pathr   r   r   r   ,fuse_attention: failed to match LoRA qk path+fuse_attention: failed to match LoRA q path)r   r^   r   r^   r   r   +fuse_attention: failed to match LoRA k pathFr  )r   r   r   r  r  r  r   r   r  r  r   r  r	  r
  r  r  r  r  r   r   r  r   r   r   r   r     sF   



z)FusionAttentionUnet.match_qkv_torch1_lorac              
   C   s  |j d |kr	dnd}| j|g d|dddddg}|du r!dS |\}}}}}}| j|g dg d}	|	du r@td dS |	\}}}
| j|dd	gddg}|durZ|\}}ntd
 dS | j|g dg d}|du rxtd dS |\}}}}| j|g dg d}|du rtd dS |\}}}}| j|g dg d}|du s|d |krtd dS d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*r   r!   )r   r   rL   r^   r   rL   N)r   r^   r   r  r  r   rL   r  )r   r   r^   r   r  r   r  r!  r  r  r  z/fuse_attention: failed to match LoRA mul_q pathTr  )r   r   r   r  r  r  r   r   r  r  r   r  r	  r  r  r  r  r   r   r  r  r   r  r   r   r   r   A  sL   






z)FusionAttentionUnet.match_qkv_torch2_loraadd_nodec                 C   s   | j |ddgddg}|d ur|\}}||fS | j |g dg d}|d ur1|\}}}||fS | j |g dg d}|d urK|\}}}}||fS d S )NrL   r!   r   )r   rL   rL   r  )r   r   rL   rL   r   )r   r  )r   r"  
lora_nodeslora_matmul_2_nodelora_matmul_1_nodelora_mul_noder  r   r   r   r   t  s2   

z#FusionAttentionUnet.match_lora_pathc              	   C   s  | j |ddgddg}|du r"| j |ddgddg}|du r"dS |\}}|jd }|| }d}	|D ]}
|
jdkr>|
}	 nq3|	du rEdS | ||	}|du rQdS |\}}}}}}| j |dd}| j |dd}| j |dd}|dur|dur| js||krn||kr||ksdS |jd |jd krdS |}| |dp| |d}|dkrt	
d dS | |}| j||||||jd |jd d	}|du rdS | j| | j| j|j< | j||g d| _dS )
zPFuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extensionr  r   r   Nr^   FTr   r   )r   r  r>   r&   match_qkv_a1111r   r   r(   r2   r8   rn   r7   r   ry   r   rv   rw   rR   r{   rz   r   )r   r   r   r   
entry_path_castr   r   r   r   r   r   r   r   r   r   r   r   cast_qcast_kcast_vr   r   r   r   r   r   r   r     sz   

	

	z#FusionAttentionUnet.fuse_a1111_fp16c              
   C   sN  |j d |kr	dnd}| j|g d|dddddg}|du r!dS |\}}}}}}	| j|	g dg d}
|
du r@td dS |
\}}}}| j|	g dg d	}|dur^|\}}}}}ntd
 dS | j|g dg d}|du r|td dS |\}}}}| j|g dg d}|du rtd dS |\}}}}||||||fS )zKMatch Q, K and V paths exported by A1111 (stable diffusion webui) extensionr   r!   )r   rL   r^   r   r^   EinsumNr   r   r   )r  r  r   r   r-  )r   r   r   r   Nr   r   r   r   r  )r   r   r   r  r  r  r   r   reshape_einsum
einsum_qkvr  r   r  r	  	einsum_qkr  r  r   r   r  r   r   r   r   r'    s@   



z#FusionAttentionUnet.match_qkv_a1111)F)__name__
__module____qualname____doc__r
   r/   boolr   r   r2   r7   r   r:   strr   r   r   r   r   r   r   r   r   r   r'  __classcell__r   r   r   r   r      s    

"
	
  
	
   WZ1403
,Nr   )loggingr   typingr   r   numpyr+   fusion_baser   fusion_utilsr   onnxr   r   r	   
onnx_modelr
   r1  r8   r   r   r   r   r   <module>   s   