Update benchmarks/benchmark.py
Browse files- benchmarks/benchmark.py +104 -104
benchmarks/benchmark.py
CHANGED
|
@@ -216,107 +216,107 @@ class FlashMLACausalBenchmark(Benchmark):
|
|
| 216 |
return _verify_mla_decode(self, causal=True)
|
| 217 |
|
| 218 |
|
| 219 |
-
class FlashMLAVarlenBenchmark(Benchmark):
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
| 216 |
return _verify_mla_decode(self, causal=True)
|
| 217 |
|
| 218 |
|
| 219 |
+
# class FlashMLAVarlenBenchmark(Benchmark):
|
| 220 |
+
# seed: int = 42
|
| 221 |
+
|
| 222 |
+
# # Workload: small (3 sequences, max_seqlen=64)
|
| 223 |
+
# def setup_small(self):
|
| 224 |
+
# H, D = 8, 64
|
| 225 |
+
# seqlens = [32, 48, 64]
|
| 226 |
+
# total = sum(seqlens)
|
| 227 |
+
# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 228 |
+
# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 229 |
+
# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 230 |
+
# self.cu_seqlens = torch.tensor(
|
| 231 |
+
# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
|
| 232 |
+
# device="cuda",
|
| 233 |
+
# dtype=torch.int32,
|
| 234 |
+
# )
|
| 235 |
+
# self.max_seqlen = max(seqlens)
|
| 236 |
+
# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 237 |
+
|
| 238 |
+
# def benchmark_small(self):
|
| 239 |
+
# self.out = _extract_output(
|
| 240 |
+
# self.kernel.flash_attn_varlen_func(
|
| 241 |
+
# self.q,
|
| 242 |
+
# self.k,
|
| 243 |
+
# self.v,
|
| 244 |
+
# self.cu_seqlens,
|
| 245 |
+
# self.cu_seqlens,
|
| 246 |
+
# self.max_seqlen,
|
| 247 |
+
# self.max_seqlen,
|
| 248 |
+
# )
|
| 249 |
+
# )
|
| 250 |
+
|
| 251 |
+
# def verify_small(self) -> torch.Tensor:
|
| 252 |
+
# return _varlen_reference_attention(
|
| 253 |
+
# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
|
| 254 |
+
# )
|
| 255 |
+
|
| 256 |
+
# # Workload: medium (5 sequences, max_seqlen=256)
|
| 257 |
+
# def setup_medium(self):
|
| 258 |
+
# H, D = 16, 64
|
| 259 |
+
# seqlens = [128, 192, 256, 200, 150]
|
| 260 |
+
# total = sum(seqlens)
|
| 261 |
+
# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 262 |
+
# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 263 |
+
# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 264 |
+
# self.cu_seqlens = torch.tensor(
|
| 265 |
+
# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
|
| 266 |
+
# device="cuda",
|
| 267 |
+
# dtype=torch.int32,
|
| 268 |
+
# )
|
| 269 |
+
# self.max_seqlen = max(seqlens)
|
| 270 |
+
# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 271 |
+
|
| 272 |
+
# def benchmark_medium(self):
|
| 273 |
+
# self.out = _extract_output(
|
| 274 |
+
# self.kernel.flash_attn_varlen_func(
|
| 275 |
+
# self.q,
|
| 276 |
+
# self.k,
|
| 277 |
+
# self.v,
|
| 278 |
+
# self.cu_seqlens,
|
| 279 |
+
# self.cu_seqlens,
|
| 280 |
+
# self.max_seqlen,
|
| 281 |
+
# self.max_seqlen,
|
| 282 |
+
# )
|
| 283 |
+
# )
|
| 284 |
+
|
| 285 |
+
# def verify_medium(self) -> torch.Tensor:
|
| 286 |
+
# return _varlen_reference_attention(
|
| 287 |
+
# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
|
| 288 |
+
# )
|
| 289 |
+
|
| 290 |
+
# # Workload: large (8 sequences, max_seqlen=512)
|
| 291 |
+
# def setup_large(self):
|
| 292 |
+
# H, D = 32, 128
|
| 293 |
+
# seqlens = [256, 384, 512, 448, 320, 480, 400, 512]
|
| 294 |
+
# total = sum(seqlens)
|
| 295 |
+
# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 296 |
+
# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 297 |
+
# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 298 |
+
# self.cu_seqlens = torch.tensor(
|
| 299 |
+
# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
|
| 300 |
+
# device="cuda",
|
| 301 |
+
# dtype=torch.int32,
|
| 302 |
+
# )
|
| 303 |
+
# self.max_seqlen = max(seqlens)
|
| 304 |
+
# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
|
| 305 |
+
|
| 306 |
+
# def benchmark_large(self):
|
| 307 |
+
# self.out = _extract_output(
|
| 308 |
+
# self.kernel.flash_attn_varlen_func(
|
| 309 |
+
# self.q,
|
| 310 |
+
# self.k,
|
| 311 |
+
# self.v,
|
| 312 |
+
# self.cu_seqlens,
|
| 313 |
+
# self.cu_seqlens,
|
| 314 |
+
# self.max_seqlen,
|
| 315 |
+
# self.max_seqlen,
|
| 316 |
+
# )
|
| 317 |
+
# )
|
| 318 |
+
|
| 319 |
+
# def verify_large(self) -> torch.Tensor:
|
| 320 |
+
# return _varlen_reference_attention(
|
| 321 |
+
# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
|
| 322 |
+
# )
|