o
    ge                    @   s   d dl Z d dlmZmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlmZ e eZG dd deZG d	d
 d
eZdS )    N)OptionalUnion)FusionAttention)Fusion)FunctionProto	NodeProtoTensorProtohelpernumpy_helper)	OnnxModelc                       s   e Zd ZdZdededef fddZ							dd	ed
edededededededededede	e
 deedf fddZdd Zdd Zdd Z  ZS )FusionRotaryAttentionze
    Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
    modelhidden_size	num_headsc                    s   t  j|||dg dd d S )NT)SimplifiedLayerNormalization SkipSimplifiedLayerNormalizationLayerNormalizationSkipLayerNormalizationAdd)use_multi_head_attentionsearch_op_types)super__init__)selfr   r   r   	__class__ g/var/www/visachat/venv/lib/python3.10/site-packages/onnxruntime/transformers/fusion_rotary_attention.pyr      s   
zFusionRotaryAttention.__init__ Ninputoutputq_rotaryk_rotaryv_matmul	attn_maskadd_qkpast_kpast_v	present_k	present_vscalereturnc                 C   s  | j dksJ | jdkr#| j| j  dkr#td| j d| j   d S | jd}|jd |jd |jd d||||	g}|g}|
rJ|rJ||
|g tj	d|||d}d|_
|jtd| j g |d urq|jtd	|g | jd ur|jtd
t| jg | d |S )Nr   z)fuse_rotary_attention: input hidden size z# is not a multiple of num of heads MultiHeadAttentionr   inputsoutputsnamecom.microsoftr   r*   mask_filter_value)r   r   loggerdebugr   create_node_namer    extendr	   	make_nodedomain	attributemake_attributer2   floatincrease_counter)r   r   r    r!   r"   r#   r$   r%   r&   r'   r(   r)   r*   mha_node_name
mha_inputsmha_outputsmha_noder   r   r   create_mha_node)   sB   

z%FusionRotaryAttention.create_mha_nodec	           1      C   s  | j |dgdg}	| j |dgdg}
|	d u s|
d u rdS |	d |
d }}| j |g dg d}| j |g dg d}| j |g dg d}| j |g dg d}|d u sg|d u sg|d u sg|d u ridS |\}}}|\}}}|jd |ks|jd |krdS |d j|jks|d j|jkrdS | j |dgdg}| j |dgdg}|d u s|d u rdS |d |d }}| j |g d	g d
}| j |g dg d}| j |g dg d}| j |g dg d}|d u s|d u s|d u s|d u rdS |d j|jks"|d j|jks"|d j|jks"|d j|jkr$dS | j |dgdg}|d u r5dS |d }| j |g d	g d
} | j |g dg d}!| d u s[|!d u r]dS | d j|jkso|!d j|jkrqdS | j |dgdg}"|"d u rdS |"d }#| j |#g d	g d
}$| j |#g dg d}%|$d u s|%d u rdS |$d j|jks|%d j|jkrdS |$d }&| d }'|d }(|jd })|&jd |)ks|'jd |)ks|(jd |)krdS | j |g dg d}*| j |g dg d}+|*d ur|*\}},}-n|+d ur|+\}}},}-ndS |-jd dvr$dS | j |,g dg d}.| j |-g dg d}/| j |-dgdg}0|.d u sU|/d u sU|0d u rWdS |.d j|/d jksm|.d j|/d jkrodS |/d jd |0d jd krdS dS )NConcat   Fr   	UnsqueezeGatherShaper   r   r   rC   r   r   )   r   r   )rE   MulrF   rG   r   r   r   r   )rE   r   rF   rG   rC   r   r   r   rJ   )rJ   r   r   r   rB   SlicerO   CastrB   rO   rO   >   r$   attention_mask)rJ   r   rC   r   rE   T)r   match_parent_pathr   r0   r    )1r   reshape_qkv_2reshape_qkv_1reshape_q_2reshape_k_2reshape_v_2reshape_v_1r%   
root_inputconcat_qkv_2_pathconcat_qkv_1_pathconcat_qkv_2concat_qkv_1reshape_qkv_2_path_1reshape_qkv_2_path_2reshape_qkv_1_path_1reshape_qkv_1_path_2_gather_1shape_1gather_2shape_2concat_v_2_pathconcat_v_1_path
concat_v_2
concat_v_1reshape_v_2_path_1reshape_v_2_path_2reshape_v_1_path_1reshape_v_1_path_2concat_k_2_path
concat_k_2reshape_k_2_path_1reshape_k_2_path_2concat_q_2_path
concat_q_2reshape_q_2_path_1reshape_q_2_path_2mul_qmul_kmul_vgather_1_outattn_mask_path_1attn_mask_path_2
slice_qk_2
slice_qk_1slice_qk_2_pathslice_qk_1_path_1slice_qk_1_path_2r   r   r   &check_runtime_shape_paths_for_functiona   s   

 
$
$
0

