# floating-point graphs
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/JAX-MHA-inf-fp32.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/JAX-MQA-inf-fp32.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --op-kind=1:Multiply,1:Divide --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
# f16 inputs + f32 intermediates + f16 outputs
--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
--reset --in-shapes=1:1x16x32x512+2:1x16x32x512+3:1x16x32x512+5:1x1x32x32,\
                    1:1x16x64x512+2:1x16x64x512+3:1x16x64x512+5:1x1x64x64,\
                    1:1x16x128x512+2:1x16x128x512+3:1x16x128x512+5:1x1x128x128,\
                    1:1x16x256x512+2:1x16x256x512+3:1x16x256x512+5:1x1x256x256,\
                    1:1x16x512x512+2:1x16x512x512+3:1x16x512x512+5:1x1x512x512,\
                    1:1x16x1024x512+2:1x16x1024x512+3:1x16x1024x512+5:1x1x1024x1024,\
                    1:1x16x2048x512+2:1x16x2048x512+3:1x16x2048x512+5:1x1x2048x2048,\
                    1:1x16x1x512+2:1x16x33x512+3:1x16x33x512+5:1x1x1x33,\
                    1:1x16x1x512+2:1x16x65x512+3:1x16x65x512+5:1x1x1x65,\
                    1:1x16x1x512+2:1x16x129x512+3:1x16x129x512+5:1x1x1x129,\
                    1:1x16x1x512+2:1x16x257x512+3:1x16x257x512+5:1x1x1x257,\
                    1:1x16x1x512+2:1x16x513x512+3:1x16x513x512+5:1x1x1x513,\
                    1:1x16x1x512+2:1x16x1025x512+3:1x16x1025x512+5:1x1x1x1025,\
                    1:1x16x1x512+2:1x16x2049x512+3:1x16x2049x512+5:1x1x1x2049
--case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
--reset --dt=1:f16+2:f16+3:f16+4:f16+6:f16+104:f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json

# bf16 inputs + f32 intermediates + bf16 outputs
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json

# int8 graphs
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
--reset --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json
--reset --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json
--reset --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-v-int8-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# Re-written graphs
--reset --dt=f32,bf16,f16 --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --mb=20 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=3:10x16x384x64+4:10x1x64x384+0:10x1x384x64+1:10x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=4:56x12x128x64+5:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=0:56x8x1024x80+1:56x8x77x80+2:56x8x77x80 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=0:20x16x384x64+1:20x16x384x64+8:20x16x384x64+5:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=5:1x1x384x384,5:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:2x16x384x64+1:2x16x384x64+5:2x1x1x384+8:2x16x384x64  --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:32x16x128x64+1:32x16x128x64+5:32x16x128x128+8:32x16x128x64  --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:acbd+1:acbd+8:acbd  --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --op-kind=1:Multiply,1:Divide --in-shapes=3:384,3:384x384,3:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --op-kind=1:Multiply,1:Divide --in-shapes=5:384,5:1x384 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --op-attrs=34107656704:group_shape:1x1x1x32+34107654464:transpose_b:1 --in-shapes=0:1x32x32x128+1:1x32x32x4+2:1x32x32x4 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:qtype:per_channel*axis:3 --in-shapes=1:32+2:1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:group_shape:1x1x128x1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --dt=f32,bf16,f16 --op-attrs=40:axis:-2+41:axis:-1 --in-shapes=1:1x32x128x64+2:1x32x128x64+3:1x32x128x64 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json

# Re-written int8 graphs
--reset --in-shapes=5:4x16x32x256+4:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
--reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json
--reset --in-shapes=5:56x12x128x64+4:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json
--reset --in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1 --op-attrs=0:group_shape:1x1x96x1+8:group_shape:1x1x1x96 --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# phi3-mini-4k-instruct
--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+5:1x1x384x384+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+5:1x1x385x385+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+5:1x1x512x512+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+5:1x1x513x513+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+5:1x1x1024x1024+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+5:1x1x1025x1025+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+99:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+8:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# llama-2-7b-chat
--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+5:1x1x384x384+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+5:1x1x385x385+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+5:1x1x512x512+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+5:1x1x513x513+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+5:1x1x1024x1024+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+5:1x1x1025x1025+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+8:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# 0: key, 2: key zps, 6: value, 8: value zps. Change them to use s8 data type.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
# Change group size to 128. It also affects the shapes of the scales and zps.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128 --in-shapes=1:1x32x1x32+2:1x32x1x32+7:1x32x32x1+8:1x32x32x1 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

# d_qk != d_v
--reset --dt=f32,bf16,f16 --in-shapes=8:1x16x384x32,8:1x16x384x64,8:1x16x384x128 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=3:1x16x384x32,3:1x16x384x64,3:1x16x384x128 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