,z<FusionRotaryAttention.check_runtime_shape_paths_for_functionc                 C   s\  | j |dgdg}|d u rdS |d }| j |g dg d}| j |g dg d}	|d u s4|	d u r6dS |\}
}}|	\}
}}|jd |ksN|jd |krPdS | j |dgdg}|d u r`dS |d }| j |g dg d}| j |g dg d}|d u s|d u rdS |d j|jks|d j|jkrdS | j |dgdg}|d u rdS |d }| j |g dg d}| j |g dg d}|d u s|d u rdS |d j|jks|d j|jkrdS | j |dgdg}|d u rdS |d }| j |g dg d}| j |g dg d}|d u s|d u rdS |d j|jks*|d j|jkr,dS dS )	NrB   rC   Fr   rD   rH   rI   T)r   rS   r   r0   )r   reshape_qkv	reshape_q	reshape_k	reshape_vrZ   concat_qkv_path
concat_qkvreshape_qkv_path_1reshape_qkv_path_2rc   rd   re   rf   rg   concat_v_pathconcat_vreshape_v_path_1reshape_v_path_2concat_k_pathconcat_kreshape_k_path_1reshape_k_path_2concat_q_pathconcat_qreshape_q_path_1reshape_q_path_2r   r   r   #check_runtime_shape_paths_for_nodes   sV   	

  $z9FusionRotaryAttention.check_runtime_shape_paths_for_nodesc           W         sh  |j dvrd S d } j|g dg d} j|g dg d} j|g dg d}|d ur;|\}}	}}
}|}n"|d urH|\}}}}|}n|d urV|\}}}}}|}ntd d S d\}}}d }d } j|g d	g d
} j|g dg d} j|g dg d} jj|g dg dfg dg dfg dg dfg dg dfg dg dfg dg dfg dg dfg dg dfg dg dfg	d d\}}} j|g d g d!}|d ur|\}}}}}}|} j|d"d#gd$d%g}|d u rtd& d S |d$ jd$ }|d' jd$ }|jd$ }nq|d ur4|\}}}}|}|jd$ }|jd$ }nY|d urF|\}}}|}|jd$ }nG|d urkt|d(krk|d$ d)d  \}}}}|}|jd$ }|jd$ }n"|d ur|\}}}}}|}|}|jd$ }|jd$ }ntd* d S  j|g d+g d,}d-\}} |d ur|\}}}} ntd. d S d/\}!}" j|g d0g d}# j|g d1g d}$ j|g d2g d3}% j|g d4g d5}& j|g d6g d7}' j|g d8g d3}( j|g d9g d:})|#d ur|#\}}*}+|*jd$ }!nb|$d ur)|$\}}}*}+|*jd$ }!nQ|%d ur9 	|%d$ jd$ }"nA|&d urI 	|&d$ jd$ }"n1|'d urV|'d$ jd$ }"n$|(d urc|(d$ jd$ }"n|)d urs 	|)d$ jd$ }"ntd; d S d/\},}-d }.d }/d }0 j| g d<g d
}1 j| g d=g d}2 j| g d>g d?}3 jj| g d@g d:fg dAg dBfg dCg dDfg dEg dFfg dGg dHfg dIg dJfg dKg dLfg dIg dMfg dIg dNfg	d d\}}4} j| g dOg dP}5|1d urA|1\}6}}7}}8}9|1}. j|7d"d#gd$d%g}:|:d u r&tdQ d S |:d$ jd$ },|:d' jd$ };|7jd$ }-||;ks@J n|2d urU|2\}}8}}<}9|2}.|8jd$ }-nk|3d uro|3\}}7}8}}<}9|3}.|7jd$ },|7jd$ }-nQ|4d urt|4d(kr|4d$ dRd  \}<}9|4d$ dSdT \}7}8|4}.|7jd$ },|7jd$ }-n$|5d ur|5\	}}7}0}8}/}}<}}9|5}.|7jd$ },|7jd$ }-ntdU d S d }=d }>d }? j| g dVg d,}@ j| g dWg d,}A j| g dXg dY}B|@d ur|@\}C}}D}E|@}=n&|Ad ur|A\}D}}F}E|A}=n|Bd ur|B\}?}D}>}}F}}E|B}=ntdZ d S |Ejd$ |9jd$ kr;|9jd$ |jd$ kr;td[ d S d\}G||kr_ 
|	|
|C|6||||Ejd$ sYtd] d S |	jd$ }GnX|||fv r ||F|<||Ejd$ sztd] d S |jd$ }G|>r|>jd$ n|Ejd$ |Djd$< |/r|/jd$ n|9jd$ |8jd$< |?d u r|8jd^ |8jd$< ||kr|d_d  } fd`da}H|?r|0r jdb}I|Id^ }Jtjdb|0jd$ g|Jg|Idc}K|Kjtddg deg  jdb}L|Ld^ }Mtjdb|?jd$ g|Mg|Ldc}N|Njtddg deg |H|<}O|Od u rtdf d S  jjdgdhdi}Ptjdg|Kjd$ |Ojd$ g|Pd^ g|Pdc}Q jjdgdjdi}Rtjdg|Njd$ |Ojd$ g|Rd^ g|Rdc}S|Q}8|S}D j|O  j|K  j|N  j|Q  j|S  j j|Oj<  j j|Kj<  j j|Nj<  j j|Qj<  j j|Sj<  |Ejd$ |G|D|8||!|"|,||-|}T|Td u rtdk d S  j|T  j j|Tj<  j|d_d   ||kr j|d u r|d d' n|d dR  n|d$ d' g}U|D ]	}V |V|U q j| |.|1kr j|.d dR  nw|.|2kr1 j|.d$   j|.d%   j|.dl  nY|.|3krW j|.d$   j|.d_   j|.dl   j|.dm  n3|.|5krm j|.d$   j|.d_  n|.|4kr|.d$ d' |.d$ d) g}U|.D ]	}V |V|U q|=|@kr j|=d dR  n|=|Akr j|=d_   j|=d%  dn _d S )oN>   r   r   r   )MatMulReshape	Transposer   r   rC   r   r   r   r   )r   r   r   r   rM   )	AllReducer   r   r   r   z0fuse_rotary_attention: failed to match qkv nodes)r   r   r   )r   r   rB   r   r   r   )rC   r   r   rC   r   r   )rB   r   r   r   )rC   rC   r   r   )r   r   r   rI   )r   ExpandrE   rB   r   r   r   )rC   r   r   r   rC   r   r   )r   r   WhereEqualr   rB   rE   rF   rG   rB   r   r   r   )rC   r   rC   r   r   r   r   r   r   r   rC   r   r   )r   r   r   r   rK   ConstantOfShaperG   r   rB   rE   rF   rG   rB   r   r   r   )rC   r   rC   r   rC   r   r   r   r   rC   r   r   r   rC   r   r   )r   r   r   r   rG   r   rB   rE   rF   rG   rB   r   r   r   )rC   r   rC   rC   r   r   r      r   r   r   rC   r   r   )r   r   r   r   rB   rE   rF   rG   rB   r   r   r   )rC   r   rC   rJ   r      r   r   r   rC   r   r   )	r   rB   rE   rF   rG   rB   r   r   r   )	rC   rC   r   r   r   r   rC   r   r   )
r   rB   rE   rK   rF   rG   rB   r   r   r   )
rC   rC   rC   r   r   r   r   rC   r   r   )	rC   rC   rJ   r   r   r   rC   r   r   )	rC   rC   r   r   r   r   rC   r   r   )output_name_to_node)rB   r   r   r   r   )rC   rC   r   r   rC   rO   rE   r   rJ   zDfuse_rotary_attention: failed to match past/present concat in v path	   z-fuse_rotary_attention: failed to match v path)Softmaxr   Divr   rL   NNz/fuse_rotary_attention: failed to match qk nodes)r   r   rN   rP   )r   r   SubrQ   r   rE   rE   )rC   r   rJ   rC   r   r   r   )r   r   rQ   r   rE   rE   )rC   rJ   rC   r   r   r   )r   r   r   r   rQ   r   rE   rE   )rC   r   r   rJ   rC   r   r   r   )r   r   r   rQ   r   rE   rE   )	r   rQ   r   rQ   r   rQ   r   rE   rE   )	rC   r   r   r   r   rC   r   r   r   z;fuse_rotary_attention: failed to match attention mask nodes)r   r   rB   r   RotaryEmbeddingr   )r   r   r   r   r   )r   rB   r   r   r   r   )rC   r   rC   r   r   r   )	r   r   r   rE   rB   r   r   r   r   )r   r   r   r   r   r   rB   rE   rF   rG   rB   r   r   r   r   )rC   r   r   rC   r   r   r   r   r   r   r   rC   r   r   r   )r   r   r   r   r   rK   r   rG   r   rB   rE   rF   rG   rB   r   r   r   r   )rC   r   r   rC   r   rC   r   r   r   r   rC   r   r   r   rC   r   r   r   )r   r   r   r   r   rG   r   rB   rE   rF   rG   rB   r   r   r   r   )rC   r   r   rC   rC   r   r   r   r   r   r   r   rC   r   r   r   )r   r   r   r   r   rB   rE   rF   rG   rB   r   r   r   r   )rC   r   r   rC   rJ   r   r   r   r   r   rC   r   r   r   )r   r   rB   rE   rF   rG   rB   r   r   r   r   )rC   r   rC   r   r   r   r   rC   r   r   r   )r   r   rB   rE   rK   rF   rG   rB   r   r   r   r   )rC   r   rC   rC   r   r   r   r   rC   r   r   r   )rC   r   rC   rJ   r   r   r   rC   r   r   r   )rC   r   rC   r   r   r   r   rC   r   r   r   )	r   rB   rB   r   rO   r   r   r   r   )	rC   r   rC   r   r   r   r   r   rC   zDfuse_rotary_attention: failed to match past/present concat in k pathz.fuse_rotary_attention: failed to match k nodes)r   r   r   r   )r   r   r   r   )rB   r   rO   r   r   r   r   )r   r   r   r   r   r   rC   z.fuse_rotary_attention: failed to match q nodeszKfuse_rotary_attention: failed to find the same root_input for q, k, v pathsr   z;fuse_rotary_attention: failed to verify runtime shape paths	_output_0rC   c           
         s   j | dd}|du rtd dS  j |jd } j |jd }|du s-|du r4td dS |d }|d }|| } j jd	d
d} j |du r] j|t	j
dg|gdd  j jddd}tjd|jd |jd |g|d g|d}	|	jtddg |	S )zDetect num_heads and hidden_size for ONNX model from phi-2
            Args:
                reshape_q (NodeProto): reshape node for q
            Returns:
                hidden_size_concat_node(NodeProto): Concat node to be used by reshape
            rB   rC   NzEfuse_rotary_attention: failed to trace the concat node from reshape_qrJ   r   zMfuse_rotary_attention: failed to get constant nodes of num_heads or head_sizer   Initializerr   name_prefixF)r0   	data_typedimsvalsrawhidden_size_concatoutput_0r-   axis)r   match_parentr3   r4   get_constant_valuer   r5   get_initializeradd_initializerr   INT64r	   r7   r9   r6   r:   )
r   concatnum_head_constant_nodehead_size_constant_nodenum_head_valuehead_size_valuer   hidden_size_initilizerhidden_size_reshape_node_namehidden_size_concat_noder   r   r   create_hidden_size_concat_node  sB   


zBFusionRotaryAttention.fuse.<locals>.create_hidden_size_concat_noder   r-   perm)r   rJ   rC   r   z?fuse_rotary_attention: failed to create hidden_size_concat_noder   concat_k_halfr   concat_q_halfzSfuse_rotary_attention: failed to create multi-head attention with rotary embeddingsr   r   T)op_typer   rS   r3   r4   match_parent_paths_allr   r    lenreshape_add_qkr   r   r0   r5   r	   r7   r9   r6   r:   nodes_to_addappendthis_graph_namenode_name_to_graph_namerA   nodes_to_remove&add_nodes_to_remove_with_nodes_to_keepprune_graph)Wr   normalize_nodeinput_name_to_nodesr   	qkv_nodesqkv_nodes_1qkv_nodes_2qkv_nodes_3rc   rT   rU   
matmul_qkvr   r'   r)   past_seq_lenv_nodesadd_v	v_nodes_1	v_nodes_2	v_nodes_3	v_nodes_4	v_nodes_5rX   r   rY   matmul_vr   transpose_vr   qk_nodesr%   	matmul_qkr$   
add_qk_strattn_mask_nodes_1attn_mask_nodes_2attn_mask_nodes_3attn_mask_nodes_4attn_mask_nodes_5attn_mask_nodes_6attn_mask_nodes_7slice_mask_1slice_mask_2r&   r(   k_nodesslice_kr   	k_nodes_1	k_nodes_2	k_nodes_3	k_nodes_4	k_nodes_5rW   r   rotary_kmatmul_kr   shared_past_seq_lenr   q_nodesslice_qr   	q_nodes_1	q_nodes_2	q_nodes_3rV   rotary_qmatmul_qr   root_outputr   k_transpose_node_namek_tranpose_output_namek_transpose_nodeq_transpose_node_nameq_tranpose_output_nameq_transpose_noder   concat_k_reshape_node_nameconcat_k_reshape_nodeconcat_q_reshape_node_nameconcat_q_reshape_nodenew_nodenodes_to_keep	temp_pathr   r   r   fuseF  sx  


lp





















  %  )














,





  

5






,







zFusionRotaryAttention.fuse)r   r   r   r   r   r   N)__name__
__module____qualname____doc__r   intr   strr   r   r;   r   rA   r   r   r  __classcell__r   r   r   r   r      s^    	


8 Ir   c                
       sh   e Zd Zdef fddZdedefddZdefd	d
Zde	de	de	de	de	f
ddZ
dd Z  ZS )FusionRotaryEmbeddingsr   c                    s*   d| _ t || j | j | j d dg d S )Nr   z.1r   )	base_namer   r   )r   r   r   r   r   r   U  s   $zFusionRotaryEmbeddings.__init__rot_emb_nodefunctionc                    s   g g }}|j D ],}|jdkr4|jg kr4|jd |jv r4|| t|j|jd }||j|  qg }|D ]}|jd j}	| j	
d|	_| j	|	 ||	j q9t||D ]\ }
tt fdd| j	j	jj }|D ]	}t| |
 qoqZ|S )NConstantr   c                    s
    | j v S N)r   )entryextra_outputr   r   <lambda>o  s   
 z?FusionRotaryEmbeddings.reassign_extra_outputs.<locals>.<lambda>)noder   r   r    r   listindexr9   tr   r5   r0   r   zipfiltergraphr   replace_node_input)r   r  r  extra_constantsextra_outputsfn_nodeoutput_indexextra_initializersextra_constantconstant_tensorprotoextra_initializernodes_to_updatenode_to_updater   r#  r   reassign_extra_outputs\  s&   

$
z-FusionRotaryEmbeddings.reassign_extra_outputsr&  c                    s8  | j | j}| j ddgddg}|d ur|\}}ntd d S |jd jd g}tt	fdd| j j j
j}tt	fdd| j j j
j}d	\}	}
t|dkrt|dkr| j |	d u r| j |
d u rt|d jd j }t|d jd j }tj|	tjt|j|  d
}| j || j tj|
tjt|j|  d
}| j || j | j|d |d g ||	|
g j}t|dkrtt	fdd| j j j}t|dksJ | |d  tt	 fdd|}t|dksJ tj | j|||dd}d|_!| j"| |S )Nr   r   r   z.fuse_rotary_embeddings: failed to match MatMulrC   c                       | j d  jd kS )Nr   rJ   r    r   constantr&  r   r   r%        zOFusionRotaryEmbeddings.create_rotary_embeddings_from_function.<locals>.<lambda>c                    r9  )Nr   r   r:  r;  r=  r   r   r%    r>  	cos_cache	sin_cacher0   r   r   r   c                    s   | j  jkS r!  )r0   r   )fnr=  r   r   r%    s    c                    s   |  vS r!  r   )output_name)r/  r   r   r%    s    r.   r/   r0   interleavedr1   )#r   r5   r  rS   r3   r4   r    r   r'  r+  r,  r&  r   r   r
   to_arrayr9   r)  squeezer	   make_tensorr   FLOATshapeflattentolistr   r   r   r6   	functionsr8  r7   r8   r   )r   r&  rotary_emb_node_namematmul_pathreshape_nodematmul_noderotary_emb_inputscos_cache_nodesin_cache_nodecos_cache_namesin_cache_namer@  rA  cos_cache_tensorsin_cache_tensorrotary_emb_outputsfuncrotary_emb_noder   )r/  r&  r   &create_rotary_embeddings_from_functionu  sn   



z=FusionRotaryEmbeddings.create_rotary_embeddings_from_functionrZ   position_ids	cos_slice	sin_slicer    c                    s  | j | j}tt fdd| j j jj}ttfdd| j j jj}d\}	}
t|dkrt|dkr| j |	d u r| j |
d u rt	
|d jd j }t	
|d jd j }|jd }|d d d |d f }|d d d |d f }tj|	tjt|j|  d}| j || j tj|
tjt|j|  d}| j || j | j|d |d g tj| j|||	|
g|g|dd	}d
|_|S )Nc                       | j d  kS Nr   r    r;  )r_  r   r   r%        zLFusionRotaryEmbeddings.create_rotary_embeddings_from_nodes.<locals>.<lambda>c                    ra  rb  rc  r;  )r`  r   r   r%    rd  r?  rC   r   rJ   rB  rE  r1   )r   r5   r  r'  r+  r,  r&  r   r   r
   rG  r9   r)  rH  rK  r	   rI  r   rJ  rL  rM  r   r   r   r6   r7   r8   )r   rZ   r^  r_  r`  r    rO  rT  rU  rV  rW  r@  rA  	head_sizerX  rY  r\  r   )r_  r`  r   #create_rotary_embeddings_from_nodes  sJ   



z:FusionRotaryEmbeddings.create_rotary_embeddings_from_nodesc           %         s  | j |jvr|jdkrd S d  |jdkrct|jdvs"|jd dvr)td d S | |  d u r9td d S | j| t	t
 fdd| jjjj}t|dksVJ | jjjj|d	  np| j|g d
g d}| j|g dg d}|p~|}| j|g dg d}| j|g dg d}	|p|	}
|d u s|
d u rtd d S | j|g dg d}| j|g dg d}|p|}| j|g dg d}| j|g dg d}|p|}|d u s|d u rtd d S |d j|d jks|d j|
d jks|d j|d jks|d j|
d jkr$td d S | j|ddgd	d	g}| j|ddgd	d	g}|p@|}|d u rMtd d S d\}}}| j|g dg d }| j|g d!g d"}| j|g d#g d$}| j|g d%g d&}|d ur|}|d' jd	 }nB|d ur|}|d( jd	 }n3|d ur|}|d' jd	 }|d) jd }n|d ur|}|d( jd	 }|d) jd }ntd* d S d+\}}| j|g dg d,}| j|g d!g d-}| j|g d#g d.}| j|g d%g d/} |d ur|}|d' jd	 }nB|d ur%|}|d( jd	 }n3|d ur;|}|d' jd	 }|d) jd }n| d urQ| }|d( jd	 }|d) jd }ntd* d S |d0kr| j|d) d1gdg}!| j|d) d1gdg}"|!d u s|"d u s|!d	 j|"d	 jkrtd2 d S |"d	 jd	 }ng }!g }"d3\}#}$||kr||ks||kr||kr|d4 j|d4 jks|d j|d jkrtd5 d S no||kr||ks||kr=|| kr=|d j|d jkrtd6 d S | j|d d7d8gdd	g}#| j|d g d9g d:}$|#d u s5|$d u s5| j|#d jd	 d u s5|$d jdkr<td; d S ntd< | |d jd	 ||||jd	   d u r_td d S | |g | |d d  | |d d  | |d d  | |
d d  | |d d  | | | | | |!d d  | |"d d  |#d urt| j|#d	 dkr| |# |$d ur| |$d d  | | j  | j| j j< | j  d=| _d S )>Nr   >   r      rC   >   pospos_idpos_idsposition_idr^  zLfuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding functionz=fuse_rotary_embeddings: failed to create RotaryEmbedding nodec                    s   | j  jd kS rb  )r0   r    r=  r\  r   r   r%    s    z-FusionRotaryEmbeddings.fuse.<locals>.<lambda>r   )rK   rB   NegrO   r   r   )rK   rB   rm  rO   rO   )	rK   rB   rm  rO   rE   r   rF   rG   r   )	rC   r   r   r   rC   r   r   r   r   )	rK   rB   rm  rO   rE   r   rF   rG   rO   z9fuse_rotary_embeddings: failed to match x2 in rotate_half)rK   rB   rO   r   )rC   r   rC   r   )rK   rB   rO   rO   )rK   rB   rO   rE   r   rF   rG   r   )rC   r   rC   rJ   r   r   r   r   )rK   rB   rO   rE   r   rF   rG   rO   z9fuse_rotary_embeddings: failed to match x1 in rotate_halfr   zCfuse_rotary_embeddings: failed to match common input in rotate_halfrK   r   rO   z8fuse_rotary_embeddings: failed to match x in rotate_half)Nr   r   )	rK   rE   rF   Squeezern  rO   rE   rF   rG   )	rC   rC   r   r   r   r   rJ   r   r   )rK   rE   rF   rn  rn  rO   rE   r   )rC   rC   r   r   r   r   rJ   r   )rK   rE   rF   rO   rE   rF   rG   )rC   rC   r   r   rJ   r   r   )rK   rE   rF   rO   rE   r   )rC   rC   r   r   rJ   r   r   r   rJ   z>fuse_rotary_embeddings: failed to match sin path in apply_rope)Nr   )	r   rC   r   r   r   r   rJ   r   r   )r   rC   r   r   r   r   rJ   r   )r   rC   r   r   rJ   r   r   )r   rC   r   r   rJ   r   r   r   zGfuse_rotary_embeddings: failed to match position ids path in apply_roper   r   zdfuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cachezRfuse_rotary_embeddings: failed to match common Add node in sin cache and cos cacherF   rG   )rF   rG   r   rH   zKfuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len pathsz:fuse_rotary_embeddings: failed to match common cache pathsT)r  r   r   r   r3   r4   r]  r   r   r'  r+  r   r,  
value_inforemoverS   r0   find_graph_inputrf  r    add_nodes_to_removeget_childrenr<   r   r   r   r   )%r   r&  r   r   old_shape_inferrotate_half_x2_path_1_1rotate_half_x2_path_1_2rotate_half_x2_path_1rotate_half_x2_path_2_1rotate_half_x2_path_2_2rotate_half_x2_path_2rotate_half_x1_path_1_1rotate_half_x1_path_1_2rotate_half_x1_path_1rotate_half_x1_path_2_1rotate_half_x1_path_2_2rotate_half_x1_path_2x_path_1x_path_2x_pathsin_pathrA  r^  
sin_path_1
sin_path_2
sin_path_3
sin_path_4cos_pathr@  
cos_path_1
cos_path_2
cos_path_3
cos_path_4position_ids_from_sin_pathposition_ids_from_cos_pathpast_seq_len_pathcurr_seq_len_pathr   rl  r   r    s  
























,








$


zFusionRotaryEmbeddings.fuse)r  r  r  r   r   r   r   r8  r]  r  rf  r  r  r   r   r   r   r  T  s     J
8r  )loggingtypingr   r   fusion_attentionr   fusion_baser   onnxr   r   r   r	   r
   
onnx_modelr   	getLoggerr  r3   r   r  r   r   r   r   <module>   s"   
        L