From 5f3d9ada8dabbd9a449f134141f14546f9ce911e Mon Sep 17 00:00:00 2001 From: Mitchell Wortsman Date: Wed, 29 Mar 2023 06:47:08 +0000 Subject: [PATCH] triton-v1 --- bitsandbytes/nn/__init__.py | 1 + bitsandbytes/nn/triton_based_modules.py | 247 ++++++++++++ bitsandbytes/nn/triton_utils/v0/__init__.py | 0 .../nn/triton_utils/v0/fused_gelu_quantize.py | 190 +++++++++ .../v0/int8_matmul_mixed_dequanitze.py | 276 +++++++++++++ .../v0/int8_matmul_rowwise_dequantize.py | 149 +++++++ .../v0/int8_matmul_rowwise_dequantize_bias.py | 160 ++++++++ .../quantize_columnwise_nogroup_transpose.py | 122 ++++++ .../nn/triton_utils/v0/quantize_global.py | 130 +++++++ .../v0/quantize_rowwise_nogroup.py | 174 +++++++++ tests/triton_tests/attn_decomp.py | 363 ++++++++++++++++++ tests/triton_tests/attn_info_ln.jsonl | 20 + tests/triton_tests/full_matrix_decomp.py | 353 +++++++++++++++++ tests/triton_tests/info.jsonl | 142 +++++++ tests/triton_tests/info_mlp.jsonl | 20 + tests/triton_tests/info_mlp_autocast.jsonl | 20 + tests/triton_tests/info_mlp_autocast_ln.jsonl | 23 ++ tests/triton_tests/make_plot_with_info.py | 137 +++++++ tests/triton_tests/mlp.py | 64 +++ tests/triton_tests/mlp_decomp_autocast.py | 166 ++++++++ tests/triton_tests/mlp_decomp_autocast_ln.py | 165 ++++++++ tests/triton_tests/plot1.pdf | Bin 0 -> 34302 bytes tests/triton_tests/plot1.png | Bin 0 -> 121873 bytes tests/triton_tests/plot2.pdf | Bin 0 -> 16044 bytes tests/triton_tests/plot2.png | Bin 0 -> 51996 bytes tests/triton_tests/plot2.py | 69 ++++ tests/triton_tests/plot3.pdf | Bin 0 -> 20122 bytes tests/triton_tests/plot3.png | Bin 0 -> 58335 bytes tests/triton_tests/plot3.py | 193 ++++++++++ tests/triton_tests/rowwise.py | 43 +++ 30 files changed, 3227 insertions(+) create mode 100644 bitsandbytes/nn/triton_based_modules.py create mode 100644 bitsandbytes/nn/triton_utils/v0/__init__.py create mode 100644 bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py create mode 100644 bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py create mode 100644 bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py create mode 100644 bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py create mode 100644 bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py create mode 100644 bitsandbytes/nn/triton_utils/v0/quantize_global.py create mode 100644 bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py create mode 100644 tests/triton_tests/attn_decomp.py create mode 100644 tests/triton_tests/attn_info_ln.jsonl create mode 100644 tests/triton_tests/full_matrix_decomp.py create mode 100644 tests/triton_tests/info.jsonl create mode 100644 tests/triton_tests/info_mlp.jsonl create mode 100644 tests/triton_tests/info_mlp_autocast.jsonl create mode 100644 tests/triton_tests/info_mlp_autocast_ln.jsonl create mode 100644 tests/triton_tests/make_plot_with_info.py create mode 100644 tests/triton_tests/mlp.py create mode 100644 tests/triton_tests/mlp_decomp_autocast.py create mode 100644 tests/triton_tests/mlp_decomp_autocast_ln.py create mode 100644 tests/triton_tests/plot1.pdf create mode 100644 tests/triton_tests/plot1.png create mode 100644 tests/triton_tests/plot2.pdf create mode 100644 tests/triton_tests/plot2.png create mode 100644 tests/triton_tests/plot2.py create mode 100644 tests/triton_tests/plot3.pdf create mode 100644 tests/triton_tests/plot3.png create mode 100644 tests/triton_tests/plot3.py create mode 100644 tests/triton_tests/rowwise.py diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 8be7674..8e3a598 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -3,3 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed +from .triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py new file mode 100644 index 0000000..9fe0b69 --- /dev/null +++ b/bitsandbytes/nn/triton_based_modules.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +import time + +from .triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup +from .triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose +from .triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias +from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose +from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias +from .triton_utils.v0.fused_gelu_quantize import quantize_rowwise_nogroup_gelu, quantize_rowwise_nogroup_back_gelu + +class _switchback(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + + X = X_3D.view(-1, X_3D.size(-1)) + + ctx.save_for_backward = X, W + X_int8, state_X = quantize_rowwise_nogroup(X) + W_int8, state_W = quantize_rowwise_nogroup(W) + return int8_matmul_rowwise_dequantize_bias( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + X, W = ctx.save_for_backward + + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + if ctx.needs_input_grad[0]: + G_int8, state_G = quantize_rowwise_nogroup(G) + W_int8, state_W = quantize_columnwise_nogroup_transpose(W) + grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + +class SwitchBackLinear(nn.Linear): + + def prepare_for_eval(self): + state_W = self.weight.abs().max(dim=1, keepdim=True)[0] + W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8) + state_W = state_W.squeeze() + + self.register_buffer("W_int8", W_int8) + self.register_buffer("state_W", state_W) + + del self.weight + + def forward(self, x): + if self.training: + return _switchback.apply(x, self.weight, self.bias) + else: + if not hasattr(self, "state_W"): + self.prepare_for_eval() + X = x.view(-1, x.size(-1)) + X_int8, state_X = quantize_rowwise_nogroup(X) + return int8_matmul_rowwise_dequantize_bias( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + + +class _switchback_global(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + + X = X_3D.view(-1, X_3D.size(-1)) + + X_int8, state_X = quantize_rowwise_nogroup(X) + W_int8, state_W = quantize_global(W) + ctx.save_for_backward = X, W + return int8_matmul_mixed_dequanitze_bias( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + X, W = ctx.save_for_backward + if ctx.needs_input_grad[0]: + G_int8, state_G = quantize_rowwise_nogroup(G) + W_int8, state_W = quantize_global_transpose(W) + grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + + + +class SwitchBackGlobalLinear(nn.Linear): + + def prepare_for_eval(self): + state_W = self.weight.abs().max() + W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8) + + self.register_buffer("W_int8", W_int8) + self.register_buffer("state_W", state_W) + + del self.weight + + def forward(self, x): + if self.training: + return _switchback_global.apply(x, self.weight, self.bias) + else: + if not hasattr(self, "state_W"): + self.prepare_for_eval() + X = x.view(-1, x.size(-1)) + X_int8, state_X = quantize_rowwise_nogroup(X) + return int8_matmul_mixed_dequanitze_bias( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + + + + +class LinearFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None): + X = input.view(-1, input.size(-1)) + + ctx.save_for_backward(X, weight, bias) + output = input.matmul(weight.t()) + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output.view(*input.size()[:-1], -1) + + @staticmethod + def backward(ctx, grad_output_3D): + input, weight, bias = ctx.saved_tensors + + grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + +class MyLinear(nn.Linear): + + def forward(self, x): + return LinearFunction.apply(x, self.weight, self.bias) + + + + +class _switchback_mlp(torch.autograd.Function): + + + @staticmethod + def forward(ctx, X_3D, W1, B1, W2, B2): + + X1 = X_3D.view(-1, X_3D.size(-1)) + + X1_int8, state_X1 = quantize_rowwise_nogroup(X1) + W1_int8, state_W1 = quantize_global(W1) + + X2_pre = int8_matmul_mixed_dequanitze_bias( + X1_int8, W1_int8.t(), state_X1, state_W1, B1 + ) + + # X2_v1 = torch.nn.functional.gelu(X2) + # X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1) + X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre) + + W2_int8, state_W2 = quantize_global(W2) + + out = int8_matmul_mixed_dequanitze_bias( + X2_int8, W2_int8.t(), state_X2, state_W2, B2 + ) + + ctx.save_for_backward = X1, W1, X2, X2_pre, W2 + + return out.view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + + G2 = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None + + X1, W1, X2, X2_pre, W2 = ctx.save_for_backward + + G2_int8, state_G2 = quantize_rowwise_nogroup(G2) + W2_int8, state_W2 = quantize_global_transpose(W2) + + G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view( + *G_3D.size()[:-1], -1 + ) + + grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype)) + grad_B2 = G2.sum(dim=0) + + G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre) + + if ctx.needs_input_grad[0]: + + W1_int8, state_W1 = quantize_global_transpose(W1) + grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype)) + if ctx.needs_input_grad[2]: + grad_B1 = G1.sum(dim=0) + + return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2 + + +class SwitchBackGlobalMLP(nn.Module): + + + def __init__(self, dim_in, dim_hidden): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_hidden) + self.linear2 = nn.Linear(dim_hidden, dim_in) + + + def forward(self, x): + return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias) + \ No newline at end of file diff --git a/bitsandbytes/nn/triton_utils/v0/__init__.py b/bitsandbytes/nn/triton_utils/v0/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py b/bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py new file mode 100644 index 0000000..50451cb --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py @@ -0,0 +1,190 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +tl.libdevice + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise_nogroup_gelu( + x_ptr, + output_ptr, + output_maxs, + output_fp16, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + cdf = 0.5 * (1.0 + tl.libdevice.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) + x_new = x * cdf + + tl.store(output_fp16 + offsets, x_new, mask=row_mask) + + abs_x = tl.abs(x_new) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x_new / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise_nogroup_gelu(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise_nogroup_gelu[grid](x, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs, output_fp16 + + + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise_nogroup_back_gelu( + x_ptr, + in_ptr, + output_ptr, + output_maxs, + output_fp16, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x_out = tl.load(x_ptr + offsets, mask=row_mask) + x_in = tl.load(in_ptr + offsets, mask=row_mask) + + cdf = 0.5 * (1.0 + tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in))) + intermediate = tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in)) + dcdf = 0.5 * (0.7978845608 + 0.1070322243 * x_in * x_in) * (1 - intermediate * intermediate) + x = x_out * (cdf + x_in * dcdf) + + tl.store(output_fp16 + offsets, x, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise_nogroup_back_gelu(x: torch.Tensor, y : torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise_nogroup_back_gelu[grid](x, y, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs, output_fp16 + + + +# if __name__ == '__main__': +# torch.manual_seed(0) + +# x = torch.randn(1280, 768).cuda().to(torch.float16) +# out = quantize_rowwise_nogroup(x) + +# x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) +# max2 = x.abs().max(1)[0] + +# print(torch.allclose(out[1], max2)) +# print( (x_real == out[0]).float().mean() ) + +# # for i in range(x.shape[0]): +# # print( (x_real[i, :] == out[0][i, :]).float().mean() ) + +# # print(out[0]) +# # print(x_real) +# # import pdb; pdb.set_trace() +# # print(out[2]) +# # print(out[2][:10]) +# sums = x.sum(dim=0) +# #print(sums[:10]) +# #print( (sums == out[2]).float().mean() ) + +# import pdb; pdb.set_trace() +# # import pdb; pdb.set_trace() +# # exit() + +# # repeat = 16 + +# # for _ in range(8): +# # out = quantize_rowwise_nogroup(x) + +# # triton_graph = torch.cuda.CUDAGraph() +# # with torch.cuda.graph(triton_graph): +# # out = quantize_rowwise_nogroup(x) + +# # triton_graph.replay() + +# # torch.cuda.synchronize() +# # start = time.time() +# # for _ in range(repeat): +# # triton_graph.replay() +# # torch.cuda.synchronize() +# # end = time.time() + +# # print(out[0]) +# # print(out[1]) +# # print(x / x.abs().max(dim=1, keepdim=True)[0]) +# # max1 = out[1] +# # max2 = x.abs().max(1)[0] +# # print(max1, max2) +# # print(torch.allclose(max1, max2)) + +# #print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py new file mode 100644 index 0000000..2ecfcb8 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py @@ -0,0 +1,276 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_mixed_dequanitze(a, b, state_x, state_w): + device = a.device + divfactor = 1. / (127. * 127.) + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c + + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias): + device = a.device + divfactor = 1. / (127. * 127.) + has_bias = 0 if bias is None else 1 + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel_bias[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py new file mode 100644 index 0000000..fa0b516 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py @@ -0,0 +1,149 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_rowwise_dequantize(a, b, state_x, state_w): + divfactor = 1. / (127. * 127.) + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py new file mode 100644 index 0000000..5f524c1 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py @@ -0,0 +1,160 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_rowwise_dequantize_bias(a, b, state_x, state_w, bias): + + #print(bias) + divfactor = 1. / (127. * 127.) + + has_bias = 0 if bias is None else 1 + + if bias is not None: + bias = bias.contiguous() + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py b/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py new file mode 100644 index 0000000..fa3a9a9 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py @@ -0,0 +1,122 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_columnwise_nogroup_transpose( + x_ptr, + output_ptr, + output_maxs, + n_elements, + M : tl.constexpr, N : tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid + p2_arange = tl.arange(0, P2) + p2_arange_mask = p2_arange < M + arange = p2_arange * N + offsets = block_start + arange + x = tl.load(x_ptr + offsets, mask=p2_arange_mask) + abs_x = tl.abs(x) + max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + + new_start = pid * M + new_offsets = new_start + p2_arange + tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_columnwise_nogroup_transpose(x: torch.Tensor): + M, N = x.shape + output = torch.empty(N, M, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(M)))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_columnwise_nogroup_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) + return output, output_maxs + + + +if __name__ == '__main__': + torch.manual_seed(0) + + x = torch.randn(1280, 768).cuda().to(torch.float16) + out = quantize_columnwise_nogroup_transpose(x) + + + x_real = x.t().float() + x_real_int8 = (127. * x_real / x_real.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) + maxs = x_real.abs().max(dim=1, keepdim=True)[0].half() + + #print(out[0][2,:]) + + print((out[0] == x_real_int8).float().mean()) + print((out[1] == maxs[:, 0]).float().mean()) + + # print(out[0]) + # print(out[1]) + + # print(out[0][2,:]) + # print(x_real[2, :]) + + # print((out[0] != x_real).nonzero()) + + #import pdb; pdb.set_trace() + # repeat = 16 + + # for _ in range(8): + # out = quantize_columnwise_nogroup_transpose(x) + + # triton_graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(triton_graph): + # out = quantize_columnwise_nogroup_transpose(x) + + # triton_graph.replay() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # triton_graph.replay() + # torch.cuda.synchronize() + # end = time.time() + + # print(out[0]) + # print(out[1]) + # print(x / x.abs().max(dim=0, keepdim=True)[0]) + # x_real = (127 * (x / x.abs().max(dim=0, keepdim=True)[0])).round().to(torch.int8) + # max1 = out[1] + # max2 = x.abs().max(0)[0] + # print(max1, max2) + # import pdb; pdb.set_trace() + # print(torch.allclose(max1, max2)) + + # print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_global.py b/bitsandbytes/nn/triton_utils/v0/quantize_global.py new file mode 100644 index 0000000..6d23aac --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/quantize_global.py @@ -0,0 +1,130 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), + triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), + + ], + key=['n_elements'] +) +@triton.jit +def _quantize_global( + x_ptr, + absmax_inv_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + output = tl.libdevice.llrint(127. * (x * absmax_inv)) + tl.store(output_ptr + offsets, output, mask=mask) + +def quantize_global(x: torch.Tensor): + absmax = x.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_global[grid](x, absmax_inv, output, n_elements) + return output, absmax + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + + # ... + ], + key=['M', 'N'] +) +@triton.jit +def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + GROUP_M : tl.constexpr): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) + mask = (rm < M)[:, None] & (rn < N)[None, :] + a = tl.load(A, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + + # rematerialize to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + output = tl.libdevice.llrint(127. * (a * absmax_inv)) + + tl.store(B, output, mask=mask) + +def quantize_global_transpose(input): + absmax = input.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + M, N = input.shape + out = torch.empty(N, M, device='cuda', dtype=torch.int8) + + assert out.size(0) == N and out.size(1) == M + assert input.stride(0) == 1 or input.stride(1) == 1 + assert out.stride(0) == 1 or out.stride(1) == 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) + return out, absmax + +if __name__ == '__main__': + + + w = torch.randn(768, 1280).cuda().to(torch.float16) + W_int8, state_w = quantize_global(w) + r_state_w = w.abs().max() + r_W_int8 = ((127 * w.float()) / state_w).round().to(torch.int8) + print((r_W_int8 == W_int8).float().mean()) + + # print(r_W_int8) + # print(W_int8) + exit() + repeat = 16 + + for _ in range(8): + out = quantize_global(w) + + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph): + out = quantize_global(w) + + triton_graph.replay() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + triton_graph.replay() + torch.cuda.synchronize() + end = time.time() + + print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py b/bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py new file mode 100644 index 0000000..7e63f74 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py @@ -0,0 +1,174 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise_nogroup( + x_ptr, + output_ptr, + output_maxs, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise_nogroup(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise_nogroup[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _experimental_quantize_rowwise_nogroup( + x_ptr, + output_ptr, + bias_grad_ptr, + output_maxs, + n_elements, + M: tl.constexpr, N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + P2M: tl.constexpr, +): + pid = tl.program_id(axis=0) + if pid < M: + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + else: + real_pid = pid - M + arange_new = tl.arange(0, P2M) + mask_new = arange_new < M + offsets_new = real_pid + arange_new * N + new_x = tl.load(x_ptr + offsets_new, mask=mask_new) + s = tl.sum(tl.where(mask_new, new_x, 0).to(tl.float32), axis=0) + tl.store(bias_grad_ptr + real_pid, s) + +def experimental_quantize_rowwise_nogroup(x: torch.Tensor): + M, N = x.shape + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + bias_grad = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + P2M = int(2 ** (math.ceil(math.log2(x.shape[0])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0] + x.shape[1],) + _experimental_quantize_rowwise_nogroup[grid](x, output, bias_grad, output_maxs, n_elements, M, N, BLOCK_SIZE=x.shape[1], P2=P2, P2M=P2M) + return output, output_maxs, bias_grad + + +if __name__ == '__main__': + torch.manual_seed(0) + + x = torch.randn(1280, 768).cuda().to(torch.float16) + out = quantize_rowwise_nogroup(x) + + x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) + max2 = x.abs().max(1)[0] + + print(torch.allclose(out[1], max2)) + print( (x_real == out[0]).float().mean() ) + + # for i in range(x.shape[0]): + # print( (x_real[i, :] == out[0][i, :]).float().mean() ) + + # print(out[0]) + # print(x_real) + # import pdb; pdb.set_trace() + # print(out[2]) + # print(out[2][:10]) + sums = x.sum(dim=0) + #print(sums[:10]) + #print( (sums == out[2]).float().mean() ) + + import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + # exit() + + # repeat = 16 + + # for _ in range(8): + # out = quantize_rowwise_nogroup(x) + + # triton_graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(triton_graph): + # out = quantize_rowwise_nogroup(x) + + # triton_graph.replay() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # triton_graph.replay() + # torch.cuda.synchronize() + # end = time.time() + + # print(out[0]) + # print(out[1]) + # print(x / x.abs().max(dim=1, keepdim=True)[0]) + # max1 = out[1] + # max2 = x.abs().max(1)[0] + # print(max1, max2) + # print(torch.allclose(max1, max2)) + + #print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/tests/triton_tests/attn_decomp.py b/tests/triton_tests/attn_decomp.py new file mode 100644 index 0000000..9e8ed28 --- /dev/null +++ b/tests/triton_tests/attn_decomp.py @@ -0,0 +1,363 @@ + +import torch +import json +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +import time + +# class AttentionOld(torch.nn.Module): +# def __init__( +# self, +# dim, +# num_heads=8, +# qkv_bias=True, +# scaled_cosine=False, +# scale_heads=False, +# attn_drop=0., +# proj_drop=0., +# linear_module=torch.nn.Linear, +# ): +# super().__init__() +# self.scaled_cosine = scaled_cosine +# self.scale_heads = scale_heads +# assert dim % num_heads == 0, 'dim should be divisible by num_heads' +# self.num_heads = num_heads +# self.head_dim = dim // num_heads +# self.scale = self.head_dim ** -0.5 + +# self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias) + +# self.attn_drop = torch.nn.Dropout(attn_drop) +# if self.scale_heads: +# self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1))) +# else: +# self.head_scale = None +# self.out_proj = linear_module(dim, dim) +# self.out_drop = torch.nn.Dropout(proj_drop) + +# def forward(self, x, attn_mask = None): +# L, N, C = x.shape + +# q, k, v = self.in_proj_linear(x).chunk(3, dim=-1) + +# q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) +# k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) +# v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + +# q = q * self.scale +# attn = torch.bmm(q, k.transpose(-1, -2)) + +# if attn_mask is not None: +# if attn_mask.dtype == torch.bool: +# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) +# new_attn_mask.masked_fill_(attn_mask, float("-inf")) +# attn_mask = new_attn_mask +# attn += attn_mask + +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) + +# x = torch.bmm(attn, v) +# x = x.transpose(0, 1).reshape(L, N, C) + +# x = self.out_proj(x) +# x = self.out_drop(x) +# return x + +class Attention(torch.nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + attn_drop=0., + proj_drop=0., + linear_module=torch.nn.Linear, + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.ln = torch.nn.LayerNorm(dim) + + self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias) + + self.attn_drop = torch.nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = linear_module(dim, dim) + self.out_drop = torch.nn.Dropout(proj_drop) + + def forward(self, x, attn_mask = None): + q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1) + x = torch.compile(torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)) + x = self.out_proj(x) + return x + +if __name__ == '__main__': + + + for dim in [1024, 1280, 1408, 1664, 2048]: + for batch in [2**14, 2**15, 2**16, 2**17]: + + # if dim != 4096 or batch != 2**17: + # continue + + x1 = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + qu = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + ke = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + + standard = Attention(dim).cuda() + my_standard = Attention(dim, linear_module=MyLinear).cuda() + sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda() + standard_compiled = torch.compile(standard) + ln_model = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + torch.nn.LayerNorm(dim), + ).cuda() + ln_model_compiled = torch.compile( + ln_model + ) + gelu_model = torch.nn.Sequential( + torch.nn.GELU(), + ).cuda() + gelu_model_compiled = torch.compile( + gelu_model + ) + + + print('Model part 2') + + repeat = 32 + + info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim} + + + k = 'attn' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va) + ((2 ** 16) * out_attn).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va) + ((2 ** 16) * out_attn).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'ln' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = ln_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = ln_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'ln_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = ln_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = ln_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'gelu' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = gelu_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = gelu_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'gelu_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = gelu_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = gelu_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + x1.grad.zero_() + + k = 'standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'my_standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + # + # + + x1.grad.zero_() + + + k = 'standard_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + + k = 'sb' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + info_json = json.dumps(info) + + + with open("tests/triton_tests/attn_info_ln.jsonl", "a") as file: + file.write(info_json + "\n") + + + #exit() + + # err_fused = (out_standard - out_fused).abs().mean() + # err_sb = (out_standard - out_sb).abs().mean() + # print('OUT', err_fused, err_sb) + + # err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean() + # err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean() + + # print('GW2', err_fused, err_sb) + + # err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean() + # err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean() + + # print('GW1', err_fused, err_sb) + + # err_fused = (x1.grad - x2.grad).abs().mean() + # err_sb = (x1.grad - x3.grad).abs().mean() + + # print('GX1', err_fused, err_sb) + + # import pdb; pdb.set_trace() + + + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file diff --git a/tests/triton_tests/attn_info_ln.jsonl b/tests/triton_tests/attn_info_ln.jsonl new file mode 100644 index 0000000..c2f239b --- /dev/null +++ b/tests/triton_tests/attn_info_ln.jsonl @@ -0,0 +1,20 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "attn": 2.1414458751678467, "ln": 1.6365647315979004, "ln_compiled": 1.799367368221283, "gelu": 1.0930374264717102, "gelu_compiled": 1.094818115234375, "standard": 4.159651696681976, "my_standard": 4.696495831012726, "standard_compiled": 3.675594925880432, "sb": 4.1465312242507935} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "attn": 4.100345075130463, "ln": 3.1594187021255493, "ln_compiled": 3.437422215938568, "gelu": 2.109348773956299, "gelu_compiled": 2.11450457572937, "standard": 7.706902921199799, "my_standard": 8.799396455287933, "standard_compiled": 6.735652685165405, "sb": 7.66376405954361} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "attn": 7.953710854053497, "ln": 6.236426532268524, "ln_compiled": 6.746955215930939, "gelu": 4.164382815361023, "gelu_compiled": 4.171714186668396, "standard": 14.894917607307434, "my_standard": 17.042435705661774, "standard_compiled": 12.985721230506897, "sb": 14.6140456199646} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "attn": 15.638880431652069, "ln": 12.333884835243225, "ln_compiled": 13.272866606712341, "gelu": 8.228793740272522, "gelu_compiled": 8.243747055530548, "standard": 29.425136744976044, "my_standard": 35.08377820253372, "standard_compiled": 25.69487690925598, "sb": 28.760001063346863} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "attn": 2.627238631248474, "ln": 2.0098239183425903, "ln_compiled": 2.4197474122047424, "gelu": 1.3455823063850403, "gelu_compiled": 1.35069340467453, "standard": 5.554787814617157, "my_standard": 6.2290579080581665, "standard_compiled": 5.132324993610382, "sb": 5.4178386926651} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "attn": 5.0596073269844055, "ln": 3.903590142726898, "ln_compiled": 4.719957709312439, "gelu": 2.6203468441963196, "gelu_compiled": 2.627365291118622, "standard": 10.546617209911346, "my_standard": 11.850126087665558, "standard_compiled": 9.685918688774109, "sb": 10.088451206684113} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "attn": 9.845800697803497, "ln": 7.711298763751984, "ln_compiled": 9.292080998420715, "gelu": 5.172915756702423, "gelu_compiled": 5.180932581424713, "standard": 21.371990442276, "my_standard": 23.921720683574677, "standard_compiled": 19.669152796268463, "sb": 20.267993211746216} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "attn": 19.375711679458618, "ln": 15.333592891693115, "ln_compiled": 18.245264887809753, "gelu": 10.264746844768524, "gelu_compiled": 10.283775627613068, "standard": 41.79700464010239, "my_standard": 45.84744572639465, "standard_compiled": 38.35208714008331, "sb": 38.35364431142807} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "attn": 2.9110386967658997, "ln": 2.1998360753059387, "ln_compiled": 2.581551671028137, "gelu": 1.4731436967849731, "gelu_compiled": 1.478634774684906, "standard": 6.764143705368042, "my_standard": 7.331632077693939, "standard_compiled": 6.24605268239975, "sb": 6.325609982013702} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "attn": 5.542516708374023, "ln": 4.289716482162476, "ln_compiled": 5.065307021141052, "gelu": 2.8742849826812744, "gelu_compiled": 2.882353961467743, "standard": 12.749537825584412, "my_standard": 13.79828155040741, "standard_compiled": 11.728867888450623, "sb": 11.642806231975555} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "attn": 10.80312579870224, "ln": 8.471302688121796, "ln_compiled": 9.96796041727066, "gelu": 5.681410431861877, "gelu_compiled": 5.6905597448349, "standard": 25.19702911376953, "my_standard": 27.226239442825317, "standard_compiled": 23.22910726070404, "sb": 22.682294249534607} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "attn": 21.284908056259155, "ln": 16.85701310634613, "ln_compiled": 19.643358886241913, "gelu": 11.292420327663422, "gelu_compiled": 11.314474046230316, "standard": 50.06787180900574, "my_standard": 54.29378151893616, "standard_compiled": 44.58653926849365, "sb": 45.359253883361816} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "attn": 3.382459282875061, "ln": 2.6206374168395996, "ln_compiled": 2.9666870832443237, "gelu": 1.7263293266296387, "gelu_compiled": 1.7317384481430054, "standard": 8.414775133132935, "my_standard": 9.117811918258667, "standard_compiled": 7.7542513608932495, "sb": 7.70898163318634} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "attn": 6.468378007411957, "ln": 5.125559866428375, "ln_compiled": 5.791269242763519, "gelu": 3.3864825963974, "gelu_compiled": 3.3920034766197205, "standard": 16.016244888305664, "my_standard": 17.25083589553833, "standard_compiled": 14.60808515548706, "sb": 14.347739517688751} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "attn": 12.645229697227478, "ln": 10.13532280921936, "ln_compiled": 11.427387595176697, "gelu": 6.6957250237464905, "gelu_compiled": 6.711684167385101, "standard": 31.792201101779938, "my_standard": 34.31189805269241, "standard_compiled": 29.10037338733673, "sb": 28.3128023147583} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "attn": 24.970605969429016, "ln": 20.182937383651733, "ln_compiled": 22.7489173412323, "gelu": 13.326868414878845, "gelu_compiled": 13.345755636692047, "standard": 63.46555054187775, "my_standard": 70.19880414009094, "standard_compiled": 56.40875548124313, "sb": 56.22846633195877} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "attn": 4.080049693584442, "ln": 3.2655522227287292, "ln_compiled": 3.3329352736473083, "gelu": 2.108432352542877, "gelu_compiled": 2.114713191986084, "standard": 11.370822787284851, "my_standard": 12.234866619110107, "standard_compiled": 10.377615690231323, "sb": 10.209612548351288} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "attn": 7.74645060300827, "ln": 6.418220698833466, "ln_compiled": 6.55733048915863, "gelu": 4.163652658462524, "gelu_compiled": 4.171028733253479, "standard": 21.39316499233246, "my_standard": 23.04024249315262, "standard_compiled": 19.431106746196747, "sb": 18.732361495494843} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "attn": 15.235155820846558, "ln": 12.684382498264313, "ln_compiled": 12.895286083221436, "gelu": 8.228868246078491, "gelu_compiled": 8.242718875408173, "standard": 42.55136102437973, "my_standard": 45.82635313272476, "standard_compiled": 38.663335144519806, "sb": 36.76284849643707} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "attn": 30.24454414844513, "ln": 25.25731921195984, "ln_compiled": 25.67601203918457, "gelu": 16.384944319725037, "gelu_compiled": 16.409948468208313, "standard": 84.26841348409653, "my_standard": 91.10662341117859, "standard_compiled": 76.89539343118668, "sb": 71.73164188861847} diff --git a/tests/triton_tests/full_matrix_decomp.py b/tests/triton_tests/full_matrix_decomp.py new file mode 100644 index 0000000..de37b95 --- /dev/null +++ b/tests/triton_tests/full_matrix_decomp.py @@ -0,0 +1,353 @@ +import json + +import time +import torch +import torch.nn as nn +import bitsandbytes.nn as bnn +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear + +from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup +from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose +from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias +from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias + +# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. +# not that big of an issue. + +def get_time_standard_fwd(k, v): + + x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() + g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() + + ##### time matmul 1 + for _ in range(repeat // 2): + g.t().matmul(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + g.t().matmul(x) + + torch.cuda.synchronize() + end = time.time() + print(f"time {k}: {(end - start) / repeat * 1000:.3f} ms") + return (end - start) / repeat * 1000 + +if __name__ == '__main__': + torch.manual_seed(0) + #for (dim, wm) in [(1024, 4), (1280, 4), (1408, 4.3637), (1664, 4.9231), (2048, 4), (4096, 4), (8096, 4)] + for (dim, wm) in [(1408, 4), (1664, 4),]: + + for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: + #for batch_size in [256*256, 256*512]: + + for switch in [False, True]: + + + # hparams + repeat = 64 + batch_size = batch_size + dim_out = dim * wm + dim_in = dim + if switch: + dim_out = dim + dim_in = wm * dim + + dim_in = round(dim_in) + dim_out = round(dim_out) + + + # simulate forward pass + x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() + g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() + w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() + + x_int8 = x.clone().to(torch.int8) + g_int8 = g.clone().to(torch.int8) + w_int8 = w.clone().to(torch.int8) + wt_int8 = w.t().contiguous().clone().to(torch.int8) + state_x_rowwise = x.max(dim=1)[0] + state_g_rowwise = g.max(dim=1)[0] + state_w_columnwise = w.max(dim=0)[0] + state_w_rowwise = w.max(dim=1)[0] + state_w_global = w.max() + + info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} + + k = 'standard_fwd' + for _ in range(repeat // 2): + x.matmul(w.t()) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + x.matmul(w.t()) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'standard_gw' + for _ in range(repeat // 2): + g.t().matmul(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + g.t().matmul(x) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'standard_gx' + for _ in range(repeat // 2): + g.matmul(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + g.matmul(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + + k = 'rowwise_fwd' + for _ in range(repeat // 2): + int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'rowwise_bwd' + for _ in range(repeat // 2): + int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'global_fwd' + for _ in range(repeat // 2): + int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'global_bwd' + for _ in range(repeat // 2): + int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'x_quantize_rowwise' + for _ in range(repeat // 2): + quantize_rowwise_nogroup(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(x) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'g_quantize_rowwise' + for _ in range(repeat // 2): + quantize_rowwise_nogroup(g) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(g) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'w_quantize_rowwise' + for _ in range(repeat // 2): + quantize_rowwise_nogroup(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'w_quantize_colwise_transpose' + for _ in range(repeat // 2): + quantize_columnwise_nogroup_transpose(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_columnwise_nogroup_transpose(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'w_quantize_global' + for _ in range(repeat // 2): + quantize_global(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_global(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'w_quantize_global_transpose' + for _ in range(repeat // 2): + quantize_global_transpose(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_global_transpose(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'cast_x' + for _ in range(repeat // 2): + newx = x.to(torch.int8) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + newx = x.to(torch.int8) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + + k = 'cast_g' + for _ in range(repeat // 2): + newx = g.to(torch.int8) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + newx = g.to(torch.int8) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + + k = 'cast_w' + for _ in range(repeat // 2): + newx = w.to(torch.int8) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + newx = w.to(torch.int8) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] + time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] + time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + print('TOTAL STANDARD', time_standard) + print('TOTAL ROWWISE', time_rowwise) + print('TOTAL GLOBAL', time_global) + + print('speedup', -100*(time_global - time_standard)/time_standard) + + info['time_standard'] = time_standard + info['time_rowwise'] = time_rowwise + info['time_global'] = time_global + + + + info_json = json.dumps(info) + + + with open("tests/triton_tests/info.jsonl", "a") as file: + file.write(info_json + "\n") \ No newline at end of file diff --git a/tests/triton_tests/info.jsonl b/tests/triton_tests/info.jsonl new file mode 100644 index 0000000..879a65f --- /dev/null +++ b/tests/triton_tests/info.jsonl @@ -0,0 +1,142 @@ +{"repeat": 64, "batch_size": 1024, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.047907233238220215, "standard_gw": 0.04326179623603821, "standard_gx": 0.042986124753952026, "rowwise_fwd": 0.03902614116668701, "rowwise_bwd": 0.038955360651016235, "global_fwd": 0.03974884748458862, "global_bwd": 0.0391639769077301, "x_quantize_rowwise": 0.02619624137878418, "g_quantize_rowwise": 0.02695620059967041, "w_quantize_rowwise": 0.02631545066833496, "w_quantize_colwise_transpose": 0.08677691221237183, "w_quantize_global": 0.07359683513641357, "w_quantize_global_transpose": 0.08226558566093445, "cast_x": 0.007815659046173096, "cast_g": 0.016041100025177002, "cast_w": 0.01600012183189392, "time_standard": 0.13415515422821045, "time_rowwise": 0.28748810291290283, "time_global": 0.33118948340415955} +{"repeat": 64, "batch_size": 1024, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.04236400127410889, "standard_gw": 0.04898756742477417, "standard_gx": 0.04731118679046631, "rowwise_fwd": 0.03933534026145935, "rowwise_bwd": 0.03947317600250244, "global_fwd": 0.03688037395477295, "global_bwd": 0.039167702198028564, "x_quantize_rowwise": 0.02533942461013794, "g_quantize_rowwise": 0.02516806125640869, "w_quantize_rowwise": 0.02528354525566101, "w_quantize_colwise_transpose": 0.0903792679309845, "w_quantize_global": 0.0997595489025116, "w_quantize_global_transpose": 0.10209530591964722, "cast_x": 0.01626834273338318, "cast_g": 0.011973083019256592, "cast_w": 0.016044825315475464, "time_standard": 0.13866275548934937, "time_rowwise": 0.2939663827419281, "time_global": 0.37739798426628113} +{"repeat": 64, "batch_size": 2048, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.07753819227218628, "standard_gw": 0.08026883006095886, "standard_gx": 0.0906921923160553, "rowwise_fwd": 0.0630207359790802, "rowwise_bwd": 0.058263540267944336, "global_fwd": 0.06167963147163391, "global_bwd": 0.05801767110824585, "x_quantize_rowwise": 0.034205615520477295, "g_quantize_rowwise": 0.03341957926750183, "w_quantize_rowwise": 0.03244727849960327, "w_quantize_colwise_transpose": 0.08665025234222412, "w_quantize_global": 0.09483471512794495, "w_quantize_global_transpose": 0.10108202695846558, "cast_x": 0.012032687664031982, "cast_g": 0.03752484917640686, "cast_w": 0.01605972647666931, "time_standard": 0.24849921464920044, "time_rowwise": 0.3882758319377899, "time_global": 0.46350806951522827} +{"repeat": 64, "batch_size": 2048, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.09099021553993225, "standard_gw": 0.0799819827079773, "standard_gx": 0.07644668221473694, "rowwise_fwd": 0.05840510129928589, "rowwise_bwd": 0.06359070539474487, "global_fwd": 0.057831406593322754, "global_bwd": 0.06148591637611389, "x_quantize_rowwise": 0.03434717655181885, "g_quantize_rowwise": 0.03361701965332031, "w_quantize_rowwise": 0.03209337592124939, "w_quantize_colwise_transpose": 0.09028613567352295, "w_quantize_global": 0.0944770872592926, "w_quantize_global_transpose": 0.0994168221950531, "cast_x": 0.03769621253013611, "cast_g": 0.012010335922241211, "cast_w": 0.01600012183189392, "time_standard": 0.24741888046264648, "time_rowwise": 0.39232149720191956, "time_global": 0.4611574113368988} +{"repeat": 64, "batch_size": 4096, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.14450401067733765, "standard_gw": 0.14326348900794983, "standard_gx": 0.14762207865715027, "rowwise_fwd": 0.10525062680244446, "rowwise_bwd": 0.09800493717193604, "global_fwd": 0.10229647159576416, "global_bwd": 0.09718164801597595, "x_quantize_rowwise": 0.03429874777793884, "g_quantize_rowwise": 0.04567950963973999, "w_quantize_rowwise": 0.03365054726600647, "w_quantize_colwise_transpose": 0.08654966950416565, "w_quantize_global": 0.09663775563240051, "w_quantize_global_transpose": 0.10383129119873047, "cast_x": 0.01605972647666931, "cast_g": 0.08305534720420837, "cast_w": 0.01624971628189087, "time_standard": 0.43538957834243774, "time_rowwise": 0.5466975271701813, "time_global": 0.6231889128684998} +{"repeat": 64, "batch_size": 4096, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.14496594667434692, "standard_gw": 0.1412704586982727, "standard_gx": 0.14446303248405457, "rowwise_fwd": 0.10041892528533936, "rowwise_bwd": 0.10674074292182922, "global_fwd": 0.09856373071670532, "global_bwd": 0.10319426655769348, "x_quantize_rowwise": 0.045571476221084595, "g_quantize_rowwise": 0.03273040056228638, "w_quantize_rowwise": 0.033464282751083374, "w_quantize_colwise_transpose": 0.09154900908470154, "w_quantize_global": 0.0964440405368805, "w_quantize_global_transpose": 0.1031048595905304, "cast_x": 0.0835023820400238, "cast_g": 0.016242265701293945, "cast_w": 0.016283243894577026, "time_standard": 0.4306994378566742, "time_rowwise": 0.5517452955245972, "time_global": 0.6208792328834534} +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28106942772865295, "standard_gw": 0.2841465175151825, "standard_gx": 0.301852822303772, "rowwise_fwd": 0.19879266619682312, "rowwise_bwd": 0.16228482127189636, "global_fwd": 0.19488856196403503, "global_bwd": 0.1607760787010193, "x_quantize_rowwise": 0.033974647521972656, "g_quantize_rowwise": 0.08221715688705444, "w_quantize_rowwise": 0.03248825669288635, "w_quantize_colwise_transpose": 0.08646398782730103, "w_quantize_global": 0.0939294695854187, "w_quantize_global_transpose": 0.09895861148834229, "cast_x": 0.03753975033760071, "cast_g": 0.15900656580924988, "cast_w": 0.01603737473487854, "time_standard": 0.8670687675476074, "time_rowwise": 0.8803680539131165, "time_global": 0.9488910436630249} +{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.26415660977363586, "standard_gw": 0.2679601311683655, "standard_gx": 0.30617788434028625, "rowwise_fwd": 0.180121511220932, "rowwise_bwd": 0.21555647253990173, "global_fwd": 0.17506256699562073, "global_bwd": 0.2116672694683075, "x_quantize_rowwise": 0.08289515972137451, "g_quantize_rowwise": 0.033795833587646484, "w_quantize_rowwise": 0.03366544842720032, "w_quantize_colwise_transpose": 0.09965524077415466, "w_quantize_global": 0.09595602750778198, "w_quantize_global_transpose": 0.1024976372718811, "cast_x": 0.1602955162525177, "cast_g": 0.03787502646446228, "cast_w": 0.016216188669204712, "time_standard": 0.8382946252822876, "time_rowwise": 0.9136497974395752, "time_global": 0.9698346257209778} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5719438195228577, "standard_gw": 0.524863600730896, "standard_gx": 0.6005167961120605, "rowwise_fwd": 0.3750324249267578, "rowwise_bwd": 0.28166547417640686, "global_fwd": 0.3674700856208801, "global_bwd": 0.2798214554786682, "x_quantize_rowwise": 0.04655122756958008, "g_quantize_rowwise": 0.1555122435092926, "w_quantize_rowwise": 0.03437697887420654, "w_quantize_colwise_transpose": 0.08634477853775024, "w_quantize_global": 0.09759142994880676, "w_quantize_global_transpose": 0.10081753134727478, "cast_x": 0.0828765332698822, "cast_g": 0.31184032559394836, "cast_w": 0.016063451766967773, "time_standard": 1.6973242163658142, "time_rowwise": 1.5043467283248901, "time_global": 1.5726275742053986} +{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5423910915851593, "standard_gw": 0.5674734711647034, "standard_gx": 0.5907565355300903, "rowwise_fwd": 0.3149174153804779, "rowwise_bwd": 0.3899820148944855, "global_fwd": 0.2909451723098755, "global_bwd": 0.3783814609050751, "x_quantize_rowwise": 0.15584751963615417, "g_quantize_rowwise": 0.04688650369644165, "w_quantize_rowwise": 0.031463801860809326, "w_quantize_colwise_transpose": 0.09072571992874146, "w_quantize_global": 0.09774044156074524, "w_quantize_global_transpose": 0.10405108332633972, "cast_x": 0.3111511468887329, "cast_g": 0.08282437920570374, "cast_w": 0.015992671251296997, "time_standard": 1.700621098279953, "time_rowwise": 1.5972964465618134, "time_global": 1.6413256525993347} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2115389108657837, "standard_gw": 1.1259466409683228, "standard_gx": 1.1027492582798004, "rowwise_fwd": 0.7407031953334808, "rowwise_bwd": 0.5539208650588989, "global_fwd": 0.7214657962322235, "global_bwd": 0.5515590310096741, "x_quantize_rowwise": 0.08765608072280884, "g_quantize_rowwise": 0.3022328019142151, "w_quantize_rowwise": 0.03347545862197876, "w_quantize_colwise_transpose": 0.08694455027580261, "w_quantize_global": 0.09706243872642517, "w_quantize_global_transpose": 0.10102614760398865, "cast_x": 0.1592189073562622, "cast_g": 0.6166175007820129, "cast_w": 0.01607835292816162, "time_standard": 3.440234810113907, "time_rowwise": 2.930879592895508, "time_global": 2.986948937177658} +{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1010989546775818, "standard_gw": 1.1352524161338806, "standard_gx": 1.1676251888275146, "rowwise_fwd": 0.5864761769771576, "rowwise_bwd": 0.7485374808311462, "global_fwd": 0.5547590553760529, "global_bwd": 0.7249303162097931, "x_quantize_rowwise": 0.3021731972694397, "g_quantize_rowwise": 0.08751824498176575, "w_quantize_rowwise": 0.033952295780181885, "w_quantize_colwise_transpose": 0.09011104702949524, "w_quantize_global": 0.09443238377571106, "w_quantize_global_transpose": 0.10376051068305969, "cast_x": 0.6167255342006683, "cast_g": 0.15922263264656067, "cast_w": 0.016070902347564697, "time_standard": 3.403976559638977, "time_rowwise": 2.984020859003067, "time_global": 3.0028261244297028} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.472013235092163, "standard_gw": 2.218998968601227, "standard_gx": 2.2116564214229584, "rowwise_fwd": 1.466125249862671, "rowwise_bwd": 1.0577328503131866, "global_fwd": 1.431729644536972, "global_bwd": 1.0476894676685333, "x_quantize_rowwise": 0.16929209232330322, "g_quantize_rowwise": 0.5952082574367523, "w_quantize_rowwise": 0.032100826501846313, "w_quantize_colwise_transpose": 0.08670613169670105, "w_quantize_global": 0.09590759873390198, "w_quantize_global_transpose": 0.10358169674873352, "cast_x": 0.31175464391708374, "cast_g": 1.2264922261238098, "cast_w": 0.016067177057266235, "time_standard": 6.902668625116348, "time_rowwise": 5.626164376735687, "time_global": 5.662407726049423} +{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.181064337491989, "standard_gw": 2.2256113588809967, "standard_gx": 2.3229196667671204, "rowwise_fwd": 1.0886266827583313, "rowwise_bwd": 1.4654062688350677, "global_fwd": 1.0472461581230164, "global_bwd": 1.433148980140686, "x_quantize_rowwise": 0.5954094231128693, "g_quantize_rowwise": 0.16921386122703552, "w_quantize_rowwise": 0.03442913293838501, "w_quantize_colwise_transpose": 0.09007751941680908, "w_quantize_global": 0.09575113654136658, "w_quantize_global_transpose": 0.10503828525543213, "cast_x": 1.2264810502529144, "cast_g": 0.3119036555290222, "cast_w": 0.01605600118637085, "time_standard": 6.729595363140106, "time_rowwise": 5.668774247169495, "time_global": 5.671419203281403} +{"repeat": 64, "batch_size": 1024, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.08157268166542053, "standard_gw": 0.07601454854011536, "standard_gx": 0.09059160947799683, "rowwise_fwd": 0.053066760301589966, "rowwise_bwd": 0.04787370562553406, "global_fwd": 0.05243346095085144, "global_bwd": 0.04809349775314331, "x_quantize_rowwise": 0.02571195363998413, "g_quantize_rowwise": 0.025898218154907227, "w_quantize_rowwise": 0.02714991569519043, "w_quantize_colwise_transpose": 0.19773468375205994, "w_quantize_global": 0.07273256778717041, "w_quantize_global_transpose": 0.08068978786468506, "cast_x": 0.008046627044677734, "cast_g": 0.0252649188041687, "cast_w": 0.0393986701965332, "time_standard": 0.24817883968353271, "time_rowwise": 0.4534497857093811, "time_global": 0.38157403469085693} +{"repeat": 64, "batch_size": 1024, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.09134411811828613, "standard_gw": 0.07602199912071228, "standard_gx": 0.09555742144584656, "rowwise_fwd": 0.047691166400909424, "rowwise_bwd": 0.05320459604263306, "global_fwd": 0.04759058356285095, "global_bwd": 0.0521540641784668, "x_quantize_rowwise": 0.025313347578048706, "g_quantize_rowwise": 0.025119632482528687, "w_quantize_rowwise": 0.0269375741481781, "w_quantize_colwise_transpose": 0.1857280731201172, "w_quantize_global": 0.07451698184013367, "w_quantize_global_transpose": 0.08009746670722961, "cast_x": 0.02547726035118103, "cast_g": 0.007897615432739258, "cast_w": 0.039536505937576294, "time_standard": 0.26292353868484497, "time_rowwise": 0.44001638889312744, "time_global": 0.3808140754699707} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.940010607242584, "standard_gw": 4.434864968061447, "standard_gx": 4.4097937643527985, "rowwise_fwd": 2.9467344284057617, "rowwise_bwd": 2.09181010723114, "global_fwd": 2.8806477785110474, "global_bwd": 2.0816922187805176, "x_quantize_rowwise": 0.33279508352279663, "g_quantize_rowwise": 1.1817067861557007, "w_quantize_rowwise": 0.03306567668914795, "w_quantize_colwise_transpose": 0.08666515350341797, "w_quantize_global": 0.0957287847995758, "w_quantize_global_transpose": 0.10242313146591187, "cast_x": 0.6165988743305206, "cast_g": 2.446405589580536, "cast_w": 0.016100704669952393, "time_standard": 13.78466933965683, "time_rowwise": 11.107642203569412, "time_global": 11.109858751296997} +{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.293464124202728, "standard_gw": 4.461295902729034, "standard_gx": 4.638340324163437, "rowwise_fwd": 2.116892486810684, "rowwise_bwd": 2.9479674994945526, "global_fwd": 2.0760856568813324, "global_bwd": 2.8755851089954376, "x_quantize_rowwise": 1.1818408966064453, "g_quantize_rowwise": 0.33276528120040894, "w_quantize_rowwise": 0.03287568688392639, "w_quantize_colwise_transpose": 0.09038299322128296, "w_quantize_global": 0.09598955512046814, "w_quantize_global_transpose": 0.100649893283844, "cast_x": 2.4467408657073975, "cast_g": 0.6165951490402222, "cast_w": 0.016082078218460083, "time_standard": 13.3931003510952, "time_rowwise": 11.164020746946335, "time_global": 11.12421229481697} +{"repeat": 64, "batch_size": 2048, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.1699887216091156, "standard_gw": 0.14045089483261108, "standard_gx": 0.17407909035682678, "rowwise_fwd": 0.10082125663757324, "rowwise_bwd": 0.08344277739524841, "global_fwd": 0.09941309690475464, "global_bwd": 0.08352473378181458, "x_quantize_rowwise": 0.025317072868347168, "g_quantize_rowwise": 0.03849714994430542, "w_quantize_rowwise": 0.02596527338027954, "w_quantize_colwise_transpose": 0.19767135381698608, "w_quantize_global": 0.07257238030433655, "w_quantize_global_transpose": 0.08127838373184204, "cast_x": 0.012032687664031982, "cast_g": 0.06345659494400024, "cast_w": 0.03953278064727783, "time_standard": 0.48451870679855347, "time_rowwise": 0.612165778875351, "time_global": 0.5410537123680115} +{"repeat": 64, "batch_size": 2048, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.14855340123176575, "standard_gw": 0.15553459525108337, "standard_gx": 0.16282498836517334, "rowwise_fwd": 0.09259581565856934, "rowwise_bwd": 0.11080875992774963, "global_fwd": 0.09166449308395386, "global_bwd": 0.10796263813972473, "x_quantize_rowwise": 0.03939121961593628, "g_quantize_rowwise": 0.025227665901184082, "w_quantize_rowwise": 0.027202069759368896, "w_quantize_colwise_transpose": 0.1940988004207611, "w_quantize_global": 0.07397681474685669, "w_quantize_global_transpose": 0.08178502321243286, "cast_x": 0.065632164478302, "cast_g": 0.01268833875656128, "cast_w": 0.04057586193084717, "time_standard": 0.46691298484802246, "time_rowwise": 0.6448589265346527, "time_global": 0.5755424499511719} +{"repeat": 64, "batch_size": 4096, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.32291561365127563, "standard_gw": 0.2875030040740967, "standard_gx": 0.3379322588443756, "rowwise_fwd": 0.19295886158943176, "rowwise_bwd": 0.16265735030174255, "global_fwd": 0.19031018018722534, "global_bwd": 0.16187503933906555, "x_quantize_rowwise": 0.02730637788772583, "g_quantize_rowwise": 0.06797909736633301, "w_quantize_rowwise": 0.02642720937728882, "w_quantize_colwise_transpose": 0.19745901226997375, "w_quantize_global": 0.07253512740135193, "w_quantize_global_transpose": 0.08047744631767273, "cast_x": 0.022336840629577637, "cast_g": 0.1209154725074768, "cast_w": 0.039268285036087036, "time_standard": 0.9483508765697479, "time_rowwise": 0.9622909128665924, "time_global": 0.8879862725734711} +{"repeat": 64, "batch_size": 4096, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.3019683063030243, "standard_gw": 0.288400799036026, "standard_gx": 0.3154948353767395, "rowwise_fwd": 0.18264353275299072, "rowwise_bwd": 0.2075284719467163, "global_fwd": 0.17072632908821106, "global_bwd": 0.1960061490535736, "x_quantize_rowwise": 0.06893649697303772, "g_quantize_rowwise": 0.02561509609222412, "w_quantize_rowwise": 0.026594847440719604, "w_quantize_colwise_transpose": 0.18575787544250488, "w_quantize_global": 0.07266923785209656, "w_quantize_global_transpose": 0.08060410618782043, "cast_x": 0.12182071805000305, "cast_g": 0.022590160369873047, "cast_w": 0.04000961780548096, "time_standard": 0.9058639407157898, "time_rowwise": 0.9854771196842194, "time_global": 0.9029582142829895} +{"repeat": 64, "batch_size": 8192, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.6489232182502747, "standard_gw": 0.5987770855426788, "standard_gx": 0.6644465029239655, "rowwise_fwd": 0.35867467522621155, "rowwise_bwd": 0.31855329871177673, "global_fwd": 0.353105366230011, "global_bwd": 0.31349435448646545, "x_quantize_rowwise": 0.03382191061973572, "g_quantize_rowwise": 0.12668967247009277, "w_quantize_rowwise": 0.02681836485862732, "w_quantize_colwise_transpose": 0.19756704568862915, "w_quantize_global": 0.07336586713790894, "w_quantize_global_transpose": 0.08036196231842041, "cast_x": 0.0583939254283905, "cast_g": 0.23520365357398987, "cast_w": 0.03935396671295166, "time_standard": 1.912146806716919, "time_rowwise": 1.660902053117752, "time_global": 1.579616218805313} +{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.5789436399936676, "standard_gw": 0.6130896508693695, "standard_gx": 0.6558857858181, "rowwise_fwd": 0.3464221954345703, "rowwise_bwd": 0.3650560975074768, "global_fwd": 0.3174394369125366, "global_bwd": 0.35758689045906067, "x_quantize_rowwise": 0.12686848640441895, "g_quantize_rowwise": 0.034302473068237305, "w_quantize_rowwise": 0.02745911478996277, "w_quantize_colwise_transpose": 0.1847483217716217, "w_quantize_global": 0.07192790508270264, "w_quantize_global_transpose": 0.08050352334976196, "cast_x": 0.23534893989562988, "cast_g": 0.05846098065376282, "cast_w": 0.03949552774429321, "time_standard": 1.847919076681137, "time_rowwise": 1.6979463398456573, "time_global": 1.6017183661460876} +{"repeat": 64, "batch_size": 1024, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.0573769211769104, "standard_gw": 0.061042606830596924, "standard_gx": 0.0783093273639679, "rowwise_fwd": 0.046797096729278564, "rowwise_bwd": 0.04620850086212158, "global_fwd": 0.04521384835243225, "global_bwd": 0.04425644874572754, "x_quantize_rowwise": 0.03257766366004944, "g_quantize_rowwise": 0.03449246287345886, "w_quantize_rowwise": 0.033657997846603394, "w_quantize_colwise_transpose": 0.1426301896572113, "w_quantize_global": 0.09257346391677856, "w_quantize_global_transpose": 0.10266527533531189, "cast_x": 0.011991709470748901, "cast_g": 0.020314007997512817, "cast_w": 0.027321279048919678, "time_standard": 0.19672885537147522, "time_rowwise": 0.39740651845932007, "time_global": 0.41282176971435547} +{"repeat": 64, "batch_size": 1024, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.07858872413635254, "standard_gw": 0.06122514605522156, "standard_gx": 0.05758553743362427, "rowwise_fwd": 0.04598498344421387, "rowwise_bwd": 0.04618242383003235, "global_fwd": 0.04597380757331848, "global_bwd": 0.046450644731521606, "x_quantize_rowwise": 0.03332272171974182, "g_quantize_rowwise": 0.033274292945861816, "w_quantize_rowwise": 0.0337548553943634, "w_quantize_colwise_transpose": 0.14807656407356262, "w_quantize_global": 0.09948387742042542, "w_quantize_global_transpose": 0.10120868682861328, "cast_x": 0.020120292901992798, "cast_g": 0.011488795280456543, "cast_w": 0.027466565370559692, "time_standard": 0.19739940762519836, "time_rowwise": 0.40182098746299744, "time_global": 0.420939177274704} +{"repeat": 64, "batch_size": 16384, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 1.3515166938304901, "standard_gw": 1.1536777019500732, "standard_gx": 1.224767416715622, "rowwise_fwd": 0.6912238895893097, "rowwise_bwd": 0.5562454462051392, "global_fwd": 0.67867711186409, "global_bwd": 0.5518943071365356, "x_quantize_rowwise": 0.06204098463058472, "g_quantize_rowwise": 0.24417787790298462, "w_quantize_rowwise": 0.025238841772079468, "w_quantize_colwise_transpose": 0.19756704568862915, "w_quantize_global": 0.07240846753120422, "w_quantize_global_transpose": 0.08046254515647888, "cast_x": 0.11138245463371277, "cast_g": 0.4637613892555237, "cast_w": 0.03935769200325012, "time_standard": 3.7299618124961853, "time_rowwise": 2.9301717877388, "time_global": 2.8433389961719513} +{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 1.2090615928173065, "standard_gw": 1.1396333575248718, "standard_gx": 1.2223869562149048, "rowwise_fwd": 0.5849376320838928, "rowwise_bwd": 0.6985403597354889, "global_fwd": 0.5565173923969269, "global_bwd": 0.6789751350879669, "x_quantize_rowwise": 0.2445802092552185, "g_quantize_rowwise": 0.06200745701789856, "w_quantize_rowwise": 0.027727335691452026, "w_quantize_colwise_transpose": 0.18501654267311096, "w_quantize_global": 0.07182732224464417, "w_quantize_global_transpose": 0.08069723844528198, "cast_x": 0.4638172686100006, "cast_g": 0.11136755347251892, "cast_w": 0.039517879486083984, "time_standard": 3.571081906557083, "time_rowwise": 2.9424428939819336, "time_global": 2.834238111972809} +{"repeat": 64, "batch_size": 32768, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 2.683013677597046, "standard_gw": 2.2987723350524902, "standard_gx": 2.4510622024536133, "rowwise_fwd": 1.359008252620697, "rowwise_bwd": 1.1018887162208557, "global_fwd": 1.3311207294464111, "global_bwd": 1.0954029858112335, "x_quantize_rowwise": 0.11804327368736267, "g_quantize_rowwise": 0.479232519865036, "w_quantize_rowwise": 0.026308000087738037, "w_quantize_colwise_transpose": 0.1975223422050476, "w_quantize_global": 0.07223710417747498, "w_quantize_global_transpose": 0.08019432425498962, "cast_x": 0.2161264419555664, "cast_g": 0.9207837283611298, "cast_w": 0.03929063677787781, "time_standard": 7.432848215103149, "time_rowwise": 5.580775439739227, "time_global": 5.475003272294998} +{"repeat": 64, "batch_size": 2048, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.11088326573371887, "standard_gw": 0.10994821786880493, "standard_gx": 0.12367218732833862, "rowwise_fwd": 0.07392093539237976, "rowwise_bwd": 0.07127970457077026, "global_fwd": 0.0730752944946289, "global_bwd": 0.07089227437973022, "x_quantize_rowwise": 0.03361701965332031, "g_quantize_rowwise": 0.03525242209434509, "w_quantize_rowwise": 0.03341585397720337, "w_quantize_colwise_transpose": 0.14318525791168213, "w_quantize_global": 0.09704753756523132, "w_quantize_global_transpose": 0.10221078991889954, "cast_x": 0.012002885341644287, "cast_g": 0.05240738391876221, "cast_w": 0.027313828468322754, "time_standard": 0.3445036709308624, "time_rowwise": 0.5006194114685059, "time_global": 0.5220435559749603} +{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 2.4625882506370544, "standard_gw": 2.421922981739044, "standard_gx": 2.380847930908203, "rowwise_fwd": 1.1231191456317902, "rowwise_bwd": 1.360483467578888, "global_fwd": 1.0947436094284058, "global_bwd": 1.3314113020896912, "x_quantize_rowwise": 0.4795975983142853, "g_quantize_rowwise": 0.11777132749557495, "w_quantize_rowwise": 0.02699345350265503, "w_quantize_colwise_transpose": 0.18484890460968018, "w_quantize_global": 0.07201358675956726, "w_quantize_global_transpose": 0.0803135335445404, "cast_x": 0.920858234167099, "cast_g": 0.21616369485855103, "cast_w": 0.03937259316444397, "time_standard": 7.265359163284302, "time_rowwise": 5.714736878871918, "time_global": 5.597773939371109} +{"repeat": 64, "batch_size": 2048, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.12437254190444946, "standard_gw": 0.11018291115760803, "standard_gx": 0.10970607399940491, "rowwise_fwd": 0.07167831063270569, "rowwise_bwd": 0.07583573460578918, "global_fwd": 0.07314234972000122, "global_bwd": 0.07501617074012756, "x_quantize_rowwise": 0.035624951124191284, "g_quantize_rowwise": 0.0333636999130249, "w_quantize_rowwise": 0.03264099359512329, "w_quantize_colwise_transpose": 0.14795735478401184, "w_quantize_global": 0.09621679782867432, "w_quantize_global_transpose": 0.10380148887634277, "cast_x": 0.05278363823890686, "cast_g": 0.01249462366104126, "cast_w": 0.02767890691757202, "time_standard": 0.3442615270614624, "time_rowwise": 0.5072839558124542, "time_global": 0.5273483693599701} +{"repeat": 64, "batch_size": 4096, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.21922588348388672, "standard_gw": 0.20731613039970398, "standard_gx": 0.23101642727851868, "rowwise_fwd": 0.1423358917236328, "rowwise_bwd": 0.1195073127746582, "global_fwd": 0.1401938498020172, "global_bwd": 0.11940300464630127, "x_quantize_rowwise": 0.03353878855705261, "g_quantize_rowwise": 0.06387382745742798, "w_quantize_rowwise": 0.03428757190704346, "w_quantize_colwise_transpose": 0.14376267790794373, "w_quantize_global": 0.09389594197273254, "w_quantize_global_transpose": 0.10196119546890259, "cast_x": 0.020060688257217407, "cast_g": 0.10236725211143494, "cast_w": 0.02732500433921814, "time_standard": 0.6575584411621094, "time_rowwise": 0.7446222007274628, "time_global": 0.7601827383041382} +{"repeat": 64, "batch_size": 4096, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.20026043057441711, "standard_gw": 0.21172687411308289, "standard_gx": 0.2276189625263214, "rowwise_fwd": 0.12956932187080383, "rowwise_bwd": 0.15310943126678467, "global_fwd": 0.12427568435668945, "global_bwd": 0.14432892203330994, "x_quantize_rowwise": 0.06471946835517883, "g_quantize_rowwise": 0.03309175372123718, "w_quantize_rowwise": 0.03242120146751404, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.09280815720558167, "w_quantize_global_transpose": 0.10265037417411804, "cast_x": 0.10267645120620728, "cast_g": 0.020150095224380493, "cast_w": 0.027399510145187378, "time_standard": 0.6396062672138214, "time_rowwise": 0.7719770073890686, "time_global": 0.773601233959198} +{"repeat": 64, "batch_size": 65536, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 5.324859172105789, "standard_gw": 4.977177828550339, "standard_gx": 4.468705505132675, "rowwise_fwd": 2.7004145085811615, "rowwise_bwd": 2.121664583683014, "global_fwd": 2.648312598466873, "global_bwd": 2.111390233039856, "x_quantize_rowwise": 0.22934377193450928, "g_quantize_rowwise": 0.9496547281742096, "w_quantize_rowwise": 0.02555176615715027, "w_quantize_colwise_transpose": 0.1977868378162384, "w_quantize_global": 0.0727437436580658, "w_quantize_global_transpose": 0.08098781108856201, "cast_x": 0.4259459674358368, "cast_g": 1.8352754414081573, "cast_w": 0.039637088775634766, "time_standard": 14.770742505788803, "time_rowwise": 11.201594024896622, "time_global": 11.069610714912415} +{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.49151480197906494, "standard_gw": 0.4681535065174103, "standard_gx": 0.42366236448287964, "rowwise_fwd": 0.2766512334346771, "rowwise_bwd": 0.2083033323287964, "global_fwd": 0.2709813416004181, "global_bwd": 0.20718947052955627, "x_quantize_rowwise": 0.034555792808532715, "g_quantize_rowwise": 0.11969730257987976, "w_quantize_rowwise": 0.03300607204437256, "w_quantize_colwise_transpose": 0.14345720410346985, "w_quantize_global": 0.09280070662498474, "w_quantize_global_transpose": 0.10214745998382568, "cast_x": 0.052288174629211426, "cast_g": 0.19747763872146606, "cast_w": 0.027339905500411987, "time_standard": 1.3833306729793549, "time_rowwise": 1.2838244438171387, "time_global": 1.2955255806446075} +{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.39635971188545227, "standard_gw": 0.44353678822517395, "standard_gx": 0.4724152386188507, "rowwise_fwd": 0.22813305258750916, "rowwise_bwd": 0.2868436276912689, "global_fwd": 0.2119205892086029, "global_bwd": 0.2749413251876831, "x_quantize_rowwise": 0.12082979083061218, "g_quantize_rowwise": 0.03444403409957886, "w_quantize_rowwise": 0.03444403409957886, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09495392441749573, "w_quantize_global_transpose": 0.1009330153465271, "cast_x": 0.19745156168937683, "cast_g": 0.05227327346801758, "cast_w": 0.027336180210113525, "time_standard": 1.312311738729477, "time_rowwise": 1.294981688261032, "time_global": 1.2815594673156738} +{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0207034647464752, "standard_gw": 0.897720456123352, "standard_gx": 0.8374936878681183, "rowwise_fwd": 0.5457103252410889, "rowwise_bwd": 0.4088357090950012, "global_fwd": 0.5308091640472412, "global_bwd": 0.40555745363235474, "x_quantize_rowwise": 0.05984678864479065, "g_quantize_rowwise": 0.2306811511516571, "w_quantize_rowwise": 0.0334717333316803, "w_quantize_colwise_transpose": 0.14356523752212524, "w_quantize_global": 0.09340420365333557, "w_quantize_global_transpose": 0.09996071457862854, "cast_x": 0.10207295417785645, "cast_g": 0.3880411386489868, "cast_w": 0.027671456336975098, "time_standard": 2.7559176087379456, "time_rowwise": 2.3198314011096954, "time_global": 2.31797993183136} +{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 4.502948373556137, "standard_gw": 4.418112337589264, "standard_gx": 4.748217761516571, "rowwise_fwd": 2.1329298615455627, "rowwise_bwd": 2.6968345046043396, "global_fwd": 2.102244645357132, "global_bwd": 2.6461556553840637, "x_quantize_rowwise": 0.9493157267570496, "g_quantize_rowwise": 0.2290569245815277, "w_quantize_rowwise": 0.02551451325416565, "w_quantize_colwise_transpose": 0.18491223454475403, "w_quantize_global": 0.07426366209983826, "w_quantize_global_transpose": 0.08058920502662659, "cast_x": 1.8352717161178589, "cast_g": 0.425681471824646, "cast_w": 0.039402395486831665, "time_standard": 13.669278472661972, "time_rowwise": 10.636676102876663, "time_global": 10.499738156795502} +{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8179470896720886, "standard_gw": 0.8687414228916168, "standard_gx": 0.9276494383811951, "rowwise_fwd": 0.4481859505176544, "rowwise_bwd": 0.5557462573051453, "global_fwd": 0.4100687801837921, "global_bwd": 0.5317367613315582, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.05963817238807678, "w_quantize_rowwise": 0.033523887395858765, "w_quantize_colwise_transpose": 0.14462321996688843, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10088086128234863, "cast_x": 0.3879927098751068, "cast_g": 0.10205060243606567, "cast_w": 0.02714991569519043, "time_standard": 2.6143379509449005, "time_rowwise": 2.3406408727169037, "time_global": 2.295881509780884} +{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0698904991149902, "standard_gw": 1.7200261354446411, "standard_gx": 1.663345843553543, "rowwise_fwd": 1.0664835572242737, "rowwise_bwd": 0.8059032261371613, "global_fwd": 1.0454729199409485, "global_bwd": 0.801432877779007, "x_quantize_rowwise": 0.1127384603023529, "g_quantize_rowwise": 0.4529319703578949, "w_quantize_rowwise": 0.03398582339286804, "w_quantize_colwise_transpose": 0.14343857765197754, "w_quantize_global": 0.09441003203392029, "w_quantize_global_transpose": 0.09993091225624084, "cast_x": 0.19744038581848145, "cast_g": 0.769149512052536, "cast_w": 0.02734735608100891, "time_standard": 5.453262478113174, "time_rowwise": 4.335507750511169, "time_global": 4.3269433081150055} +{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.758193761110306, "standard_gw": 1.6880109906196594, "standard_gx": 1.8163062632083893, "rowwise_fwd": 0.8343160152435303, "rowwise_bwd": 1.073598861694336, "global_fwd": 0.8045099675655365, "global_bwd": 1.0492689907550812, "x_quantize_rowwise": 0.453021377325058, "g_quantize_rowwise": 0.11304020881652832, "w_quantize_rowwise": 0.0337064266204834, "w_quantize_colwise_transpose": 0.1452416181564331, "w_quantize_global": 0.09451434016227722, "w_quantize_global_transpose": 0.0998079776763916, "cast_x": 0.769101083278656, "cast_g": 0.19731372594833374, "cast_w": 0.027332454919815063, "time_standard": 6.2625110149383545, "time_rowwise": 4.340935498476028, "time_global": 4.302173852920532} +{"repeat": 64, "batch_size": 131072, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 10.728541761636734, "standard_gw": 9.228862822055817, "standard_gx": 8.837487548589706, "rowwise_fwd": 5.4414160549640656, "rowwise_bwd": 4.186157137155533, "global_fwd": 5.329187959432602, "global_bwd": 4.150416702032089, "x_quantize_rowwise": 0.4517659544944763, "g_quantize_rowwise": 1.890372484922409, "w_quantize_rowwise": 0.027563422918319702, "w_quantize_colwise_transpose": 0.1980513334274292, "w_quantize_global": 0.0733695924282074, "w_quantize_global_transpose": 0.08009746670722961, "cast_x": 0.8449330925941467, "cast_g": 3.6641769111156464, "cast_w": 0.03945454955101013, "time_standard": 28.794892132282257, "time_rowwise": 21.42418920993805, "time_global": 21.20407298207283} +{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.127204418182373, "standard_gw": 3.359321504831314, "standard_gx": 5.557261407375336, "rowwise_fwd": 2.1365806460380554, "rowwise_bwd": 1.6042962670326233, "global_fwd": 2.0923763513565063, "global_bwd": 1.5939176082611084, "x_quantize_rowwise": 0.21954253315925598, "g_quantize_rowwise": 0.8971206843852997, "w_quantize_rowwise": 0.03357976675033569, "w_quantize_colwise_transpose": 0.1431293785572052, "w_quantize_global": 0.10574981570243835, "w_quantize_global_transpose": 0.10281801223754883, "cast_x": 0.38795173168182373, "cast_g": 1.5318207442760468, "cast_w": 0.027142465114593506, "time_standard": 13.043787330389023, "time_rowwise": 8.39357078075409, "time_global": 8.370846509933472} +{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.576469004154205, "standard_gw": 3.361724317073822, "standard_gx": 3.6300085484981537, "rowwise_fwd": 1.6183294355869293, "rowwise_bwd": 2.1462254226207733, "global_fwd": 1.5953555703163147, "global_bwd": 2.0915642380714417, "x_quantize_rowwise": 0.8973218500614166, "g_quantize_rowwise": 0.2197064459323883, "w_quantize_rowwise": 0.03402307629585266, "w_quantize_colwise_transpose": 0.14822185039520264, "w_quantize_global": 0.09706616401672363, "w_quantize_global_transpose": 0.10339170694351196, "cast_x": 1.5312805771827698, "cast_g": 0.3879964351654053, "cast_w": 0.0269375741481781, "time_standard": 12.568201869726181, "time_rowwise": 8.425552397966385, "time_global": 8.366130292415619} +{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 8.900497108697891, "standard_gw": 9.188394993543625, "standard_gx": 9.503517299890518, "rowwise_fwd": 4.189815372228622, "rowwise_bwd": 5.426768213510513, "global_fwd": 4.155576229095459, "global_bwd": 5.329132080078125, "x_quantize_rowwise": 1.8885880708694458, "g_quantize_rowwise": 0.45193731784820557, "w_quantize_rowwise": 0.025987625122070312, "w_quantize_colwise_transpose": 0.1842118799686432, "w_quantize_global": 0.07349997758865356, "w_quantize_global_transpose": 0.08074194192886353, "cast_x": 3.6639943718910217, "cast_g": 0.8447282016277313, "cast_w": 0.03973767161369324, "time_standard": 27.592409402132034, "time_rowwise": 21.355703473091125, "time_global": 21.167870610952377} +{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.2329623401165, "standard_gw": 6.799045950174332, "standard_gx": 6.893906742334366, "rowwise_fwd": 4.252739250659943, "rowwise_bwd": 3.2025352120399475, "global_fwd": 4.176046699285507, "global_bwd": 3.173377364873886, "x_quantize_rowwise": 0.43221935629844666, "g_quantize_rowwise": 1.7872042953968048, "w_quantize_rowwise": 0.03328174352645874, "w_quantize_colwise_transpose": 0.1431480050086975, "w_quantize_global": 0.09707733988761902, "w_quantize_global_transpose": 0.10161846876144409, "cast_x": 0.7692091166973114, "cast_g": 3.057178109884262, "cast_w": 0.027302652597427368, "time_standard": 21.9259150326252, "time_rowwise": 16.65017381310463, "time_global": 16.56658947467804} +{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.278409510850906, "standard_gw": 6.815284490585327, "standard_gx": 7.280956953763962, "rowwise_fwd": 3.206692636013031, "rowwise_bwd": 4.246953874826431, "global_fwd": 3.1801797449588776, "global_bwd": 4.169579595327377, "x_quantize_rowwise": 1.7862766981124878, "g_quantize_rowwise": 0.4329495131969452, "w_quantize_rowwise": 0.03413483500480652, "w_quantize_colwise_transpose": 0.14493241906166077, "w_quantize_global": 0.09881332516670227, "w_quantize_global_transpose": 0.10376423597335815, "cast_x": 3.057088702917099, "cast_g": 0.7693544030189514, "cast_w": 0.027261674404144287, "time_standard": 25.374650955200195, "time_rowwise": 16.66722446680069, "time_global": 16.586847603321075} +{"repeat": 64, "batch_size": 1024, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.11636316776275635, "standard_gw": 0.11816620826721191, "standard_gx": 0.11482089757919312, "rowwise_fwd": 0.08482113480567932, "rowwise_bwd": 0.06284937262535095, "global_fwd": 0.08296221494674683, "global_bwd": 0.061664730310440063, "x_quantize_rowwise": 0.026706606149673462, "g_quantize_rowwise": 0.025641173124313354, "w_quantize_rowwise": 0.03740563988685608, "w_quantize_colwise_transpose": 0.2965778112411499, "w_quantize_global": 0.11304393410682678, "w_quantize_global_transpose": 0.12390688061714172, "cast_x": 0.008635222911834717, "cast_g": 0.037532299757003784, "cast_w": 0.06856024265289307, "time_standard": 0.3493502736091614, "time_rowwise": 0.652167946100235, "time_global": 0.5520917475223541} +{"repeat": 64, "batch_size": 1024, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.11609122157096863, "standard_gw": 0.11704489588737488, "standard_gx": 0.11566653847694397, "rowwise_fwd": 0.06706640124320984, "rowwise_bwd": 0.09074807167053223, "global_fwd": 0.06621330976486206, "global_bwd": 0.0859871506690979, "x_quantize_rowwise": 0.027574598789215088, "g_quantize_rowwise": 0.02520531415939331, "w_quantize_rowwise": 0.04095584154129028, "w_quantize_colwise_transpose": 0.37036463618278503, "w_quantize_global": 0.11350959539413452, "w_quantize_global_transpose": 0.12202560901641846, "cast_x": 0.03780052065849304, "cast_g": 0.00860169529914856, "cast_w": 0.06864592432975769, "time_standard": 0.3488026559352875, "time_rowwise": 0.7389597594738007, "time_global": 0.5575604736804962} +{"repeat": 64, "batch_size": 2048, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.22610649466514587, "standard_gw": 0.2229548990726471, "standard_gx": 0.22150203585624695, "rowwise_fwd": 0.1421608030796051, "rowwise_bwd": 0.10771304368972778, "global_fwd": 0.13930723071098328, "global_bwd": 0.10715052485466003, "x_quantize_rowwise": 0.02812594175338745, "g_quantize_rowwise": 0.04733726382255554, "w_quantize_rowwise": 0.03758445382118225, "w_quantize_colwise_transpose": 0.29515475034713745, "w_quantize_global": 0.11344626545906067, "w_quantize_global_transpose": 0.12392178177833557, "cast_x": 0.013589859008789062, "cast_g": 0.08285418152809143, "cast_w": 0.06850436329841614, "time_standard": 0.6705634295940399, "time_rowwise": 0.8810311555862427, "time_global": 0.7822439074516296} +{"repeat": 64, "batch_size": 2048, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.20173192024230957, "standard_gw": 0.2351999282836914, "standard_gx": 0.24710968136787415, "rowwise_fwd": 0.12035667896270752, "rowwise_bwd": 0.153418630361557, "global_fwd": 0.11473894119262695, "global_bwd": 0.14553219079971313, "x_quantize_rowwise": 0.04762038588523865, "g_quantize_rowwise": 0.02557411789894104, "w_quantize_rowwise": 0.04055723547935486, "w_quantize_colwise_transpose": 0.32641738653182983, "w_quantize_global": 0.1138448715209961, "w_quantize_global_transpose": 0.12255832552909851, "cast_x": 0.08405372500419617, "cast_g": 0.013835728168487549, "cast_w": 0.06961449980735779, "time_standard": 0.6840415298938751, "time_rowwise": 0.9491443634033203, "time_global": 0.8050687611103058} +{"repeat": 64, "batch_size": 4096, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.48126280307769775, "standard_gw": 0.46824291348457336, "standard_gx": 0.45252591371536255, "rowwise_fwd": 0.2749897539615631, "rowwise_bwd": 0.2111680805683136, "global_fwd": 0.2689175307750702, "global_bwd": 0.2104043960571289, "x_quantize_rowwise": 0.02676248550415039, "g_quantize_rowwise": 0.0842660665512085, "w_quantize_rowwise": 0.037495046854019165, "w_quantize_colwise_transpose": 0.2952851355075836, "w_quantize_global": 0.11366978287696838, "w_quantize_global_transpose": 0.12461841106414795, "cast_x": 0.0283755362033844, "cast_g": 0.1590624451637268, "cast_w": 0.06854161620140076, "time_standard": 1.4020316302776337, "time_rowwise": 1.3982094824314117, "time_global": 1.2968815863132477} +{"repeat": 64, "batch_size": 4096, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.4076175391674042, "standard_gw": 0.45526400208473206, "standard_gx": 0.4996545612812042, "rowwise_fwd": 0.238761305809021, "rowwise_bwd": 0.2913624048233032, "global_fwd": 0.2149641513824463, "global_bwd": 0.2717897295951843, "x_quantize_rowwise": 0.0845976173877716, "g_quantize_rowwise": 0.0266246497631073, "w_quantize_rowwise": 0.04038959741592407, "w_quantize_colwise_transpose": 0.33299997448921204, "w_quantize_global": 0.11374801397323608, "w_quantize_global_transpose": 0.12202560901641846, "cast_x": 0.15895813703536987, "cast_g": 0.028312206268310547, "cast_w": 0.06841868162155151, "time_standard": 1.3625361025333405, "time_rowwise": 1.4699995517730713, "time_global": 1.2890137732028961} +{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 1.02214515209198, "standard_gw": 0.9412020444869995, "standard_gx": 0.883936882019043, "rowwise_fwd": 0.5209781229496002, "rowwise_bwd": 0.41617080569267273, "global_fwd": 0.5089044570922852, "global_bwd": 0.4142932593822479, "x_quantize_rowwise": 0.03763660788536072, "g_quantize_rowwise": 0.15798211097717285, "w_quantize_rowwise": 0.0375211238861084, "w_quantize_colwise_transpose": 0.2973228693008423, "w_quantize_global": 0.11317431926727295, "w_quantize_global_transpose": 0.12396648526191711, "cast_x": 0.0685863196849823, "cast_g": 0.311531126499176, "cast_w": 0.0685080885887146, "time_standard": 2.8472840785980225, "time_rowwise": 2.4088136851787567, "time_global": 2.2971592843532562} +{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.8539073169231415, "standard_gw": 0.9352751076221466, "standard_gx": 0.9567439556121826, "rowwise_fwd": 0.4599541425704956, "rowwise_bwd": 0.531073659658432, "global_fwd": 0.42063742876052856, "global_bwd": 0.5125999450683594, "x_quantize_rowwise": 0.1581348478794098, "g_quantize_rowwise": 0.03755837678909302, "w_quantize_rowwise": 0.04056468605995178, "w_quantize_colwise_transpose": 0.3295913338661194, "w_quantize_global": 0.11314079165458679, "w_quantize_global_transpose": 0.12153387069702148, "cast_x": 0.3114752471446991, "cast_g": 0.06850063800811768, "cast_w": 0.06839632987976074, "time_standard": 2.7459263801574707, "time_rowwise": 2.492152154445648, "time_global": 2.2988803684711456} +{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 2.0550191402435303, "standard_gw": 1.7850138247013092, "standard_gx": 1.7571337521076202, "rowwise_fwd": 1.026798039674759, "rowwise_bwd": 0.8242167532444, "global_fwd": 1.0042376816272736, "global_bwd": 0.8189938962459564, "x_quantize_rowwise": 0.0688992440700531, "g_quantize_rowwise": 0.3054179251194, "w_quantize_rowwise": 0.03757700324058533, "w_quantize_colwise_transpose": 0.2973712980747223, "w_quantize_global": 0.11324509978294373, "w_quantize_global_transpose": 0.12398511171340942, "cast_x": 0.13050436973571777, "cast_g": 0.6165280938148499, "cast_w": 0.06848573684692383, "time_standard": 5.59716671705246, "time_rowwise": 4.345294088125229, "time_global": 4.2197927832603455} +{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 1.79310142993927, "standard_gw": 1.7801076173782349, "standard_gx": 1.9140169024467468, "rowwise_fwd": 0.8629709482192993, "rowwise_bwd": 1.0353922843933105, "global_fwd": 0.8200556039810181, "global_bwd": 1.002725213766098, "x_quantize_rowwise": 0.30517578125, "g_quantize_rowwise": 0.06880238652229309, "w_quantize_rowwise": 0.040318816900253296, "w_quantize_colwise_transpose": 0.3413744270801544, "w_quantize_global": 0.11326000094413757, "w_quantize_global_transpose": 0.12197345495223999, "cast_x": 0.6162337958812714, "cast_g": 0.13053417205810547, "cast_w": 0.06848946213722229, "time_standard": 5.487225949764252, "time_rowwise": 4.4341422617435455, "time_global": 4.212100058794022} +{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 4.0736086666584015, "standard_gw": 3.595758229494095, "standard_gx": 3.7020929157733917, "rowwise_fwd": 2.0306408405303955, "rowwise_bwd": 1.635722815990448, "global_fwd": 1.9890740513801575, "global_bwd": 1.627359539270401, "x_quantize_rowwise": 0.13131648302078247, "g_quantize_rowwise": 0.6001107394695282, "w_quantize_rowwise": 0.03781542181968689, "w_quantize_colwise_transpose": 0.2975836396217346, "w_quantize_global": 0.11357292532920837, "w_quantize_global_transpose": 0.12416765093803406, "cast_x": 0.2544410526752472, "cast_g": 1.2265890836715698, "cast_w": 0.06866827607154846, "time_standard": 11.371459811925888, "time_rowwise": 8.32894816994667, "time_global": 8.181359618902206} +{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 3.525231033563614, "standard_gw": 3.489706665277481, "standard_gx": 3.9937011897563934, "rowwise_fwd": 1.6627348959445953, "rowwise_bwd": 2.0311400294303894, "global_fwd": 1.6270726919174194, "global_bwd": 1.988884061574936, "x_quantize_rowwise": 0.5999915301799774, "g_quantize_rowwise": 0.1310594379901886, "w_quantize_rowwise": 0.04043802618980408, "w_quantize_colwise_transpose": 0.32950565218925476, "w_quantize_global": 0.11298432946205139, "w_quantize_global_transpose": 0.12201443314552307, "cast_x": 1.2257546186447144, "cast_g": 0.25444477796554565, "cast_w": 0.06848573684692383, "time_standard": 11.008638888597488, "time_rowwise": 8.28457623720169, "time_global": 8.071713149547577} +{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 8.123598992824554, "standard_gw": 8.085217326879501, "standard_gx": 7.293816655874252, "rowwise_fwd": 4.07782569527626, "rowwise_bwd": 3.196723759174347, "global_fwd": 4.001103341579437, "global_bwd": 3.1843744218349457, "x_quantize_rowwise": 0.2560615539550781, "g_quantize_rowwise": 1.1893659830093384, "w_quantize_rowwise": 0.037297606468200684, "w_quantize_colwise_transpose": 0.29668211936950684, "w_quantize_global": 0.11358782649040222, "w_quantize_global_transpose": 0.12476742267608643, "cast_x": 0.5020052194595337, "cast_g": 2.4454034864902496, "cast_w": 0.0684782862663269, "time_standard": 23.502632975578308, "time_rowwise": 17.139174044132233, "time_global": 16.95447787642479} +{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 6.932958960533142, "standard_gw": 7.0609524846076965, "standard_gx": 7.460080087184906, "rowwise_fwd": 3.1809918582439423, "rowwise_bwd": 4.078391939401627, "global_fwd": 3.185112029314041, "global_bwd": 3.99089977145195, "x_quantize_rowwise": 1.1891834437847137, "g_quantize_rowwise": 0.25588274002075195, "w_quantize_rowwise": 0.0406019389629364, "w_quantize_colwise_transpose": 0.3389529883861542, "w_quantize_global": 0.11313334107398987, "w_quantize_global_transpose": 0.12241676449775696, "cast_x": 2.4446770548820496, "cast_g": 0.5022138357162476, "cast_w": 0.06857141852378845, "time_standard": 21.453991532325745, "time_rowwise": 16.14495739340782, "time_global": 15.9175805747509} +{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 16.38999581336975, "standard_gw": 15.075922012329102, "standard_gx": 14.479495584964752, "rowwise_fwd": 8.128684014081955, "rowwise_bwd": 6.41091912984848, "global_fwd": 7.977847009897232, "global_bwd": 6.362702697515488, "x_quantize_rowwise": 0.5057230591773987, "g_quantize_rowwise": 2.3681968450546265, "w_quantize_rowwise": 0.037435442209243774, "w_quantize_colwise_transpose": 0.29555708169937134, "w_quantize_global": 0.11360272765159607, "w_quantize_global_transpose": 0.12426823377609253, "cast_x": 0.997692346572876, "cast_g": 4.8848651349544525, "cast_w": 0.0685565173625946, "time_standard": 45.945413410663605, "time_rowwise": 32.82243758440018, "time_global": 32.528262585401535} +{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 14.838922768831253, "standard_gw": 15.112213790416718, "standard_gx": 14.869242906570435, "rowwise_fwd": 6.402213126420975, "rowwise_bwd": 8.132629096508026, "global_fwd": 6.36359304189682, "global_bwd": 7.9823993146419525, "x_quantize_rowwise": 2.367999404668808, "g_quantize_rowwise": 0.5056969821453094, "w_quantize_rowwise": 0.04053488373756409, "w_quantize_colwise_transpose": 0.3559887409210205, "w_quantize_global": 0.1136288046836853, "w_quantize_global_transpose": 0.125102698802948, "cast_x": 4.880473017692566, "cast_g": 0.9965412318706512, "cast_w": 0.06855279207229614, "time_standard": 44.820379465818405, "time_rowwise": 32.91727602481842, "time_global": 32.57063403725624} +{"repeat": 64, "batch_size": 1024, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.15426427125930786, "standard_gw": 0.14531239867210388, "standard_gx": 0.1703128218650818, "rowwise_fwd": 0.09618699550628662, "rowwise_bwd": 0.10633841156959534, "global_fwd": 0.09483471512794495, "global_bwd": 0.10636076331138611, "x_quantize_rowwise": 0.02434849739074707, "g_quantize_rowwise": 0.026009976863861084, "w_quantize_rowwise": 0.04366040229797363, "w_quantize_colwise_transpose": 0.34148991107940674, "w_quantize_global": 0.13587623834609985, "w_quantize_global_transpose": 0.14698877930641174, "cast_x": 0.009745359420776367, "cast_g": 0.03773719072341919, "cast_w": 0.08277222514152527, "time_standard": 0.46988949179649353, "time_rowwise": 0.7833465933799744, "time_global": 0.6797313690185547} +{"repeat": 64, "batch_size": 1024, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.16738846898078918, "standard_gw": 0.14199689030647278, "standard_gx": 0.15476346015930176, "rowwise_fwd": 0.11660531163215637, "rowwise_bwd": 0.1050308346748352, "global_fwd": 0.11050701141357422, "global_bwd": 0.09868666529655457, "x_quantize_rowwise": 0.02781301736831665, "g_quantize_rowwise": 0.024966895580291748, "w_quantize_rowwise": 0.047437846660614014, "w_quantize_colwise_transpose": 0.5995631217956543, "w_quantize_global": 0.1362822949886322, "w_quantize_global_transpose": 0.14807283878326416, "cast_x": 0.0377558171749115, "cast_g": 0.00973045825958252, "cast_w": 0.0828281044960022, "time_standard": 0.4641488194465637, "time_rowwise": 1.063413918018341, "time_global": 0.6883256137371063} +{"repeat": 64, "batch_size": 2048, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.2727396786212921, "standard_gw": 0.2711080014705658, "standard_gx": 0.3120154142379761, "rowwise_fwd": 0.16424059867858887, "rowwise_bwd": 0.17686933279037476, "global_fwd": 0.161685049533844, "global_bwd": 0.17517060041427612, "x_quantize_rowwise": 0.025484710931777954, "g_quantize_rowwise": 0.047635287046432495, "w_quantize_rowwise": 0.04380941390991211, "w_quantize_colwise_transpose": 0.3401711583137512, "w_quantize_global": 0.13605505228042603, "w_quantize_global_transpose": 0.14705583453178406, "cast_x": 0.01584365963935852, "cast_g": 0.08274242281913757, "cast_w": 0.08281320333480835, "time_standard": 0.855863094329834, "time_rowwise": 1.0693185031414032, "time_global": 0.9641945362091064} +{"repeat": 64, "batch_size": 2048, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.28916075825691223, "standard_gw": 0.29472261667251587, "standard_gx": 0.30096620321273804, "rowwise_fwd": 0.19618868827819824, "rowwise_bwd": 0.17556175589561462, "global_fwd": 0.18328800797462463, "global_bwd": 0.16647577285766602, "x_quantize_rowwise": 0.047441571950912476, "g_quantize_rowwise": 0.026609748601913452, "w_quantize_rowwise": 0.04766508936882019, "w_quantize_colwise_transpose": 0.6060972809791565, "w_quantize_global": 0.1363418996334076, "w_quantize_global_transpose": 0.14806538820266724, "cast_x": 0.08295103907585144, "cast_g": 0.015836209058761597, "cast_w": 0.08285045623779297, "time_standard": 0.8848495781421661, "time_rowwise": 1.3942867517471313, "time_global": 1.0029450058937073} +{"repeat": 64, "batch_size": 4096, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.6430819630622864, "standard_gw": 0.5622953176498413, "standard_gx": 0.5780421197414398, "rowwise_fwd": 0.318676233291626, "rowwise_bwd": 0.29438361525535583, "global_fwd": 0.31290948390960693, "global_bwd": 0.290747731924057, "x_quantize_rowwise": 0.027455389499664307, "g_quantize_rowwise": 0.08405372500419617, "w_quantize_rowwise": 0.04369765520095825, "w_quantize_colwise_transpose": 0.34110620617866516, "w_quantize_global": 0.1360774040222168, "w_quantize_global_transpose": 0.14697015285491943, "cast_x": 0.037614256143569946, "cast_g": 0.15922263264656067, "cast_w": 0.08288025856018066, "time_standard": 1.7834194004535675, "time_rowwise": 1.671668142080307, "time_global": 1.560509204864502} +{"repeat": 64, "batch_size": 4096, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.551275908946991, "standard_gw": 0.591665506362915, "standard_gx": 0.6067268550395966, "rowwise_fwd": 0.33493712544441223, "rowwise_bwd": 0.32918527722358704, "global_fwd": 0.29528141021728516, "global_bwd": 0.31659379601478577, "x_quantize_rowwise": 0.08441135287284851, "g_quantize_rowwise": 0.025656074285507202, "w_quantize_rowwise": 0.04745647311210632, "w_quantize_colwise_transpose": 0.5993843078613281, "w_quantize_global": 0.1359879970550537, "w_quantize_global_transpose": 0.14815106987953186, "cast_x": 0.15932321548461914, "cast_g": 0.037439167499542236, "cast_w": 0.08288398385047913, "time_standard": 1.7496682703495026, "time_rowwise": 2.0126961171627045, "time_global": 1.5977472066879272} +{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2295916676521301, "standard_gw": 1.116037368774414, "standard_gx": 1.1164769530296326, "rowwise_fwd": 0.603698194026947, "rowwise_bwd": 0.5168020725250244, "global_fwd": 0.5922466516494751, "global_bwd": 0.5151033401489258, "x_quantize_rowwise": 0.0437907874584198, "g_quantize_rowwise": 0.157918781042099, "w_quantize_rowwise": 0.044032931327819824, "w_quantize_colwise_transpose": 0.34073740243911743, "w_quantize_global": 0.13559311628341675, "w_quantize_global_transpose": 0.14679506421089172, "cast_x": 0.08263811469078064, "cast_g": 0.3115162253379822, "cast_w": 0.08287280797958374, "time_standard": 3.4621059894561768, "time_rowwise": 2.8230175375938416, "time_global": 2.707485109567642} +{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.090865582227707, "standard_gw": 1.1468492448329926, "standard_gx": 1.1166594922542572, "rowwise_fwd": 0.5559474229812622, "rowwise_bwd": 0.6105974316596985, "global_fwd": 0.5200020968914032, "global_bwd": 0.592011958360672, "x_quantize_rowwise": 0.15802308917045593, "g_quantize_rowwise": 0.04357844591140747, "w_quantize_rowwise": 0.04709511995315552, "w_quantize_colwise_transpose": 0.5969703197479248, "w_quantize_global": 0.13620033860206604, "w_quantize_global_transpose": 0.148136168718338, "cast_x": 0.31115859746932983, "cast_g": 0.08263811469078064, "cast_w": 0.08268281817436218, "time_standard": 3.3543743193149567, "time_rowwise": 3.159061074256897, "time_global": 2.744801342487335} +{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4665743112564087, "standard_gw": 2.1993443369865417, "standard_gx": 2.1993033587932587, "rowwise_fwd": 1.192428171634674, "rowwise_bwd": 1.023314893245697, "global_fwd": 1.1711902916431427, "global_bwd": 1.0202191770076752, "x_quantize_rowwise": 0.08077174425125122, "g_quantize_rowwise": 0.30520185828208923, "w_quantize_rowwise": 0.043783336877822876, "w_quantize_colwise_transpose": 0.339999794960022, "w_quantize_global": 0.13628602027893066, "w_quantize_global_transpose": 0.14696642756462097, "cast_x": 0.15902891755104065, "cast_g": 0.6164535880088806, "cast_w": 0.08285418152809143, "time_standard": 6.865222007036209, "time_rowwise": 5.184844136238098, "time_global": 5.059979856014252} +{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1861791610717773, "standard_gw": 2.157818526029587, "standard_gx": 2.321537584066391, "rowwise_fwd": 1.0536126792430878, "rowwise_bwd": 1.1971630156040192, "global_fwd": 1.02127343416214, "global_bwd": 1.1707991361618042, "x_quantize_rowwise": 0.30522048473358154, "g_quantize_rowwise": 0.08065253496170044, "w_quantize_rowwise": 0.04741176962852478, "w_quantize_colwise_transpose": 0.5979575216770172, "w_quantize_global": 0.1362040638923645, "w_quantize_global_transpose": 0.14854222536087036, "cast_x": 0.6162486970424652, "cast_g": 0.1591891050338745, "cast_w": 0.08288398385047913, "time_standard": 6.665535271167755, "time_rowwise": 5.439836531877518, "time_global": 5.020510405302048} +{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.891645163297653, "standard_gw": 4.233300685882568, "standard_gx": 4.2071714997291565, "rowwise_fwd": 2.3616664111614227, "rowwise_bwd": 1.9419342279434204, "global_fwd": 2.3244209587574005, "global_bwd": 1.9598640501499176, "x_quantize_rowwise": 0.15483051538467407, "g_quantize_rowwise": 0.6008371710777283, "w_quantize_rowwise": 0.043839216232299805, "w_quantize_colwise_transpose": 0.3400743007659912, "w_quantize_global": 0.1362822949886322, "w_quantize_global_transpose": 0.14691054821014404, "cast_x": 0.31141936779022217, "cast_g": 1.2254081666469574, "cast_w": 0.08280202746391296, "time_standard": 13.332117348909378, "time_rowwise": 9.676482528448105, "time_global": 9.556446224451065} +{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.267625510692596, "standard_gw": 4.237007349729538, "standard_gx": 4.666488617658615, "rowwise_fwd": 1.9670464098453522, "rowwise_bwd": 2.362079918384552, "global_fwd": 1.9469596445560455, "global_bwd": 2.32585147023201, "x_quantize_rowwise": 0.6000921130180359, "g_quantize_rowwise": 0.15481188893318176, "w_quantize_rowwise": 0.04725530743598938, "w_quantize_colwise_transpose": 0.5976222455501556, "w_quantize_global": 0.13619661331176758, "w_quantize_global_transpose": 0.14815852046012878, "cast_x": 1.2261345982551575, "cast_g": 0.3117173910140991, "cast_w": 0.08279457688331604, "time_standard": 13.17112147808075, "time_rowwise": 9.965915232896805, "time_global": 9.549077600240707} +{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.787477552890778, "standard_gw": 8.533861488103867, "standard_gx": 8.979786187410355, "rowwise_fwd": 4.741787910461426, "rowwise_bwd": 3.871854394674301, "global_fwd": 4.674319177865982, "global_bwd": 3.9110779762268066, "x_quantize_rowwise": 0.3025829792022705, "g_quantize_rowwise": 1.1898204684257507, "w_quantize_rowwise": 0.043705105781555176, "w_quantize_colwise_transpose": 0.33997371792793274, "w_quantize_global": 0.13592839241027832, "w_quantize_global_transpose": 0.14724954962730408, "cast_x": 0.6160177290439606, "cast_g": 2.4440810084342957, "cast_w": 0.08280575275421143, "time_standard": 27.301125228405, "time_rowwise": 19.023586064577103, "time_global": 18.89484003186226} +{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.461769670248032, "standard_gw": 8.428700268268585, "standard_gx": 9.447630494832993, "rowwise_fwd": 3.881257027387619, "rowwise_bwd": 4.7471001744270325, "global_fwd": 3.9101652801036835, "global_bwd": 4.662122577428818, "x_quantize_rowwise": 1.1892355978488922, "g_quantize_rowwise": 0.3024376928806305, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.5982778966426849, "w_quantize_global": 0.13624131679534912, "w_quantize_global_transpose": 0.1484602689743042, "cast_x": 2.4463236331939697, "cast_g": 0.6163865327835083, "cast_w": 0.08278340101242065, "time_standard": 26.33810043334961, "time_rowwise": 19.194088876247406, "time_global": 18.777363002300262} +{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.699689000844955, "standard_gw": 16.89574122428894, "standard_gx": 17.907552421092987, "rowwise_fwd": 9.453803300857544, "rowwise_bwd": 7.8153833746910095, "global_fwd": 9.313825517892838, "global_bwd": 7.8215524554252625, "x_quantize_rowwise": 0.5986690521240234, "g_quantize_rowwise": 2.368006855249405, "w_quantize_rowwise": 0.043682754039764404, "w_quantize_colwise_transpose": 0.3406330943107605, "w_quantize_global": 0.13626739382743835, "w_quantize_global_transpose": 0.14715641736984253, "cast_x": 1.2262165546417236, "cast_g": 4.8834048211574554, "cast_w": 0.08272379636764526, "time_standard": 54.50298264622688, "time_rowwise": 37.51591965556145, "time_global": 37.28121891617775} +{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.66700127720833, "standard_gw": 18.56840029358864, "standard_gx": 18.049821257591248, "rowwise_fwd": 7.742393761873245, "rowwise_bwd": 9.479016065597534, "global_fwd": 7.806576788425446, "global_bwd": 9.328477084636688, "x_quantize_rowwise": 2.368297427892685, "g_quantize_rowwise": 0.5978643894195557, "w_quantize_rowwise": 0.047303736209869385, "w_quantize_colwise_transpose": 0.5982741713523865, "w_quantize_global": 0.13678893446922302, "w_quantize_global_transpose": 0.1488029956817627, "cast_x": 4.880513995885849, "cast_g": 1.2248307466506958, "cast_w": 0.08270144462585449, "time_standard": 55.285222828388214, "time_rowwise": 39.401549845933914, "time_global": 38.955207914114} +{"repeat": 64, "batch_size": 1024, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 0.529509037733078, "standard_gw": 0.5781911313533783, "standard_gx": 0.6095841526985168, "rowwise_fwd": 0.2811029553413391, "rowwise_bwd": 0.3345906734466553, "global_fwd": 0.27928128838539124, "global_bwd": 0.33126771450042725, "x_quantize_rowwise": 0.025760382413864136, "g_quantize_rowwise": 0.06494298577308655, "w_quantize_rowwise": 0.15570968389511108, "w_quantize_colwise_transpose": 1.6086548566818237, "w_quantize_global": 0.481434166431427, "w_quantize_global_transpose": 0.505443662405014, "cast_x": 0.01582130789756775, "cast_g": 0.08295103907585144, "cast_w": 0.311531126499176, "time_standard": 1.7172843217849731, "time_rowwise": 3.048952668905258, "time_global": 2.2663213312625885} +{"repeat": 64, "batch_size": 1024, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 0.5729459226131439, "standard_gw": 0.5789846181869507, "standard_gx": 0.5775243043899536, "rowwise_fwd": 0.36711618304252625, "rowwise_bwd": 0.2913735806941986, "global_fwd": 0.33703818917274475, "global_bwd": 0.2821236848831177, "x_quantize_rowwise": 0.064849853515625, "g_quantize_rowwise": 0.025060027837753296, "w_quantize_rowwise": 0.22537633776664734, "w_quantize_colwise_transpose": 3.6401040852069855, "w_quantize_global": 0.4818551242351532, "w_quantize_global_transpose": 0.5101114511489868, "cast_x": 0.08286535739898682, "cast_g": 0.015828758478164673, "cast_w": 0.3114677965641022, "time_standard": 1.7294548451900482, "time_rowwise": 5.192864686250687, "time_global": 2.2800229489803314} +{"repeat": 64, "batch_size": 2048, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 1.1735819280147552, "standard_gw": 1.121576875448227, "standard_gx": 1.1242404580116272, "rowwise_fwd": 0.5535706877708435, "rowwise_bwd": 0.5567893385887146, "global_fwd": 0.5486570298671722, "global_bwd": 0.551365315914154, "x_quantize_rowwise": 0.02710893750190735, "g_quantize_rowwise": 0.11784210801124573, "w_quantize_rowwise": 0.15565752983093262, "w_quantize_colwise_transpose": 1.607745885848999, "w_quantize_global": 0.4824437201023102, "w_quantize_global_transpose": 0.5060508847236633, "cast_x": 0.03808736801147461, "cast_g": 0.15912577509880066, "cast_w": 0.31150132417678833, "time_standard": 3.4193992614746094, "time_rowwise": 4.14029136300087, "time_global": 3.35504487156868} +{"repeat": 64, "batch_size": 2048, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 1.1169910430908203, "standard_gw": 1.1065900325775146, "standard_gx": 1.1815577745437622, "rowwise_fwd": 0.5917288362979889, "rowwise_bwd": 0.5614385008811951, "global_fwd": 0.5646944046020508, "global_bwd": 0.5500949919223785, "x_quantize_rowwise": 0.118207186460495, "g_quantize_rowwise": 0.025041401386260986, "w_quantize_rowwise": 0.22566691040992737, "w_quantize_colwise_transpose": 3.635551780462265, "w_quantize_global": 0.4815608263015747, "w_quantize_global_transpose": 0.509701669216156, "cast_x": 0.15912950038909912, "cast_g": 0.03797560930252075, "cast_w": 0.3114044666290283, "time_standard": 3.405138850212097, "time_rowwise": 6.264224648475647, "time_global": 3.3558905124664307} +{"repeat": 64, "batch_size": 4096, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 2.3259930312633514, "standard_gw": 2.1472275257110596, "standard_gx": 2.213582396507263, "rowwise_fwd": 1.0509602725505829, "rowwise_bwd": 0.9888559579849243, "global_fwd": 1.0398179292678833, "global_bwd": 0.9887740015983582, "x_quantize_rowwise": 0.04647299647331238, "g_quantize_rowwise": 0.22570788860321045, "w_quantize_rowwise": 0.1554824411869049, "w_quantize_colwise_transpose": 1.610085368156433, "w_quantize_global": 0.48134103417396545, "w_quantize_global_transpose": 0.5054809153079987, "cast_x": 0.08297711610794067, "cast_g": 0.3115646541118622, "cast_w": 0.31159818172454834, "time_standard": 6.686802953481674, "time_rowwise": 6.224792450666428, "time_global": 5.434822291135788} +{"repeat": 64, "batch_size": 4096, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 2.19760462641716, "standard_gw": 2.2860951721668243, "standard_gx": 2.290956676006317, "rowwise_fwd": 1.0311491787433624, "rowwise_bwd": 1.0555200278759003, "global_fwd": 0.9858310222625732, "global_bwd": 1.0394863784313202, "x_quantize_rowwise": 0.22591277956962585, "g_quantize_rowwise": 0.046234577894210815, "w_quantize_rowwise": 0.22603943943977356, "w_quantize_colwise_transpose": 3.628809005022049, "w_quantize_global": 0.4819147288799286, "w_quantize_global_transpose": 0.5104243755340576, "cast_x": 0.3114528954029083, "cast_g": 0.08296966552734375, "cast_w": 0.3116317093372345, "time_standard": 6.7746564745903015, "time_rowwise": 8.499760180711746, "time_global": 5.575899034738541} +{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.633370786905289, "standard_gw": 4.397690296173096, "standard_gx": 4.286538809537888, "rowwise_fwd": 2.089906483888626, "rowwise_bwd": 1.9657425582408905, "global_fwd": 2.0679645240306854, "global_bwd": 1.9629858434200287, "x_quantize_rowwise": 0.08271634578704834, "g_quantize_rowwise": 0.43905526399612427, "w_quantize_rowwise": 0.1551508903503418, "w_quantize_colwise_transpose": 1.6106180846691132, "w_quantize_global": 0.48185884952545166, "w_quantize_global_transpose": 0.506274402141571, "cast_x": 0.15918537974357605, "cast_g": 0.6163418292999268, "cast_w": 0.311531126499176, "time_standard": 13.317599892616272, "time_rowwise": 10.74087992310524, "time_global": 9.938545525074005} +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.424266517162323, "standard_gw": 4.391487687826157, "standard_gx": 4.61186096072197, "rowwise_fwd": 1.9874684512615204, "rowwise_bwd": 2.093140035867691, "global_fwd": 1.9647255539894104, "global_bwd": 2.06940621137619, "x_quantize_rowwise": 0.43999403715133667, "g_quantize_rowwise": 0.08271634578704834, "w_quantize_rowwise": 0.22581592202186584, "w_quantize_colwise_transpose": 3.631964325904846, "w_quantize_global": 0.4821456968784332, "w_quantize_global_transpose": 0.5102343857288361, "cast_x": 0.6164386868476868, "cast_g": 0.1591108739376068, "cast_w": 0.31154975295066833, "time_standard": 13.42761516571045, "time_rowwise": 12.852586805820465, "time_global": 9.940709918737411} +{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.229827672243118, "standard_gw": 8.319318294525146, "standard_gx": 8.652344346046448, "rowwise_fwd": 4.163607954978943, "rowwise_bwd": 3.778301179409027, "global_fwd": 4.121184349060059, "global_bwd": 3.7708766758441925, "x_quantize_rowwise": 0.1553669571876526, "g_quantize_rowwise": 0.8715838193893433, "w_quantize_rowwise": 0.15540048480033875, "w_quantize_colwise_transpose": 1.6092769801616669, "w_quantize_global": 0.4813969135284424, "w_quantize_global_transpose": 0.5070343613624573, "cast_x": 0.31150132417678833, "cast_g": 1.2259706854820251, "cast_w": 0.311482697725296, "time_standard": 26.201490312814713, "time_rowwise": 19.052855670452118, "time_global": 18.226761370897293} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.577890694141388, "standard_gw": 9.073298424482346, "standard_gx": 9.210295975208282, "rowwise_fwd": 3.7784352898597717, "rowwise_bwd": 4.165928810834885, "global_fwd": 3.7702471017837524, "global_bwd": 4.121150821447372, "x_quantize_rowwise": 0.868629664182663, "g_quantize_rowwise": 0.1554340124130249, "w_quantize_rowwise": 0.22614002227783203, "w_quantize_colwise_transpose": 3.6367811262607574, "w_quantize_global": 0.4828609526157379, "w_quantize_global_transpose": 0.510137528181076, "cast_x": 1.2258104979991913, "cast_g": 0.31299516558647156, "cast_w": 0.3114677965641022, "time_standard": 26.861485093832016, "time_rowwise": 21.90464735031128, "time_global": 18.981758505105972} +{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.52763444185257, "standard_gw": 17.835520207881927, "standard_gx": 17.375655472278595, "rowwise_fwd": 8.35346058011055, "rowwise_bwd": 7.584303617477417, "global_fwd": 8.300606161355972, "global_bwd": 7.550913840532303, "x_quantize_rowwise": 0.3016740083694458, "g_quantize_rowwise": 1.7321519553661346, "w_quantize_rowwise": 0.15538185834884644, "w_quantize_colwise_transpose": 1.6110800206661224, "w_quantize_global": 0.4815198481082916, "w_quantize_global_transpose": 0.5066357553005219, "cast_x": 0.6163753569126129, "cast_g": 2.4452805519104004, "cast_w": 0.31156837940216064, "time_standard": 53.73881012201309, "time_rowwise": 37.573572248220444, "time_global": 36.7090217769146} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 18.073823302984238, "standard_gw": 16.71283319592476, "standard_gx": 18.46104860305786, "rowwise_fwd": 7.542364299297333, "rowwise_bwd": 8.374195545911789, "global_fwd": 7.5644850730896, "global_bwd": 8.26016440987587, "x_quantize_rowwise": 1.7326027154922485, "g_quantize_rowwise": 0.30233338475227356, "w_quantize_rowwise": 0.2259574830532074, "w_quantize_colwise_transpose": 3.634512424468994, "w_quantize_global": 0.48204511404037476, "w_quantize_global_transpose": 0.5093887448310852, "cast_x": 2.445656806230545, "cast_g": 0.6163381040096283, "cast_w": 0.31144917011260986, "time_standard": 53.24770510196686, "time_rowwise": 38.524799048900604, "time_global": 35.56385263800621} +{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.123402416706085, "standard_gw": 32.68447890877724, "standard_gx": 34.13737937808037, "rowwise_fwd": 16.65867120027542, "rowwise_bwd": 15.004873275756836, "global_fwd": 16.536589711904526, "global_bwd": 14.949381351470947, "x_quantize_rowwise": 0.5952902138233185, "g_quantize_rowwise": 3.4581348299980164, "w_quantize_rowwise": 0.15559792518615723, "w_quantize_colwise_transpose": 1.6055963933467865, "w_quantize_global": 0.48203766345977783, "w_quantize_global_transpose": 0.5048215389251709, "cast_x": 1.2256354093551636, "cast_g": 4.875503480434418, "cast_w": 0.3110244870185852, "time_standard": 102.94526070356369, "time_rowwise": 70.16264274716377, "time_global": 69.210734218359} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.0223146378994, "standard_gw": 32.84081444144249, "standard_gx": 35.984884947538376, "rowwise_fwd": 15.018381178379059, "rowwise_bwd": 16.69919490814209, "global_fwd": 14.942582696676254, "global_bwd": 16.529250890016556, "x_quantize_rowwise": 3.442291170358658, "g_quantize_rowwise": 0.5951747298240662, "w_quantize_rowwise": 0.22576376795768738, "w_quantize_colwise_transpose": 3.621157258749008, "w_quantize_global": 0.48135966062545776, "w_quantize_global_transpose": 0.5095489323139191, "cast_x": 4.875205457210541, "cast_g": 1.2237727642059326, "cast_w": 0.3110431134700775, "time_standard": 103.84801402688026, "time_rowwise": 72.44277745485306, "time_global": 69.3410225212574} +{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 72.33698666095734, "standard_gw": 71.31465151906013, "standard_gx": 69.32922825217247, "rowwise_fwd": 33.37707370519638, "rowwise_bwd": 30.1642008125782, "global_fwd": 33.002063632011414, "global_bwd": 30.003495514392853, "x_quantize_rowwise": 1.1819563806056976, "g_quantize_rowwise": 6.896954029798508, "w_quantize_rowwise": 0.15557929873466492, "w_quantize_colwise_transpose": 1.6083605587482452, "w_quantize_global": 0.48125162720680237, "w_quantize_global_transpose": 0.5055665969848633, "cast_x": 2.442535012960434, "cast_g": 9.750165045261383, "cast_w": 0.31094998121261597, "time_standard": 212.98086643218994, "time_rowwise": 144.69877630472183, "time_global": 143.38593930006027} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 70.24158909916878, "standard_gw": 72.03734293580055, "standard_gx": 72.01339676976204, "rowwise_fwd": 30.072908848524094, "rowwise_bwd": 33.376410603523254, "global_fwd": 29.965493828058243, "global_bwd": 33.01112726330757, "x_quantize_rowwise": 6.894122809171677, "g_quantize_rowwise": 1.1817142367362976, "w_quantize_rowwise": 0.22567808628082275, "w_quantize_colwise_transpose": 3.616899251937866, "w_quantize_global": 0.4819147288799286, "w_quantize_global_transpose": 0.5107112228870392, "cast_x": 9.750377386808395, "cast_g": 2.4411343038082123, "cast_w": 0.31099095940589905, "time_standard": 214.29232880473137, "time_rowwise": 147.40507677197456, "time_global": 144.0824270248413} +{"repeat": 64, "batch_size": 65536, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 138.23134452104568, "standard_gw": 131.48364424705505, "standard_gx": 141.09868183732033, "rowwise_fwd": 65.38830325007439, "rowwise_bwd": 58.39048698544502, "global_fwd": 65.2194656431675, "global_bwd": 58.58004465699196, "x_quantize_rowwise": 1.1899955570697784, "g_quantize_rowwise": 6.623774766921997, "w_quantize_rowwise": 0.5935952067375183, "w_quantize_colwise_transpose": 24.08137544989586, "w_quantize_global": 1.740824431180954, "w_quantize_global_transpose": 1.8664970993995667, "cast_x": 2.413548529148102, "cast_g": 9.63655486702919, "cast_w": 1.1956281960010529, "time_standard": 410.81367060542107, "time_rowwise": 287.7511754631996, "time_global": 266.7042464017868} +{"repeat": 64, "batch_size": 65536, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 141.08363911509514, "standard_gw": 133.26667994260788, "standard_gx": 136.0350362956524, "rowwise_fwd": 58.49892646074295, "rowwise_bwd": 65.34496694803238, "global_fwd": 58.73573571443558, "global_bwd": 65.30505418777466, "x_quantize_rowwise": 6.648071110248566, "g_quantize_rowwise": 1.1903978884220123, "w_quantize_rowwise": 0.8329600095748901, "w_quantize_colwise_transpose": 15.297897160053253, "w_quantize_global": 1.7403066158294678, "w_quantize_global_transpose": 1.8791332840919495, "cast_x": 9.636614471673965, "cast_g": 2.4122819304466248, "cast_w": 1.1954344809055328, "time_standard": 410.3853553533554, "time_rowwise": 281.07989951968193, "time_global": 268.7653787434101} +{"repeat": 64, "batch_size": 1024, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 2.535879611968994, "standard_gw": 2.249978482723236, "standard_gx": 2.2262558341026306, "rowwise_fwd": 1.085665076971054, "rowwise_bwd": 1.069542020559311, "global_fwd": 1.0830685496330261, "global_bwd": 1.0597631335258484, "x_quantize_rowwise": 0.02650916576385498, "g_quantize_rowwise": 0.1200847327709198, "w_quantize_rowwise": 0.5937665700912476, "w_quantize_colwise_transpose": 23.926906287670135, "w_quantize_global": 1.7397291958332062, "w_quantize_global_transpose": 1.8652454018592834, "cast_x": 0.03688782453536987, "cast_g": 0.15725940465927124, "cast_w": 1.1969134211540222, "time_standard": 7.012113928794861, "time_rowwise": 29.07245233654976, "time_global": 8.144378662109375} +{"repeat": 64, "batch_size": 1024, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 2.245493233203888, "standard_gw": 2.2966675460338593, "standard_gx": 2.216015011072159, "rowwise_fwd": 1.1000856757164001, "rowwise_bwd": 1.0902360081672668, "global_fwd": 1.0597333312034607, "global_bwd": 1.0812543332576752, "x_quantize_rowwise": 0.11992454528808594, "g_quantize_rowwise": 0.026784837245941162, "w_quantize_rowwise": 0.8310377597808838, "w_quantize_colwise_transpose": 15.30550792813301, "w_quantize_global": 1.7401352524757385, "w_quantize_global_transpose": 1.8841177225112915, "cast_x": 0.1573599874973297, "cast_g": 0.03676116466522217, "cast_w": 1.195952296257019, "time_standard": 6.758175790309906, "time_rowwise": 20.770244300365448, "time_global": 8.208617568016052} +{"repeat": 64, "batch_size": 2048, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 4.197858273983002, "standard_gw": 4.288379102945328, "standard_gx": 4.155721515417099, "rowwise_fwd": 2.0567886531352997, "rowwise_bwd": 1.9073635339736938, "global_fwd": 2.0506344735622406, "global_bwd": 1.9086338579654694, "x_quantize_rowwise": 0.04758685827255249, "g_quantize_rowwise": 0.22284314036369324, "w_quantize_rowwise": 0.5935467779636383, "w_quantize_colwise_transpose": 23.935042321681976, "w_quantize_global": 1.7397813498973846, "w_quantize_global_transpose": 1.8662959337234497, "cast_x": 0.08194148540496826, "cast_g": 0.3077872097492218, "cast_w": 1.1968687176704407, "time_standard": 12.641958892345428, "time_rowwise": 33.05155038833618, "time_global": 12.124154716730118} +{"repeat": 64, "batch_size": 2048, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 4.126541316509247, "standard_gw": 4.309836775064468, "standard_gx": 4.117351025342941, "rowwise_fwd": 1.9266381859779358, "rowwise_bwd": 2.0577237010002136, "global_fwd": 1.908630132675171, "global_bwd": 2.0505934953689575, "x_quantize_rowwise": 0.22304058074951172, "g_quantize_rowwise": 0.04766136407852173, "w_quantize_rowwise": 0.8306317031383514, "w_quantize_colwise_transpose": 15.309855341911316, "w_quantize_global": 1.7415396869182587, "w_quantize_global_transpose": 1.8827766180038452, "cast_x": 0.30782073736190796, "cast_g": 0.08186325430870056, "cast_w": 1.1955127120018005, "time_standard": 12.553729116916656, "time_rowwise": 24.70538765192032, "time_global": 12.164078652858734} +{"repeat": 64, "batch_size": 4096, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 8.298952132463455, "standard_gw": 8.345257490873337, "standard_gx": 8.647706359624863, "rowwise_fwd": 4.106882959604263, "rowwise_bwd": 3.8046911358833313, "global_fwd": 4.09451499581337, "global_bwd": 3.8078874349594116, "x_quantize_rowwise": 0.08447840809822083, "g_quantize_rowwise": 0.4291348159313202, "w_quantize_rowwise": 0.5934201180934906, "w_quantize_colwise_transpose": 23.843105882406235, "w_quantize_global": 1.7399191856384277, "w_quantize_global_transpose": 1.8653236329555511, "cast_x": 0.1577921211719513, "cast_g": 0.6089024245738983, "cast_w": 1.1952444911003113, "time_standard": 25.291915982961655, "time_rowwise": 41.2069708108902, "time_global": 20.366515964269638} +{"repeat": 64, "batch_size": 4096, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 8.323360234498978, "standard_gw": 8.433796465396881, "standard_gx": 8.236430585384369, "rowwise_fwd": 3.8114115595817566, "rowwise_bwd": 4.106346517801285, "global_fwd": 3.8080140948295593, "global_bwd": 4.094675183296204, "x_quantize_rowwise": 0.4288516938686371, "g_quantize_rowwise": 0.08437782526016235, "w_quantize_rowwise": 0.8310228586196899, "w_quantize_colwise_transpose": 15.306610614061356, "w_quantize_global": 1.741155982017517, "w_quantize_global_transpose": 1.8809586763381958, "cast_x": 0.6091706454753876, "cast_g": 0.157233327627182, "cast_w": 1.1953115463256836, "time_standard": 24.993587285280228, "time_rowwise": 33.00241753458977, "time_global": 20.471829921007156} +{"repeat": 64, "batch_size": 8192, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 16.656354069709778, "standard_gw": 17.066240310668945, "standard_gx": 17.252348363399506, "rowwise_fwd": 8.220307528972626, "rowwise_bwd": 7.2372183203697205, "global_fwd": 8.2036592066288, "global_bwd": 7.236208766698837, "x_quantize_rowwise": 0.15832111239433289, "g_quantize_rowwise": 0.8406005799770355, "w_quantize_rowwise": 0.5935393273830414, "w_quantize_colwise_transpose": 23.86143058538437, "w_quantize_global": 1.7401576042175293, "w_quantize_global_transpose": 1.8653534352779388, "cast_x": 0.3079026937484741, "cast_g": 1.209162175655365, "cast_w": 1.1951625347137451, "time_standard": 50.97494274377823, "time_rowwise": 57.97765776515007, "time_global": 37.11054101586342} +{"repeat": 64, "batch_size": 8192, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 17.398890107870102, "standard_gw": 18.470749258995056, "standard_gx": 16.520217061042786, "rowwise_fwd": 7.235266268253326, "rowwise_bwd": 8.207589387893677, "global_fwd": 7.235914468765259, "global_bwd": 8.204508572816849, "x_quantize_rowwise": 0.8409880101680756, "g_quantize_rowwise": 0.15821680426597595, "w_quantize_rowwise": 0.8324198424816132, "w_quantize_colwise_transpose": 15.305522829294205, "w_quantize_global": 1.7396919429302216, "w_quantize_global_transpose": 1.8805749714374542, "cast_x": 1.2103468179702759, "cast_g": 0.30729547142982483, "cast_w": 1.1953599750995636, "time_standard": 52.389856427907944, "time_rowwise": 51.05075240135193, "time_global": 38.53064402937889} +{"repeat": 64, "batch_size": 16384, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 33.533211797475815, "standard_gw": 33.00020843744278, "standard_gx": 34.614477306604385, "rowwise_fwd": 16.364943236112595, "rowwise_bwd": 14.551006257534027, "global_fwd": 16.33496955037117, "global_bwd": 14.513172209262848, "x_quantize_rowwise": 0.3053396940231323, "g_quantize_rowwise": 1.6693994402885437, "w_quantize_rowwise": 0.5936138331890106, "w_quantize_colwise_transpose": 23.89485388994217, "w_quantize_global": 1.741711050271988, "w_quantize_global_transpose": 1.8656104803085327, "cast_x": 0.6089657545089722, "cast_g": 2.4122074246406555, "cast_w": 1.1951886117458344, "time_standard": 101.14789754152298, "time_rowwise": 90.37936478853226, "time_global": 69.430410861969} +{"repeat": 64, "batch_size": 16384, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 33.65536406636238, "standard_gw": 33.02193805575371, "standard_gx": 33.10496360063553, "rowwise_fwd": 14.54489678144455, "rowwise_bwd": 16.36252924799919, "global_fwd": 14.50401172041893, "global_bwd": 16.33254438638687, "x_quantize_rowwise": 1.6695670783519745, "g_quantize_rowwise": 0.3054291009902954, "w_quantize_rowwise": 0.83121657371521, "w_quantize_colwise_transpose": 15.305932611227036, "w_quantize_global": 1.7382949590682983, "w_quantize_global_transpose": 1.880194991827011, "cast_x": 2.412091940641403, "cast_g": 0.6079599261283875, "cast_w": 1.1950358748435974, "time_standard": 99.78226572275162, "time_rowwise": 82.04150944948196, "time_global": 69.45198029279709} +{"repeat": 64, "batch_size": 32768, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 67.96638667583466, "standard_gw": 67.99514591693878, "standard_gx": 69.66376304626465, "rowwise_fwd": 33.51752087473869, "rowwise_bwd": 29.131878167390823, "global_fwd": 32.65715390443802, "global_bwd": 29.13403883576393, "x_quantize_rowwise": 0.6002038717269897, "g_quantize_rowwise": 3.3336542546749115, "w_quantize_rowwise": 0.5934685468673706, "w_quantize_colwise_transpose": 23.92345294356346, "w_quantize_global": 1.7405375838279724, "w_quantize_global_transpose": 1.8656738102436066, "cast_x": 1.2112446129322052, "cast_g": 4.81804832816124, "cast_w": 1.1952146887779236, "time_standard": 205.6252956390381, "time_rowwise": 159.09532457590103, "time_global": 137.3264081776142} +{"repeat": 64, "batch_size": 32768, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 68.2341456413269, "standard_gw": 65.5074268579483, "standard_gx": 67.13805347681046, "rowwise_fwd": 29.153641313314438, "rowwise_bwd": 32.71844983100891, "global_fwd": 29.124341905117035, "global_bwd": 32.65979886054993, "x_quantize_rowwise": 3.3318176865577698, "g_quantize_rowwise": 0.6004795432090759, "w_quantize_rowwise": 0.8309967815876007, "w_quantize_colwise_transpose": 15.305690467357635, "w_quantize_global": 1.7405711114406586, "w_quantize_global_transpose": 1.8802620470523834, "cast_x": 4.8183538019657135, "cast_g": 1.2096390128135681, "cast_w": 1.1951103806495667, "time_standard": 200.87962597608566, "time_rowwise": 147.44850248098373, "time_global": 134.84469801187515} +{"repeat": 64, "batch_size": 1024, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.07764250040054321, "standard_gw": 0.07398426532745361, "standard_gx": 0.08482858538627625, "rowwise_fwd": 0.05266070365905762, "rowwise_bwd": 0.04478543996810913, "global_fwd": 0.052012503147125244, "global_bwd": 0.044364482164382935, "x_quantize_rowwise": 0.02640858292579651, "g_quantize_rowwise": 0.02539902925491333, "w_quantize_rowwise": 0.026457011699676514, "w_quantize_colwise_transpose": 0.17770379781723022, "w_quantize_global": 0.07440149784088135, "w_quantize_global_transpose": 0.08142739534378052, "cast_x": 0.008150935173034668, "cast_g": 0.022415071725845337, "cast_w": 0.03479421138763428, "time_standard": 0.23645535111427307, "time_rowwise": 0.42739883065223694, "time_global": 0.3779977560043335} +{"repeat": 64, "batch_size": 1024, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.08524581789970398, "standard_gw": 0.07383152842521667, "standard_gx": 0.07564574480056763, "rowwise_fwd": 0.04478171467781067, "rowwise_bwd": 0.052671879529953, "global_fwd": 0.04452839493751526, "global_bwd": 0.05219504237174988, "x_quantize_rowwise": 0.025328248739242554, "g_quantize_rowwise": 0.027123838663101196, "w_quantize_rowwise": 0.025607645511627197, "w_quantize_colwise_transpose": 0.17121434211730957, "w_quantize_global": 0.07916614413261414, "w_quantize_global_transpose": 0.08177384734153748, "cast_x": 0.022619962692260742, "cast_g": 0.008556991815567017, "cast_w": 0.034421682357788086, "time_standard": 0.23472309112548828, "time_rowwise": 0.42055919766426086, "time_global": 0.3839470446109772} +{"repeat": 64, "batch_size": 2048, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.13731792569160461, "standard_gw": 0.13414397835731506, "standard_gx": 0.14049187302589417, "rowwise_fwd": 0.10158121585845947, "rowwise_bwd": 0.07804110646247864, "global_fwd": 0.09908527135848999, "global_bwd": 0.07766112685203552, "x_quantize_rowwise": 0.026516616344451904, "g_quantize_rowwise": 0.03666803240776062, "w_quantize_rowwise": 0.024981796741485596, "w_quantize_colwise_transpose": 0.17706677317619324, "w_quantize_global": 0.07443130016326904, "w_quantize_global_transpose": 0.07870793342590332, "cast_x": 0.01224130392074585, "cast_g": 0.05828961730003357, "cast_w": 0.03501400351524353, "time_standard": 0.41195377707481384, "time_rowwise": 0.5789995193481445, "time_global": 0.5272142589092255} +{"repeat": 64, "batch_size": 2048, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.14651194214820862, "standard_gw": 0.14011189341545105, "standard_gx": 0.140264630317688, "rowwise_fwd": 0.081576406955719, "rowwise_bwd": 0.10671466588973999, "global_fwd": 0.08158013224601746, "global_bwd": 0.10219961404800415, "x_quantize_rowwise": 0.03775954246520996, "g_quantize_rowwise": 0.026103109121322632, "w_quantize_rowwise": 0.02656877040863037, "w_quantize_colwise_transpose": 0.17822161316871643, "w_quantize_global": 0.07506832480430603, "w_quantize_global_transpose": 0.07928535342216492, "cast_x": 0.05893409252166748, "cast_g": 0.012326985597610474, "cast_w": 0.03498047590255737, "time_standard": 0.42688846588134766, "time_rowwise": 0.5970560014247894, "time_global": 0.5421079695224762} +{"repeat": 64, "batch_size": 4096, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.2734065055847168, "standard_gw": 0.25558844208717346, "standard_gx": 0.29174983501434326, "rowwise_fwd": 0.173322856426239, "rowwise_bwd": 0.1515895128250122, "global_fwd": 0.17048418521881104, "global_bwd": 0.1506991684436798, "x_quantize_rowwise": 0.025950372219085693, "g_quantize_rowwise": 0.0653192400932312, "w_quantize_rowwise": 0.027138739824295044, "w_quantize_colwise_transpose": 0.17699971795082092, "w_quantize_global": 0.07373467087745667, "w_quantize_global_transpose": 0.07901713252067566, "cast_x": 0.02214685082435608, "cast_g": 0.11127442121505737, "cast_w": 0.03481656312942505, "time_standard": 0.8207447826862335, "time_rowwise": 0.8759088814258575, "time_global": 0.8207932114601135} +{"repeat": 64, "batch_size": 4096, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.27839839458465576, "standard_gw": 0.2537444233894348, "standard_gx": 0.28207898139953613, "rowwise_fwd": 0.16542896628379822, "rowwise_bwd": 0.18540024757385254, "global_fwd": 0.15722215175628662, "global_bwd": 0.17368420958518982, "x_quantize_rowwise": 0.06661936640739441, "g_quantize_rowwise": 0.027049332857131958, "w_quantize_rowwise": 0.025507062673568726, "w_quantize_colwise_transpose": 0.1741349697113037, "w_quantize_global": 0.07463246583938599, "w_quantize_global_transpose": 0.07879361510276794, "cast_x": 0.11301413178443909, "cast_g": 0.023346394300460815, "cast_w": 0.03505498170852661, "time_standard": 0.8142217993736267, "time_rowwise": 0.8978843688964844, "time_global": 0.8317455649375916} +{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5755424499511719, "standard_gw": 0.5219094455242157, "standard_gx": 0.5992203950881958, "rowwise_fwd": 0.33193081617355347, "rowwise_bwd": 0.295441597700119, "global_fwd": 0.32791122794151306, "global_bwd": 0.2906434237957001, "x_quantize_rowwise": 0.0337548553943634, "g_quantize_rowwise": 0.1225881278514862, "w_quantize_rowwise": 0.024937093257904053, "w_quantize_colwise_transpose": 0.17729029059410095, "w_quantize_global": 0.0730752944946289, "w_quantize_global_transpose": 0.07835403084754944, "cast_x": 0.058166682720184326, "cast_g": 0.21592900156974792, "cast_w": 0.03454089164733887, "time_standard": 1.6966722905635834, "time_rowwise": 1.5078522264957428, "time_global": 1.4482364058494568} +{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5104020237922668, "standard_gw": 0.5302242934703827, "standard_gx": 0.5842559039592743, "rowwise_fwd": 0.32220035791397095, "rowwise_bwd": 0.3576017916202545, "global_fwd": 0.2939775586128235, "global_bwd": 0.3313682973384857, "x_quantize_rowwise": 0.12369826436042786, "g_quantize_rowwise": 0.03423169255256653, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.16975775361061096, "w_quantize_global": 0.0768713653087616, "w_quantize_global_transpose": 0.08094683289527893, "cast_x": 0.21589547395706177, "cast_g": 0.05825608968734741, "cast_w": 0.03466010093688965, "time_standard": 1.6248822212219238, "time_rowwise": 1.5642158687114716, "time_global": 1.4713183045387268} +{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.194491982460022, "standard_gw": 1.0553859174251556, "standard_gx": 1.0726377367973328, "rowwise_fwd": 0.636763870716095, "rowwise_bwd": 0.5154944956302643, "global_fwd": 0.6281323730945587, "global_bwd": 0.5117170512676239, "x_quantize_rowwise": 0.062175095081329346, "g_quantize_rowwise": 0.23643672466278076, "w_quantize_rowwise": 0.025566667318344116, "w_quantize_colwise_transpose": 0.17768144607543945, "w_quantize_global": 0.07302314043045044, "w_quantize_global_transpose": 0.07866695523262024, "cast_x": 0.11140108108520508, "cast_g": 0.42498111724853516, "cast_w": 0.034831464290618896, "time_standard": 3.3225156366825104, "time_rowwise": 2.7095042169094086, "time_global": 2.645537257194519} +{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.0797791182994843, "standard_gw": 1.062549650669098, "standard_gx": 1.104947179555893, "rowwise_fwd": 0.5390122532844543, "rowwise_bwd": 0.6449781358242035, "global_fwd": 0.5145668983459473, "global_bwd": 0.6276033818721771, "x_quantize_rowwise": 0.23603439331054688, "g_quantize_rowwise": 0.062234699726104736, "w_quantize_rowwise": 0.02781301736831665, "w_quantize_colwise_transpose": 0.1703314483165741, "w_quantize_global": 0.07431954145431519, "w_quantize_global_transpose": 0.08028373122215271, "cast_x": 0.4249885678291321, "cast_g": 0.1113303005695343, "cast_w": 0.0348016619682312, "time_standard": 3.247275948524475, "time_rowwise": 2.742953598499298, "time_global": 2.657592296600342} +{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.392485737800598, "standard_gw": 2.046734094619751, "standard_gx": 2.177651971578598, "rowwise_fwd": 1.252591609954834, "rowwise_bwd": 1.0205842554569244, "global_fwd": 1.230098307132721, "global_bwd": 1.0132193565368652, "x_quantize_rowwise": 0.11823698878288269, "g_quantize_rowwise": 0.4639141261577606, "w_quantize_rowwise": 0.02602487802505493, "w_quantize_colwise_transpose": 0.17801672220230103, "w_quantize_global": 0.07301196455955505, "w_quantize_global_transpose": 0.07893890142440796, "cast_x": 0.21591037511825562, "cast_g": 0.843394547700882, "cast_w": 0.03460049629211426, "time_standard": 6.616871803998947, "time_rowwise": 5.106102675199509, "time_global": 5.0241537392139435} +{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.205628901720047, "standard_gw": 1.9917488098144531, "standard_gx": 2.1518059074878693, "rowwise_fwd": 1.040138304233551, "rowwise_bwd": 1.2538731098175049, "global_fwd": 1.0131187736988068, "global_bwd": 1.2291893362998962, "x_quantize_rowwise": 0.46381354331970215, "g_quantize_rowwise": 0.11790916323661804, "w_quantize_rowwise": 0.027123838663101196, "w_quantize_colwise_transpose": 0.17021596431732178, "w_quantize_global": 0.0752471387386322, "w_quantize_global_transpose": 0.08159875869750977, "cast_x": 0.8433908224105835, "cast_g": 0.215873122215271, "cast_w": 0.03452599048614502, "time_standard": 6.349183619022369, "time_rowwise": 5.064822733402252, "time_global": 4.972625523805618} +{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.755370318889618, "standard_gw": 4.736289381980896, "standard_gx": 4.0378570556640625, "rowwise_fwd": 2.4783052504062653, "rowwise_bwd": 1.9634142518043518, "global_fwd": 2.435591071844101, "global_bwd": 1.9498206675052643, "x_quantize_rowwise": 0.22948533296585083, "g_quantize_rowwise": 0.9186491370201111, "w_quantize_rowwise": 0.028233975172042847, "w_quantize_colwise_transpose": 0.17858296632766724, "w_quantize_global": 0.07418543100357056, "w_quantize_global_transpose": 0.07958710193634033, "cast_x": 0.4257224500179291, "cast_g": 1.680031418800354, "cast_w": 0.03458559513092041, "time_standard": 13.529516756534576, "time_rowwise": 10.532960295677185, "time_global": 10.423608124256134} +{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.050172865390778, "standard_gw": 3.916766494512558, "standard_gx": 4.281226545572281, "rowwise_fwd": 1.9789263606071472, "rowwise_bwd": 2.477586269378662, "global_fwd": 1.9495487213134766, "global_bwd": 2.434592694044113, "x_quantize_rowwise": 0.918261706829071, "g_quantize_rowwise": 0.22961944341659546, "w_quantize_rowwise": 0.025540590286254883, "w_quantize_colwise_transpose": 0.17032772302627563, "w_quantize_global": 0.07384642958641052, "w_quantize_global_transpose": 0.08105114102363586, "cast_x": 1.679886132478714, "cast_g": 0.42508915066719055, "cast_w": 0.03442913293838501, "time_standard": 12.248165905475616, "time_rowwise": 9.717028588056564, "time_global": 9.60368663072586} +{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.53347235918045, "standard_gw": 8.138865232467651, "standard_gx": 7.9666972160339355, "rowwise_fwd": 4.984956234693527, "rowwise_bwd": 3.850068897008896, "global_fwd": 4.9025751650333405, "global_bwd": 3.820303827524185, "x_quantize_rowwise": 0.45222043991088867, "g_quantize_rowwise": 1.8290691077709198, "w_quantize_rowwise": 0.026736408472061157, "w_quantize_colwise_transpose": 0.17832592129707336, "w_quantize_global": 0.07471069693565369, "w_quantize_global_transpose": 0.08177757263183594, "cast_x": 0.8435025811195374, "cast_g": 3.3529214560985565, "cast_w": 0.03475695848464966, "time_standard": 25.639034807682037, "time_rowwise": 19.460242241621017, "time_global": 19.299522042274475} +{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 7.996037602424622, "standard_gw": 8.2748644053936, "standard_gx": 8.523400872945786, "rowwise_fwd": 3.8556940853595734, "rowwise_bwd": 4.966288805007935, "global_fwd": 3.820043057203293, "global_bwd": 4.882067441940308, "x_quantize_rowwise": 1.8279887735843658, "g_quantize_rowwise": 0.4520900547504425, "w_quantize_rowwise": 0.02676248550415039, "w_quantize_colwise_transpose": 0.17083808779716492, "w_quantize_global": 0.07691606879234314, "w_quantize_global_transpose": 0.08223950862884521, "cast_x": 3.3530443906784058, "cast_g": 0.8434318006038666, "cast_w": 0.034671276807785034, "time_standard": 24.794302880764008, "time_rowwise": 19.574526697397232, "time_global": 19.416209310293198} +{"repeat": 64, "batch_size": 1024, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.09413063526153564, "standard_gw": 0.10038167238235474, "standard_gx": 0.09725615382194519, "rowwise_fwd": 0.05979463458061218, "rowwise_bwd": 0.0525452196598053, "global_fwd": 0.059057027101516724, "global_bwd": 0.05194917321205139, "x_quantize_rowwise": 0.02664700150489807, "g_quantize_rowwise": 0.02642720937728882, "w_quantize_rowwise": 0.030562281608581543, "w_quantize_colwise_transpose": 0.2400912344455719, "w_quantize_global": 0.09407848119735718, "w_quantize_global_transpose": 0.10256841778755188, "cast_x": 0.008724629878997803, "cast_g": 0.028502196073532104, "cast_w": 0.05552172660827637, "time_standard": 0.29176846146583557, "time_rowwise": 0.5364492535591125, "time_global": 0.4611089825630188} +{"repeat": 64, "batch_size": 1024, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.09753555059432983, "standard_gw": 0.10102242231369019, "standard_gx": 0.09121373295783997, "rowwise_fwd": 0.052150338888168335, "rowwise_bwd": 0.059779733419418335, "global_fwd": 0.05161017179489136, "global_bwd": 0.05943328142166138, "x_quantize_rowwise": 0.026702880859375, "g_quantize_rowwise": 0.02469494938850403, "w_quantize_rowwise": 0.03324449062347412, "w_quantize_colwise_transpose": 0.23468583822250366, "w_quantize_global": 0.09394437074661255, "w_quantize_global_transpose": 0.10142102837562561, "cast_x": 0.028360635042190552, "cast_g": 0.008717179298400879, "cast_w": 0.05577504634857178, "time_standard": 0.28977170586586, "time_rowwise": 0.5322806537151337, "time_global": 0.4588291049003601} +{"repeat": 64, "batch_size": 2048, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.18056854605674744, "standard_gw": 0.18374621868133545, "standard_gx": 0.19219890236854553, "rowwise_fwd": 0.1150965690612793, "rowwise_bwd": 0.0903494656085968, "global_fwd": 0.11263042688369751, "global_bwd": 0.08984282612800598, "x_quantize_rowwise": 0.027067959308624268, "g_quantize_rowwise": 0.040043145418167114, "w_quantize_rowwise": 0.03063306212425232, "w_quantize_colwise_transpose": 0.24128705263137817, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.1024976372718811, "cast_x": 0.01381710171699524, "cast_g": 0.06845593452453613, "cast_w": 0.05572289228439331, "time_standard": 0.5565136671066284, "time_rowwise": 0.7282234728336334, "time_global": 0.6494410336017609} +{"repeat": 64, "batch_size": 2048, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.16536936163902283, "standard_gw": 0.19479170441627502, "standard_gx": 0.18597766757011414, "rowwise_fwd": 0.09634345769882202, "rowwise_bwd": 0.11937320232391357, "global_fwd": 0.09264424443244934, "global_bwd": 0.11524930596351624, "x_quantize_rowwise": 0.04038214683532715, "g_quantize_rowwise": 0.025559216737747192, "w_quantize_rowwise": 0.03334507346153259, "w_quantize_colwise_transpose": 0.23956596851348877, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.1020580530166626, "cast_x": 0.06891414523124695, "cast_g": 0.013861805200576782, "cast_w": 0.05607306957244873, "time_standard": 0.546138733625412, "time_rowwise": 0.7493607699871063, "time_global": 0.6651394069194794} +{"repeat": 64, "batch_size": 4096, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.36064907908439636, "standard_gw": 0.3711991012096405, "standard_gx": 0.3863237798213959, "rowwise_fwd": 0.22270530462265015, "rowwise_bwd": 0.1760348677635193, "global_fwd": 0.21781772375106812, "global_bwd": 0.17484650015830994, "x_quantize_rowwise": 0.02625212073326111, "g_quantize_rowwise": 0.07131323218345642, "w_quantize_rowwise": 0.030372291803359985, "w_quantize_colwise_transpose": 0.23974105715751648, "w_quantize_global": 0.09407475590705872, "w_quantize_global_transpose": 0.1024492084980011, "cast_x": 0.028584152460098267, "cast_g": 0.1303069293498993, "cast_w": 0.05582347512245178, "time_standard": 1.1181719601154327, "time_rowwise": 1.137617975473404, "time_global": 1.057952642440796} +{"repeat": 64, "batch_size": 4096, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.32703205943107605, "standard_gw": 0.3764517605304718, "standard_gx": 0.3938935697078705, "rowwise_fwd": 0.18771737813949585, "rowwise_bwd": 0.2374798059463501, "global_fwd": 0.1843757927417755, "global_bwd": 0.23005902767181396, "x_quantize_rowwise": 0.07155537605285645, "g_quantize_rowwise": 0.02625212073326111, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.23755058646202087, "w_quantize_global": 0.09388476610183716, "w_quantize_global_transpose": 0.10246038436889648, "cast_x": 0.13131648302078247, "cast_g": 0.028781592845916748, "cast_w": 0.05638599395751953, "time_standard": 1.0973773896694183, "time_rowwise": 1.1699534952640533, "time_global": 1.0850392282009125} +{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7961541414260864, "standard_gw": 0.7424280047416687, "standard_gx": 0.8688867092132568, "rowwise_fwd": 0.432576984167099, "rowwise_bwd": 0.34543126821517944, "global_fwd": 0.4248805344104767, "global_bwd": 0.3432855010032654, "x_quantize_rowwise": 0.03750622272491455, "g_quantize_rowwise": 0.13292208313941956, "w_quantize_rowwise": 0.030599534511566162, "w_quantize_colwise_transpose": 0.24292618036270142, "w_quantize_global": 0.09351596236228943, "w_quantize_global_transpose": 0.1026056706905365, "cast_x": 0.06843730807304382, "cast_g": 0.2539418637752533, "cast_w": 0.05568563938140869, "time_standard": 2.407468855381012, "time_rowwise": 1.9643902778625488, "time_global": 1.8771439790725708} +{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7150471210479736, "standard_gw": 0.7525831460952759, "standard_gx": 0.8075274527072906, "rowwise_fwd": 0.36595389246940613, "rowwise_bwd": 0.4404708743095398, "global_fwd": 0.3485158085823059, "global_bwd": 0.4275962710380554, "x_quantize_rowwise": 0.1329965889453888, "g_quantize_rowwise": 0.03767386078834534, "w_quantize_rowwise": 0.03295019268989563, "w_quantize_colwise_transpose": 0.23509934544563293, "w_quantize_global": 0.09398534893989563, "w_quantize_global_transpose": 0.10186433792114258, "cast_x": 0.2537667751312256, "cast_g": 0.06839632987976074, "cast_w": 0.05571544170379639, "time_standard": 2.27515771985054, "time_rowwise": 1.9977279007434845, "time_global": 1.8952153623104095} +{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6392990946769714, "standard_gw": 1.4941170811653137, "standard_gx": 1.4451220631599426, "rowwise_fwd": 0.8369758725166321, "rowwise_bwd": 0.6830468773841858, "global_fwd": 0.8197203278541565, "global_bwd": 0.6782263517379761, "x_quantize_rowwise": 0.06883591413497925, "g_quantize_rowwise": 0.2565309405326843, "w_quantize_rowwise": 0.03046169877052307, "w_quantize_colwise_transpose": 0.2430342137813568, "w_quantize_global": 0.09346380829811096, "w_quantize_global_transpose": 0.10301917791366577, "cast_x": 0.13044849038124084, "cast_g": 0.5010999739170074, "cast_w": 0.05590170621871948, "time_standard": 4.578538239002228, "time_rowwise": 3.613002598285675, "time_global": 3.5139136016368866} +{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4654621481895447, "standard_gw": 1.5012174844741821, "standard_gx": 1.5183314681053162, "rowwise_fwd": 0.7059797644615173, "rowwise_bwd": 0.8470229804515839, "global_fwd": 0.6788894534111023, "global_bwd": 0.8200779557228088, "x_quantize_rowwise": 0.2564750611782074, "g_quantize_rowwise": 0.06899237632751465, "w_quantize_rowwise": 0.03293529152870178, "w_quantize_colwise_transpose": 0.23559853434562683, "w_quantize_global": 0.09375810623168945, "w_quantize_global_transpose": 0.10203942656517029, "cast_x": 0.5010105669498444, "cast_g": 0.13037025928497314, "cast_w": 0.05577504634857178, "time_standard": 4.485011100769043, "time_rowwise": 3.648221492767334, "time_global": 3.521449863910675} +{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.236088901758194, "standard_gw": 2.8601549565792084, "standard_gx": 2.8000958263874054, "rowwise_fwd": 1.6548968851566315, "rowwise_bwd": 1.3559646904468536, "global_fwd": 1.6249343752861023, "global_bwd": 1.3474412262439728, "x_quantize_rowwise": 0.13122707605361938, "g_quantize_rowwise": 0.5038455128669739, "w_quantize_rowwise": 0.03061816096305847, "w_quantize_colwise_transpose": 0.24301931262016296, "w_quantize_global": 0.09343400597572327, "w_quantize_global_transpose": 0.10178983211517334, "cast_x": 0.25383010506629944, "cast_g": 0.9955987334251404, "cast_w": 0.05569681525230408, "time_standard": 8.896339684724808, "time_rowwise": 6.779726594686508, "time_global": 6.662826985120773} +{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.8433389961719513, "standard_gw": 2.861086279153824, "standard_gx": 3.0227042734622955, "rowwise_fwd": 1.4057457447052002, "rowwise_bwd": 1.6565024852752686, "global_fwd": 1.3475008308887482, "global_bwd": 1.6247481107711792, "x_quantize_rowwise": 0.5038045346736908, "g_quantize_rowwise": 0.13130158185958862, "w_quantize_rowwise": 0.03298744559288025, "w_quantize_colwise_transpose": 0.23539364337921143, "w_quantize_global": 0.09393692016601562, "w_quantize_global_transpose": 0.10208785533905029, "cast_x": 0.9952597320079803, "cast_g": 0.25385990738868713, "cast_w": 0.05589798092842102, "time_standard": 8.72712954878807, "time_rowwise": 6.826821714639664, "time_global": 6.664466112852097} +{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.449159234762192, "standard_gw": 6.384443491697311, "standard_gx": 5.543403327465057, "rowwise_fwd": 3.3065229654312134, "rowwise_bwd": 2.6249960064888, "global_fwd": 3.2497718930244446, "global_bwd": 2.6061534881591797, "x_quantize_rowwise": 0.25821104645729065, "g_quantize_rowwise": 0.9981803596019745, "w_quantize_rowwise": 0.030606985092163086, "w_quantize_colwise_transpose": 0.24094432592391968, "w_quantize_global": 0.09358301758766174, "w_quantize_global_transpose": 0.10264664888381958, "cast_x": 0.5018562078475952, "cast_g": 1.9840113818645477, "cast_w": 0.05584210157394409, "time_standard": 18.37700605392456, "time_rowwise": 13.843905180692673, "time_global": 13.692989945411682} +{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.508493632078171, "standard_gw": 5.689781159162521, "standard_gx": 6.020743399858475, "rowwise_fwd": 2.640843391418457, "rowwise_bwd": 3.3075474202632904, "global_fwd": 2.605751156806946, "global_bwd": 3.2674334943294525, "x_quantize_rowwise": 0.9983181953430176, "g_quantize_rowwise": 0.25597214698791504, "w_quantize_rowwise": 0.03277510404586792, "w_quantize_colwise_transpose": 0.23587048053741455, "w_quantize_global": 0.09367987513542175, "w_quantize_global_transpose": 0.10236725211143494, "cast_x": 1.9848868250846863, "cast_g": 0.5010329186916351, "cast_w": 0.055771321058273315, "time_standard": 17.219018191099167, "time_rowwise": 13.161107897758484, "time_global": 13.013303279876709} +{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 12.975204735994339, "standard_gw": 11.424731463193893, "standard_gx": 11.05477660894394, "rowwise_fwd": 6.623122841119766, "rowwise_bwd": 5.253363400697708, "global_fwd": 6.506938487291336, "global_bwd": 5.211424082517624, "x_quantize_rowwise": 0.5057789385318756, "g_quantize_rowwise": 1.9870363175868988, "w_quantize_rowwise": 0.030517578125, "w_quantize_colwise_transpose": 0.24361908435821533, "w_quantize_global": 0.09384006261825562, "w_quantize_global_transpose": 0.10285153985023499, "cast_x": 0.9967051446437836, "cast_g": 3.9620958268642426, "cast_w": 0.05599111318588257, "time_standard": 35.45471280813217, "time_rowwise": 26.068169623613358, "time_global": 25.83260089159012} +{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.05555146932602, "standard_gw": 11.32136583328247, "standard_gx": 12.035444378852844, "rowwise_fwd": 5.243867635726929, "rowwise_bwd": 6.622854620218277, "global_fwd": 5.209986120462418, "global_bwd": 6.507329642772675, "x_quantize_rowwise": 1.9862838089466095, "g_quantize_rowwise": 0.506080687046051, "w_quantize_rowwise": 0.03318488597869873, "w_quantize_colwise_transpose": 0.23682788014411926, "w_quantize_global": 0.09349361062049866, "w_quantize_global_transpose": 0.1023709774017334, "cast_x": 3.962486982345581, "cast_g": 0.9956248104572296, "cast_w": 0.05572289228439331, "time_standard": 34.412361681461334, "time_rowwise": 25.950465351343155, "time_global": 25.726910680532455} diff --git a/tests/triton_tests/info_mlp.jsonl b/tests/triton_tests/info_mlp.jsonl new file mode 100644 index 0000000..a2076ee --- /dev/null +++ b/tests/triton_tests/info_mlp.jsonl @@ -0,0 +1,20 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 3.807276487350464, "my_standard": 4.196919500827789, "standard_compiled": 3.771558403968811, "sb": 3.5132691264152527} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 7.215872406959534, "my_standard": 7.991522550582886, "standard_compiled": 7.241688668727875, "sb": 6.581142544746399} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 14.26444947719574, "my_standard": 15.685759484767914, "standard_compiled": 14.251746237277985, "sb": 12.735314667224884} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 28.49559485912323, "my_standard": 31.26966953277588, "standard_compiled": 28.414390981197357, "sb": 25.319166481494904} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 5.887262523174286, "my_standard": 6.132654845714569, "standard_compiled": 5.902409553527832, "sb": 4.947789013385773} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 11.14131510257721, "my_standard": 12.859955430030823, "standard_compiled": 11.133037507534027, "sb": 9.303092956542969} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 22.193141281604767, "my_standard": 25.66336840391159, "standard_compiled": 22.22583442926407, "sb": 18.285617232322693} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 44.23898458480835, "my_standard": 51.30268633365631, "standard_compiled": 44.08355802297592, "sb": 35.999126732349396} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 6.938718259334564, "my_standard": 7.269218564033508, "standard_compiled": 6.94604218006134, "sb": 5.764961242675781} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 13.04878294467926, "my_standard": 13.742901384830475, "standard_compiled": 13.011425733566284, "sb": 10.774023830890656} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 26.738539338111877, "my_standard": 27.739346027374268, "standard_compiled": 26.75659954547882, "sb": 21.882005035877228} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 51.905401051044464, "my_standard": 53.98637801408768, "standard_compiled": 51.8316924571991, "sb": 41.67725890874863} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 9.233824908733368, "my_standard": 9.619377553462982, "standard_compiled": 9.214423596858978, "sb": 7.557623088359833} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 17.324909567832947, "my_standard": 17.996780574321747, "standard_compiled": 17.29544997215271, "sb": 14.035224914550781} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 35.51657497882843, "my_standard": 36.674730479717255, "standard_compiled": 35.43049842119217, "sb": 28.38330715894699} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 69.0087378025055, "my_standard": 71.56594842672348, "standard_compiled": 68.82885098457336, "sb": 54.01633679866791} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 12.590140104293823, "my_standard": 13.106442987918854, "standard_compiled": 12.606985867023468, "sb": 10.286301374435425} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 24.830535054206848, "my_standard": 25.563716888427734, "standard_compiled": 24.895809590816498, "sb": 19.559212028980255} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 49.55078661441803, "my_standard": 51.16480588912964, "standard_compiled": 49.739621579647064, "sb": 38.29141706228256} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 98.36294502019882, "my_standard": 102.69322991371155, "standard_compiled": 98.76712411642075, "sb": 75.88706165552139} diff --git a/tests/triton_tests/info_mlp_autocast.jsonl b/tests/triton_tests/info_mlp_autocast.jsonl new file mode 100644 index 0000000..f2098cc --- /dev/null +++ b/tests/triton_tests/info_mlp_autocast.jsonl @@ -0,0 +1,20 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 4.91420179605484, "my_standard": 5.577877163887024, "standard_compiled": 4.810944199562073, "sb": 4.512995481491089} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 8.876129984855652, "my_standard": 10.154612362384796, "standard_compiled": 8.820965886116028, "sb": 8.367843925952911} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 17.47015118598938, "my_standard": 19.857674837112427, "standard_compiled": 17.338842153549194, "sb": 15.992552042007446} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 34.824438393116, "my_standard": 39.499424397945404, "standard_compiled": 34.56207364797592, "sb": 31.573951244354248} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.342606782913208, "my_standard": 7.9323723912239075, "standard_compiled": 7.279552519321442, "sb": 6.395488977432251} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 13.69999349117279, "my_standard": 16.0503089427948, "standard_compiled": 13.603456318378448, "sb": 11.813104152679443} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 29.557034373283386, "my_standard": 34.2303067445755, "standard_compiled": 29.382556676864624, "sb": 22.882774472236633} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 53.629085421562195, "my_standard": 63.07622790336609, "standard_compiled": 53.33048850297928, "sb": 44.76426541805267} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 8.81417840719223, "my_standard": 9.477965533733368, "standard_compiled": 8.73943418264389, "sb": 7.479414343833923} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 16.242466866970062, "my_standard": 17.616644501686096, "standard_compiled": 16.14125818014145, "sb": 13.665586709976196} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 32.429613173007965, "my_standard": 34.80646014213562, "standard_compiled": 32.319076359272, "sb": 27.123987674713135} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 62.85770237445831, "my_standard": 67.55391508340836, "standard_compiled": 62.453076243400574, "sb": 51.53566598892212} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 11.585861444473267, "my_standard": 12.565858662128448, "standard_compiled": 11.504307389259338, "sb": 9.657211601734161} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 21.261662244796753, "my_standard": 22.771358489990234, "standard_compiled": 21.12410217523575, "sb": 17.64291524887085} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 42.85307973623276, "my_standard": 45.70870101451874, "standard_compiled": 42.57970303297043, "sb": 34.918561577796936} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 83.56057852506638, "my_standard": 89.11971747875214, "standard_compiled": 83.05662125349045, "sb": 66.32210314273834} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 15.7279372215271, "my_standard": 16.854502260684967, "standard_compiled": 15.655294060707092, "sb": 13.228952884674072} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 30.42648732662201, "my_standard": 32.26502239704132, "standard_compiled": 30.239209532737732, "sb": 24.354808032512665} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 60.779355466365814, "my_standard": 64.11923468112946, "standard_compiled": 60.89268624782562, "sb": 46.91776633262634} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 119.93677169084549, "my_standard": 128.19699943065643, "standard_compiled": 120.20225822925568, "sb": 92.3452153801918} diff --git a/tests/triton_tests/info_mlp_autocast_ln.jsonl b/tests/triton_tests/info_mlp_autocast_ln.jsonl new file mode 100644 index 0000000..706f949 --- /dev/null +++ b/tests/triton_tests/info_mlp_autocast_ln.jsonl @@ -0,0 +1,23 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 5.171686410903931, "my_standard": 5.839601159095764, "standard_compiled": 5.032263696193695, "sb": 4.89344447851181} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 9.605035185813904, "my_standard": 10.910414159297943, "standard_compiled": 9.230785071849823, "sb": 9.128175675868988} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 18.802084028720856, "my_standard": 21.311581134796143, "standard_compiled": 18.105976283550262, "sb": 17.489850521087646} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 37.49683499336243, "my_standard": 42.40527004003525, "standard_compiled": 36.13145649433136, "sb": 34.58733111619949} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.709823548793793, "my_standard": 8.290477097034454, "standard_compiled": 7.564418017864227, "sb": 6.8823546171188354} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 14.64156061410904, "my_standard": 16.996942460536957, "standard_compiled": 14.4081711769104, "sb": 12.761622667312622} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 31.40200674533844, "my_standard": 36.074504256248474, "standard_compiled": 30.981406569480896, "sb": 24.76389706134796} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 56.93405121564865, "my_standard": 66.35250151157379, "standard_compiled": 56.07586354017258, "sb": 48.49743843078613} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 9.188003838062286, "my_standard": 9.84550267457962, "standard_compiled": 9.006097912788391, "sb": 7.9473331570625305} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 17.268165946006775, "my_standard": 18.64910125732422, "standard_compiled": 16.983114182949066, "sb": 14.70106840133667} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 34.39047932624817, "my_standard": 36.69705241918564, "standard_compiled": 33.8401272892952, "sb": 29.188089072704315} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 66.70494377613068, "my_standard": 71.27603143453598, "standard_compiled": 65.56134670972824, "sb": 55.6538850069046} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 12.10707426071167, "my_standard": 12.931793928146362, "standard_compiled": 11.76995038986206, "sb": 10.228671133518219} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 22.5130096077919, "my_standard": 23.962542414665222, "standard_compiled": 21.997176110744476, "sb": 18.89890432357788} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 45.210108160972595, "my_standard": 47.94136434793472, "standard_compiled": 44.2262664437294, "sb": 37.37735003232956} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 88.1955549120903, "my_standard": 93.6831533908844, "standard_compiled": 86.33609116077423, "sb": 71.23208791017532} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 16.538940370082855, "my_standard": 17.607316374778748, "standard_compiled": 16.108587384223938, "sb": 14.030493795871735} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 31.795650720596313, "my_standard": 33.57230871915817, "standard_compiled": 31.04180097579956, "sb": 25.971196591854095} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 63.021354377269745, "my_standard": 66.8477788567543, "standard_compiled": 61.682507395744324, "sb": 50.138771533966064} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 125.17062574625015, "my_standard": 133.60925763845444, "standard_compiled": 122.21191823482513, "sb": 98.40084612369537} +{"repeat": 32, "batch_size": 16384, "dim": 4096, "standard": 57.31645971536636, "my_standard": 60.84543466567993, "standard_compiled": 55.78199774026871, "sb": 45.43223977088928} +{"repeat": 32, "batch_size": 32768, "dim": 4096, "standard": 111.80306226015091, "my_standard": 119.0284714102745, "standard_compiled": 108.91905426979065, "sb": 85.4572057723999} +{"repeat": 32, "batch_size": 65536, "dim": 4096, "standard": 220.4471081495285, "my_standard": 233.0927476286888, "standard_compiled": 214.26431089639664, "sb": 163.30372542142868} diff --git a/tests/triton_tests/make_plot_with_info.py b/tests/triton_tests/make_plot_with_info.py new file mode 100644 index 0000000..116d1d1 --- /dev/null +++ b/tests/triton_tests/make_plot_with_info.py @@ -0,0 +1,137 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os + +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(12,3.5)) + gs = gridspec.GridSpec(1, 2) + + + ax = fig.add_subplot(gs[0, 0]) + + rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True) + df = rdf[rdf.batch_size == 32768] + + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), + + ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), + ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), + ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), + + ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), + ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), + + #### time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), + ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), + ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), + ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), + #('standard_gw', '.', '--', 'C1', 'standard_gw'), + ]: + xs = [] + ys = [] + for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + + + ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) + + + + + ax.set_xlabel('dim', fontsize=13) + ax.set_ylabel('time (ms)', fontsize=13) + # make a legend which is below the plot + + + + ax.grid() + + ax.set_xscale('log') + #ax.set_yscale('log') + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048, 4096]) + ax.set_xticklabels([1024, 2048, 4096]) + ax.set_xticks([], minor=True) + + leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight('bold') + leg.get_texts()[1].set_fontweight('bold') + plt.subplots_adjust(left=0.1) + ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) + + + ax = fig.add_subplot(gs[0, 1]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]): + all_xs, all_ys = [], [] + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + all_xs.append(xs) + all_ys.append(ys) + + color = cmap(j * 0.25) + real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + + ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% speedup', fontsize=13) + + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048, 4096]) + ax.set_xticklabels([1024, 2048, 4096]) + ax.set_xticks([], minor=True) + + ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + + + plt.savefig('tests/triton_tests/plot1.pdf', bbox_inches='tight') + diff --git a/tests/triton_tests/mlp.py b/tests/triton_tests/mlp.py new file mode 100644 index 0000000..1ec85b8 --- /dev/null +++ b/tests/triton_tests/mlp.py @@ -0,0 +1,64 @@ + +import time +import torch +import torch.nn as nn +import bitsandbytes.nn as bnn +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear + +def construct_model(dim, layers, module): + modules = [] + for _ in range(layers): + modules.append(module(dim, 4*dim)) + modules.append(module(4*dim, dim)) + return nn.Sequential(*modules).cuda().train() + +def get_time(model, x, name): + for _ in range(repeat // 2): + #with torch.cuda.amp.autocast(): + out = model(x) + #(2**16 * out.pow(2).mean()).backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + # with torch.cuda.amp.autocast(): + out = model(x) + #(2**16 * out.pow(2).mean()).backward() + + torch.cuda.synchronize() + end = time.time() + print(f"time {name}: {(end - start) / repeat * 1000:.3f} ms") + +if __name__ == '__main__': + torch.manual_seed(0) + + # hparams + repeat = 16 + dim=2048 + layers =4 + batch_size = 2 + sequence_length = 2**15 + + # construct models + standard = construct_model(dim, layers, nn.Linear).half() + my_standard = construct_model(dim, layers, MyLinear).half() + switchback = construct_model(dim, layers, SwitchBackLinear).half() + switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half() + #bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt) + + # simulate forward pass + x = torch.randn(batch_size * sequence_length, dim, dtype=torch.float16).cuda() + + # get time for forward and backward + get_time(standard, x, "standard") + get_time(my_standard, x, "my_standard") + get_time(switchback, x, "switchback") + get_time(switchback_global, x, "switchback_global") + #get_time(bnb_8bitmixed, x, "bnb_8bitmixed") + + + + + + + \ No newline at end of file diff --git a/tests/triton_tests/mlp_decomp_autocast.py b/tests/triton_tests/mlp_decomp_autocast.py new file mode 100644 index 0000000..3a1fc9e --- /dev/null +++ b/tests/triton_tests/mlp_decomp_autocast.py @@ -0,0 +1,166 @@ + +import torch +import json +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +import time + +if __name__ == '__main__': + + print('Startin') + + + for dim in [1024, 1280, 1408, 1664, 2048]: + for batch in [2**14, 2**15, 2**16, 2**17]: + + if dim != 4096 or batch != 2**17: + continue + + + x1 = torch.randn(batch, dim).cuda().requires_grad_(True) + d = 2 + + standard = torch.nn.Sequential( + torch.nn.Linear(dim, 4 * dim), + torch.nn.GELU(), + torch.nn.Linear(4 * dim, dim), + ).cuda() + + my_standard = torch.nn.Sequential( + MyLinear(dim, 4 * dim), + torch.nn.GELU(), + MyLinear(4 * dim, dim), + ).cuda() + + fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda() + + sb = torch.nn.Sequential( + SwitchBackGlobalLinear(dim, 4 * dim), + torch.nn.GELU(), + SwitchBackGlobalLinear(4 * dim, dim), + ).cuda() + + standard_compiled = torch.compile(standard) + + print('Model part 2') + + repeat = 32 + + + info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim} + + # k = 'standard' + # for _ in range(repeat // 2): + # with torch.cuda.amp.autocast(): + # out_standard = standard(x1) + # ((2 ** 16) * out_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # with torch.cuda.amp.autocast(): + # out_standard = standard(x1) + # ((2 ** 16) * out_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # end = time.time() + # ms = (end - start) / repeat * 1000 + # print(f"time {k}: {ms:.3f} ms") + # info[k] = ms + + + # x1.grad.zero_() + + # k = 'my_standard' + # for _ in range(repeat // 2): + # with torch.cuda.amp.autocast(): + # out_my_standard = my_standard(x1) + # ((2 ** 16) * out_my_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # with torch.cuda.amp.autocast(): + # out_my_standard = my_standard(x1) + # ((2 ** 16) * out_my_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # end = time.time() + # ms = (end - start) / repeat * 1000 + # print(f"time {k}: {ms:.3f} ms") + # info[k] = ms + + # x1.grad.zero_() + + # k = 'standard_compiled' + # for _ in range(repeat // 2): + # with torch.cuda.amp.autocast(): + # out_standard_compiled = standard_compiled(x1) + # ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # with torch.cuda.amp.autocast(): + # out_standard_compiled = standard_compiled(x1) + # ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + # torch.cuda.synchronize() + # end = time.time() + # ms = (end - start) / repeat * 1000 + # print(f"time {k}: {ms:.3f} ms") + # info[k] = ms + + # x1.grad.zero_() + + k = 'sb' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + info_json = json.dumps(info) + + + with open("tests/triton_tests/info_mlp_autocast.jsonl", "a") as file: + file.write(info_json + "\n") + + + #exit() + + # err_fused = (out_standard - out_fused).abs().mean() + # err_sb = (out_standard - out_sb).abs().mean() + # print('OUT', err_fused, err_sb) + + # err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean() + # err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean() + + # print('GW2', err_fused, err_sb) + + # err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean() + # err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean() + + # print('GW1', err_fused, err_sb) + + # err_fused = (x1.grad - x2.grad).abs().mean() + # err_sb = (x1.grad - x3.grad).abs().mean() + + # print('GX1', err_fused, err_sb) + + # import pdb; pdb.set_trace() + + + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file diff --git a/tests/triton_tests/mlp_decomp_autocast_ln.py b/tests/triton_tests/mlp_decomp_autocast_ln.py new file mode 100644 index 0000000..2596278 --- /dev/null +++ b/tests/triton_tests/mlp_decomp_autocast_ln.py @@ -0,0 +1,165 @@ + +import torch +import json +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +import time + +if __name__ == '__main__': + + print('Startin') + + + for dim in [1024, 1280, 1408, 1664, 2048]: + for batch in [2**14, 2**15, 2**16, 2**17]: + + x1 = torch.randn(batch, dim).cuda().requires_grad_(True) + d = 2 + + standard = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + torch.nn.Linear(dim, 4 * dim), + torch.nn.GELU(), + torch.nn.Linear(4 * dim, dim), + ).cuda() + + my_standard = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + MyLinear(dim, 4 * dim), + torch.nn.GELU(), + MyLinear(4 * dim, dim), + ).cuda() + + fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda() + + sb = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + SwitchBackGlobalLinear(dim, 4 * dim), + torch.nn.GELU(), + SwitchBackGlobalLinear(4 * dim, dim), + ).cuda() + + standard_compiled = torch.compile(standard) + + print('Model part 2') + + repeat = 32 + + + info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim} + + k = 'standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + x1.grad.zero_() + + k = 'my_standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'standard_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'sb' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + info_json = json.dumps(info) + + + with open("tests/triton_tests/info_mlp_autocast_ln.jsonl", "a") as file: + file.write(info_json + "\n") + + + #exit() + + # err_fused = (out_standard - out_fused).abs().mean() + # err_sb = (out_standard - out_sb).abs().mean() + # print('OUT', err_fused, err_sb) + + # err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean() + # err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean() + + # print('GW2', err_fused, err_sb) + + # err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean() + # err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean() + + # print('GW1', err_fused, err_sb) + + # err_fused = (x1.grad - x2.grad).abs().mean() + # err_sb = (x1.grad - x3.grad).abs().mean() + + # print('GX1', err_fused, err_sb) + + # import pdb; pdb.set_trace() + + + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file diff --git a/tests/triton_tests/plot1.pdf b/tests/triton_tests/plot1.pdf new file mode 100644 index 0000000000000000000000000000000000000000..1fe71682174766b2d551d9aa055a72e6eb837737 GIT binary patch literal 34302 zcmb@t1z43$(*P_;cPl7(=!SC+-QAti(%mJkq;!WMAq~>q-5{liv}U&dl!4&dgFPi%T*?SlH31OBX;T&1eu18027VjmFOpVpaEa zGy}1U8MzwSI#_^Mm5nUSTtHBufhtH)5Y5cq6xNXY9~~qe>|H@@Hw{>|l#H#-Ok6=+ zzkW)1xJs$H8o8Q**nfRcHga_}bG8R@z<#2!s#zMDTG?BGIRC13b}&&ha|P)FV8tZ? zRLneFL98-%01Bdiuy-~Cu=|BRtE!ocgPXGnFdrEH0DJ{A zQ!67;2M+)u82ICcu(NPMxwtsEc(}prASgRK3pW6VRScLRFpLX`=NFzL_Vy0IF9@v9 zKdAz%{+o>QX7(1YmLSMqpe3zrfr*1yC2av}h?|)>n3}=xcX4$#GqOYTNdKUx@3Pn! z=iAW{z^cnpL-^j(XO$d#)%0Us@o0~!UQ+S!nn*GS8Fl>Es$)7w7I#`z`Ge zl$Sq^vwh@>!m@HCYhZ12ffC2iaaBEk)$09|)V^I}HKirvw5nk&{xljA08R#Kuhz87{}qc7eUsmr+!9kGY?ZziX|L5lB0VDv@D@6*S-FTRM_t5{SxyW|_~ z-`Z0mZ7DN6EmS-aq-WP4*EwIx$ranq&={853W!xq-u0XD{dC1EAOfdsEc}92n3?KD zlE`C{!37ezNB0~S zx$My*{wv}40z3MImADnwkq~>`g#79Lnh)EOaG0~I27qQ6PWF-|Ok37A*= zp2?9FTazFr)#QiefDbj*;)6<3s@sXuoV?f+R&nOODAFmlAIe7Y;qb07i3RBR?p6;g zAp;GBfd=qlV^MMvWLodmU!G9KP7!$3%1zcvqkSI&dLzb{jfqKH-^VW#E%?T=iAj}V zJ>WNe5?Ny2=UDz=}OJ=8v)jH zLU#!D)oH9iuN;jAK2g44)5=-4Q+`o2NN8Hg+wyq<`Jw6G=LAo{(SqqUy-xEFON_9Z8NWYGpa)?Y$!CjH3ZG5YFT4ByMH z@?g8MFQ_xtylebIZob80)R>+3vDDJ?WzF5Uw5!tu<68GAj|vqBXY$G{{aC5;$jL!l zX=GP-YeqVHLb%T2A7qFWC`q4Vl9M|7BbsMZQ*nnH?riFaQnbGOMuA@&^sQVfF3;p8 zZaj(syqN&rQC{zP1K}qjRvx19IP?D2>;~zu&syIJY+on!Gt+0m1zH+YyjqZ5O-@Xk zln{zhDNCEId`Hfb*L6=mbo#BOb{M^7Z_pFIdxgCvkDe#BH>4(g3tm`YZX#}*v@3VT zl7XDRwwUU*MagxqW#%AlhWxa>j7&?eQz09r<{-q^^TcoG<(PE})FbS-H9l3t_sqg! zh-x?2Bx#t+GvJBJ%a~3Z3qw7klNzFLGZz<=pWjH_O~P+xx~ks(y5FD&qRqcwtllV?{nh9>T9fV&FjKDW*^`8O6Wd#w~4k$3cprT8IZ z#njieLs?(p`um+w>Mp(qsz&FWX@2&~KTH=24VI%@E8w@}Qq%3W{H&JNuiTO}M?gYZ zO-K0gtyUgquoo*Y*`>rWGMQsXE4*uwRwo{zt~+ra5;m3jw}7U8OM#L*T@qg=*etX$ zG1>bz?M5n=z2(u%7-V*5Rvh|US()G?Ryh^in&*{dgNq%7CEtf>&ef-jN4t^1aj5-t zxu8c?j}1(ak*>6nWKWzD_#QepYoc0f-e*e_55|RHoy2{p{+=6IUx@DlnOi^CYFIOq zE@=(=eq5X4DPtX+LZk#mH_$+doZUueNjy9}=f32MgqLWALun!U`uFGj{e+P*T5~IfW9#gMucR1#=o}D>e0@*9#fEZL1u!M*Ytl9P{V z<7+ByePU4ko?t~Bc}~k?sJAz-E6}EP+9x=5GcHW;$Y%836hbp+@|x6x2-epHJ+5q> z9xc{G>92GSEtAq65}&a?7I8ChdZ{UVR$eje`*dVCICO1puXm?zG0f6?_-B9p_t&Ob zORZke^6mAyxFaI-1(X?AsXM4YsJ_hGslU;e-fg~aYi;7bn9$F-x>)JiZ~kcUbfd%d zdUny^^2g!I`ejZ_bEBvCncUgM*ygk0{OOHPsoD53y5jzal_IV$ys|9xt-J9NMBP_@ zKFNC=yG}|q-N*1Hcw_FJNr6~4Zb8-O?~2kK&5lb%3=&hh&-*$Q`K2B#9&gu|jA}V% zVv_lF%punHcYx7Ef2vJ#>b>!~cWuDbOu`cJK|;Z)T;xmo)6({tx*eaz6eb@0Dyz9u z7UXmz2Lx#YmoCauOf~Uhf0if$HXSpPt`Afw65u6OMB?J=RHKD{Py^QbL6`mqWqsD7 zpFNASqpxbP9h$Q~CG7kCSU+8LJKNjXyPAE`GRp~%HJHNGjMb%FEVXDS8YP(||2Z}u z+2B!zElN@Q8;!~1?qC8NKPQre5oij$k{bEympgkz490g3^x89A4!~i~ljcr-SG{fF zcTonE$zCnQ_3unh9z)ip*Etab;k1iBf8lXu-(905f9qm1dPcMpOQ9irq?F+pc)Y$T zd0KCh;v1`->o!KVGHFN?LMl(6mwH#fgte#cTfMX-e^44)nnxZ& zqYkaH$|8=)tG?4#_g8KDlHcEXJhVlkv#0je@ zaLB!FJ0`{HBwaxbIO(;9TyG-Uh)=K^J=xG`W`dTvr&`GzxH0Bp?(Cklv^>Ewi)Ow5 zl)hvB2(OzV6`U9^<&n#gX5=MA%xUPbN|oA~62CgpVAO0Pba~y=)FMN}Q_&%5AjI;* zXGzZ|^o74HZsehJxOkIKE%-i0Ck4{WKpNcV&Yll{yrj;{&3@K|dFAU4Umg~sWJ$A7 zVj5r|O9to5()kw5NMTUUw7zw@PEr?x^F#(>FE1F}YGNMxE*v4S;(#-T8%ME3@h?iqkX=@lf>O~sE%Co(z z+AGuss{Acq7Pmv5@Qba(Gt4@XMWo3Tma@G`JZhb_oB%mIVjI~@Tokn zywau9YoGgXxaH~fGYcQrV9293o*SvFLR-qn8lajb4BQvQaeA^2_z$=kL(X2&b(}Q)IARK~ zfF~a}$T%a&?NydmLaLgiQ?_*@3y3Rhr|~L`<+vik-?CE1E&2=@3J!@7_2P|I&F!t4 z^u&=`YBa%nB(=&FqokF&dFbZ&&5_{yKt9p{g8->ivC7KKCV!|&xdUbl8>j7|iV&_~ z!E?_hu4ztg&8;nN*(kgL$qkW*3F~Z>{80aBT^~KN$RiEMDLN%Q$=(9p$DfDjWPRa# zC0EHYJRd8rwS6K{d&kT|(y#%-j#Tz>l6{$%ptOdJni3&0&)XAXTCgT!`q>sQ*m0oshKw7ZfCosJv4^y|G8`+Ndg1m} zB!njEz0hn%=Sj0dJl^Z9Jmz<*`~FofXYi9g%|@U>L%x7RU<#iGcrHGor2g&cfVV*d z$JpGmbp11wF_391Vi#4}v0Qnv%_C#Hcw7v2`dW3ga;`jQQG9vldSu%ka+Y|R`^3u# z&odzObFEh5YI4*|@dS}-ODGI>+!MAN^VaoQ?tKbnqqT5PS8?v1-ZB2+mu+ zbtrxw>CTe>;gosb_|f-Rgxqhn9hscb(1{@XQC9Y*{jynTuKu#`+vpC#Ou=KSE5lm+ z7ETYF(epWuhdHDb54~LTeKOU)Wbk82A7P;^?|-lL4|#h3DV(L?i0}~$&jDkQ&!uD! z-+%)|fTIvNDqeqH)U3H`E%BpZRq~!_h^7NAtz^oJhkn&LCO!m1YVP;Msjbc-UEfNK5tv8!(m-P*huA4sRs*5T#6aacw^e^74t6_e|yyMv|z`^FdWJ%Jo5&JG>uW+PuU z6EK$uX))(wvS+yGo_~bHF;N%$GVeIYy|A46;S!Z=IefF)X4j$>eKXe3Dgdp;T497Q z#%SS4M62C3t&9x|rX|ZuH`>oVy7KJ%p?i{aZyjavNdMyGNT2Unh|{-MKk5&^8TF-& z-ObW?>ov3CqdB?g*V-<4aj^f%?YjFqh0CFZE9K)ou7;F-d;6G+f^G@Hsk?WD|~w_Mrlc*PbuN*p*+GBHaR$LCi(0 zP_`b^po!r2Kt>UwKLvgx`fMA~Qx67*$Pv;hDZwq%+TgPf+ApS5^BzYf3Y9wcREveZ zdxK-rpr-TXN=>v&unJd+F?lx*jWC#?q;XB#v^G16~ARmC&}8HmS=xR#P~Ib6dGV@JrgCf-%@8a*13< z;HT;a-eyiqbwg({K>?H@N=4XFmfv<;@v!@t89w+_7fjqcXF4p>KmO{IlRfSY4N9&0 zh*Y|ki&-L6^@GjRKzxH33;JE2XEvTMd%8?^>ClLYjx}p6(ydWA|1mQ&dpVg#+<_jN zR8XFvzp1Z#pch09q3DN6a^|b;@6*YN@1CZ!aQTc?S|}%OuAi*!UBkTTzn50Mf{8x zQ3_{;(uu54wZ?I;Bf~y@xI0Hqa`dh{OR6eA-i84+E2+EQLW568Nu4wL;d3m)@kE=0 zZ~Eu4x`if88;=+&#v=HyXrhm>^>SjdKc+L3_IO0OnU}N>GA;>)dtYLf&^Kd+IjVZ* z6~&Q@n@p09g?mA8Coxv7b2R<2Vyn`4Of?QL_2ehR0#jbMd@CW1t`Lx%pVuS3{)t$R{ z{RkdpqX|_*)Umi?s_u*$rD{1U?%eP1s|$GMobh0?(!fDanOg?UBxCSLy+`)qnGjEr zW!%46zy>oj{}*N1*#CgC9IAZw)DVoF4T4tTJ3AJ$8=hUp5nr$7sX3&nMeC)*-7mUD z`P!K(YbLkO>kIcJDx#hRzIiC5M5w|_V6lyG#J$ZcsAO);DpuapjGLh*q;yWY_Eb!M zEP;0vpVq>yi(ael6NwzHX#Gsn;H!g-DcOu~BJLnf^4X}ohTU}Kp%cw}6|q*x)ymTK)ibOPw%utSIbIA6&<9i&Pmt}Aq4sv%U{xr&PE-#~ z$zlQIBuZw!y>bPN%4C-_lTI=Iq4xb+hHSQO%?1 zBU-VeaMZ92ef)^0F68~!f=F4=`wQc%5#+O^dxli2g>{43txZ3}U+OyKsUv#0CRs!` z`BIrnr99M8%#t+~m|ktAYn#yH;Ofw~ETCI@+iKZByuHtQ$^R@%=o*gK;aTP_{DGPI z|C>J$&Oea+c)92`FoxI}PcX~G^OKs@WA$7}MB0VJ#$@WYgs{l;3j(&s^Odp}N3~1u zy|6A$8BKJjYC9f(Z_Z?0@0GsyiW!4<828QKDYP$31nQ@ALSB=zm*|| zd>IU^U~~~AKIJw&(yTZoI1yZAK=SSvQCOG2^;zoh`SdYM4*KryUN{oZ2j2ttq+{g~ z=gDz~O*<9G2G(z-o@2P$$WKywG~)_&)B0`rQqmgiH@=}7V$qWDyq!vyoRl=Nk*uK` zLL%|*bR_jrr}H|OxxTh!Oa5u>vxL5wmL<=!7w=fyozQ~qtV@uAF;~*~f!V@Mx8lOJ zC2-Tw(Vm{hl3^)DtmzBh`&ZkYDSNBbv-S)67ncuLn1i^M7#Y&7Pu-3%n-)q+&C?&r zMhHf?I&N${^2Bj{ehU*|&XxaS0+fyG_e|hWk+VmHpv-R&w09<))4F6n7I2$fdZsI; zZbij!7)kz}8{xA@F=bgS_4zxF3B-=TOWVs*0&bZ_ir+W*5Jhr%539H-TCJMWuCgtz-t}4Vc9MHbReJ&1 zzT+ABfwOJX-e5HT+o#^VT!ngBZOE=`AI5w`>tmY+)XZZn6NukVIEO4>XDsb_`~JK$ zaiXDd3)NsgsDDw7`wyrVBWHgX0*I~)nmz~+-LJ>;%B}#}9;lApNezvqJ}Pz{j+OQ- z9&t@b6I=3CevEwNip5&lZrr`gAE-cQN#XzloAtm=j)+s6)Qmef`(r=v~ zhAl)2rTuZPw%GfwPOB8d4n+#jpr^VpcXRy2^?nZv+~6HeD$4k9gw$#AEeTdofq!!E zEA5Ttiyw#;-HT4GAFT=RiudbHVeBDKy<1|tm_8~VII)Y52%gm;zPILUvrWXxYK6kWz_1bPB8_d)9Uu0tk zga3eRYBDggwTv@<^M{|dz7SADXj{Af7=MH@p&m{hHx)leEhhU#N`F|){N?5Lm#W^f zHSZ=J?|M}|Bb)GD`ZKe^s<0S`swoo=V*f(?PqEpNuO%1n!-WTg7x9V?ig`b2ACrq* zBsSV94l9xjE0xR_;NOJ=U@#bsPD&EIp=S8dakcRAli3@OCz}e{j3EIpd@5y2qnLtB z6uX|RMA49rJEFbi1=j=^984E@Ez%^Y4KGqsJz`Gus7vYLqx64f?QD^oSPl%!zU$v0 zSr@5yxG7s2E{mU`>Iv#Lz!AX~6m!NpM0muxGl)R)bDpftK(S>c_2C!+E; ziilYE%f#(&7io<=Vp6wCWj3$;yN6+a)G*bNrKW`*ycwI>2C?y_tKWc6|*8!c3CpK;~kb|R^ zHM9f$Z7g3yj~F_c09g_^a1L#D0*Pwx%?Fd04nkjdkFmW;ayD*Z>_1(+e>*eT{$MLt zh!MfNOOArD-x@%BC!leER^1wV{rY2M8i%)2^TT`+)~;@^Wyhyu4410Shit`Oj)g0p z3A1?={si@RwG)-9kwk42nqdw&OKm^yZc^< znZj`(>W+a7MeMEX3b-U1ftd($Q<8}^teye`!d`}+j}MtC;X9;=$pV{%L`4F z$D{V3TAC(;S*SB|_yk95r>)8$eZ;6K%qVhjLP8WI3#Mcy3v{<4qpCloDg1a4 zcL*vJD5Zmj$vMbRvEy%&@w1u=5TD1G#)~T$X6}MYrSB~n z-yG6(%?bU^{IyWjo&(&{Q(0t#5ZB|lvpX2tnVa~z1M)FtsNvvk1DXrCP%!6ZsH>hc zJOp!1vJ{7~Wh4wWjiFt*u2ax;+DRVOZ%MleB6Vnud!e!A%=_63*JkiA`*K_AYm0}q zbDrXS#3QQ0-r{Qlsp6z%LDXnhDP{GnAIFLB#A{12MNmhTAm`ndkxRvmH@%w0-{k zuX3a}44{1qWMG-6Vd+P`l{xn0Sw?JWTMX)p)d-kQ@ ziS1jo6L}*G^XClS^fFv|;*A7a&=R|6B*IT*%JfHKv3y%IaYnD=9k2~aA`sriSo?r< za{i6+JgK_oaQlvWmh?jl={d-^4kX65c8U;n!&|6xi$=@+2V0iHTM?{I5DG#kU&S&i z(_G!pLMyL{>n9~My#|KyW=8pG@K*9(lPemmD9QvwwGmnLQ=;q`qsM`ly=jyubX-Hd zN+os?NZF(d7pKKAaeD(-{-sqhba8IlA$DH6eLIlC= zMFTDeejWv!%vwzf6;1tWz6*H@$azIBq`!;zymssJ__?mlJ;k1_j$_*fOg}8!X?a5P z{CY>sktmd4rCqYKS+A%ft>(eBm*f`RCD$ZdNhOoEM6++?c<5=p10PTnIzWjv$&C%y zOA?uMg1xgyRonULb}K3pkL)o@NBxLovMCJ#Iwd%{zO-NNtHMhM*uc zxgO?E%&kS77WOUpp4!O#MQjoO5k)4o@AFY|IHLZvnOeMbE~!XwI-6PM0=M4Q_wXqi zg1VrP*G_ThZP{n`KR6ntKj7BzVwTzZ7j`o?$cTn>k+THHSX7@V&{y#$Eo+^<`Vz$~ z?Pj z*IEBQkS0i8xT^1zo*G@i^jwTtoz^E^El#V?8PB~ zJx>tvp5>4{+1PQeCQ-(P!!c+W&DK#xgWF*^k=?Ucu0Kvst=%>6O`g*Ct+)HhTGmD*^B7D1s6b@VLOPSJlGN+#)UaBSuy-P6KBw{_qhCSz zWBd!DqV?jWr_a0$c5J?2TIF+^&(eQl?Be#%AXD|0+mp1VI7y<&3KH?^te_2mOhH`Q zZSezB!nn4P?QN^>(0&EKB-vAgggHT^UNADag(2-tVkWaVDK{zQ4QDb{(c(RPK&YTD zn;pyK$c-!|iQYUUU($)qVsBZ}`w%UR(-IrJI<$pKnfk@l3liqy|1*1lpERv&-X~y# z={Xj6%sPT^K967)JkF?+llM0B-&ml*mm6V9=IN_MQXXCShjD|Em^ zb<5_*af|K@`GdWIjk8J_{Je)i47e|l5#PPVR0Et5|GL3I0HgbF*P}#?2$mQ*upZL` zC{c_jk7wBRXD%H02CqK_agRqTA{~zf=dV;w&qN&QN(&q|6sI2NRAQ$x&@G5ThZR^& zb_e^sN#j0wsxpH}^*Cys?1~fK#!nYyKv^`x{K^Nr)%)e~^z^jc4Q2X}@4sdWN+KJa z`%D^sGF`Ko!|MA}8w8cVMRIS`{D9!&-%dtBMZ_L~9HRv|BMhnW;K}!Atv`reT-(^0 zOKuO*(Db8PMMp1q%yu3{5r?tblyWn6P%vDvXC1~!eD6u_*$_VG=XM`xA?4hY?U)#` z@7u~X;G##@+@X7rjIJn^u(OidP_!0=$jE0Xrug+R59$?99P;#~U3(Ljs*Dg*bJakq zQu#2EYqf%e%%(H=3yu5edFhIA2HkEK6K)-iBw@N(x#-rciiZMSisrUF`5MNRPBm0- zODtEzutl(Q#>3%iv3-66+cR4JHrcsF;p6!e%^-3F_c?GQ+*cPt{d`YUD0-Fr^ci8! zAw2s2C=w$l!3ae1b@8f5?9)li1%tHC&(cfrYdZ?yS*s0Mjv=eU)%~Vqp^E|=8w-N{ zFR~`SgC@IAJi5N$!C4|B{x|1S20T?XAl^TRo>{YGL(Y@r~+cm#m!tyoUI&P9h^bXze3K{+>Bjs z0`p;j*w9!NjqHH%_WyHyU=Yz?aqY|yFc%oa%nn%LAv|2bKL{5G;HtWTqhRFfY;_a- z&H@I*c2MB&e@d{pcK}0WAST}h#0IMc0_G*`fyi}O#QuL?3jF;EjaA&r+}sR^w}%DT z>w;hl(#6rp1c?0y_F+d@Aik}cIS|i(Q*^enuyh5n^MF_#?9D(NfV{f^m~a4xs@Y8q zCr}avaRW0|1MvX1YaS_tRYq)s`@Q;@SH}wA7I`SW;{(~bol=}_& zZH+7dnn8b2`Tvy>AXX6<6W9|49sn{c3=HfCGZbcN6Ekv@{zVNyO0@ni0uLa7rx{iz zBK8)xW&l@M)m+W&G+_k=BafSDfgtPv$^Y|M<7SeuIq8Bp0n_aN4+-u6>lkc+O3V&$ znF9&|n9B{~n3qP*=`!Jaf3nZ5DpM5XXE66b@;y$+CS@H zHfUIa0r@*|0Am2az#uNb-o?cU1#z*l1C;mc5+~3X$N>_u0}y};PHwK74zNz_KxYmBGj^c! zO&xIJ@&JH2*#WG9G6#U!&0xSFupB_(7vEqc1dzBPEdZ1gK<B*-{OOI!x4a@u;3mX zm^9o70!$eGD?z~A9Kcur;J+l}@BDAc_&fh4TE8R%CKdlH0kCg=L!f9d3HdwS2nSHd z1?)#KLAWge3mq7dy3|RS}35ML?^)LbWJHdnl5Fj*I{-2wh7pMZ+z}R$p~QCuWu$GfQC1f_{Q7PQ0>t8` ztrhGXxGC8H67j2K2Lj~!=9@jhs$V4s5c>@rw40LwkmZ|tN6?L-1;FD!CkguvwEume z`~}t-cC-K;(SY6m=K2P7yg6Y00{rV>0_J}+xGU&Jr3Biz!H%MvZ%;wMX$&jlG(5s8{|d z`eL<5(g5|9|9;?f`Z-C}iqM>7UXqZyE(jl@LglK$MRF$1)rvOhmlyE(&6(CQ?o-B5 zU0(O9j@9Vi zHMDO6r3!+MWNJv#-qV{8+7=DiF=#^3xZXsDU#Ls!U>|XgL7ykA20?X{+)HV4E~jWy zV(;NQb+hpC5gFs7KX_65C`qvD(r0J%{FPN%eyo!S1$$rqCyUmB4Ug+PNlJP&|5lhd z|0!aB(eF>JA#By`8i41a_6}Z1@;chOWb%#YK`=CF1QCDae|a4qlicZu#MHb6G>p@K zAM1aC2GIE}=h!$n9{&PeP$2ogK$}@wVVQqKoi86eS`yU7P2-7a?m-)R9J0jQ@_xX-PbrPeC>VyY~J3a`498lP%J(w_pep z{{JG=FRk=9kxQ`9fxoWNhIzv0RR432e%F9X^vEP(j#`c4OMZQRvTV~(O0=6ERKafm zYGXZ)*)Y~e(qn%nQHoirCA|A{%uogeO7*Lx4cu`7`r0u=O zYy)boMf+Ol(|gyiHN zG_@5)WZKp#;U($8H#bQ0u&2>s7Wq3Swn}R?Pw-E^I|sZ0)>FG!UfGW1(`eT(7xPo$AK zstSf9g11r^$}(~!%;`1+INjWPvJ@)@dO}gt2NbNn$)vDLO@8dPn4CA1w~jO29N<}r zt;UmV();}3T~SZ~C<37!^YORTXC~w6Nrj6-gBK6p--6*S24XgtuK(LOy7<3BG6aKJ za5{g?W*+C!HNRYXHcv~v75;FRNWwa~SzcR>ES5bX=-JkRV^zixpY6E0`La$>bq(d!z2nFIf*(BNx9M8wk7hA{eei9yGSF4r8zraq_|(kKRa>uNav&Lqg!Zyo9P<-2j&}vXc0WX z=W=(yokHmj!Q{0oufWXp6Gqj-Hs+upI$V@gP+89XuJGmMHLc=R#dn0mh$g#ePYXR0 zqiDnG%+QZwWhjKuH0Esv`NS+Y);;Q?-^}g2=Rx$$AH1eh-^Yz!e459YfqiX4&vFaW zx7eeBN8G;`7>PIqXGC(8ptH;z)`K@|j|Il_RJG;zIxt6DsTg&F*YCcE$Co<56Z;yL zg#QTY!@3FeX@K$_@Xn?>DmB{c^_car8R;(bF^upTh&`=W70Sqdvroda#G=aKYG#_{ zm7NBJV1$l>+!O(@}6A?G}y3+M41xio$j&W|sd zl!B$Y%a&yb{8K&~xCe|p=rVtdq@qB)k=pMWta!70Fx3!S0H;RUxKxy_i*g|(A^ zuAx*Nl}vvB5%1eO283a?`cGQ{cy=W}TrNt{?Jnl<72UnJWxn8tZX}m>nynHh^l2dp zs5Hqzd}Ft5vuU$i>9PWC1+7EHQ|K&~!@9cYCa}IX?OLp8kX!067TTE|RxabwBi6IJTn zX!F~xl0%6n9h7it%3hm!KzY?yh8Yba*bJX(BB6 z3y(pfGFHZ;uvbzG!iBj2)g})&r^| zN=Guu5@z%LPvv*h=#}OD>e;9!aQ&q)Ac+;L)Mc;B{W~H=5A9cW0=q#aVyd7qFYHj) zuLYcO_`Ia)Y(hp{CefV|5Tx`i*8%}#8lH=8v>=xUW?wa>X~W)7L~pZ2(`PbxXL%4Y z?GAG4(#8v{fj+CMh7ILsB@PSO4JV@UnolAy;WaIr;wx~Ew8*GN5+_2p$qL*ld`tNU zb9pyS6~;;_Q)k~R&8MC}bZmGsXR4Sq;7a!PJiT@-kcY7OY&DD%Ss`j}^XnPD6$7hk z3#m?qpBrA%;{A>Lcqtaoh(DQLe@FiLLWklO3f`hBvH_R<-%}9qxbzAC^0>S^cs1^+ zs3La`ODA(mL@L|!_N%ahEiHM>W1Y#uj$m{i8 zS*o~6rH;^{84Th@qi-g&&qb4ikz*APveCPJd48_4KJWTiiP`vB04h>RwN5YEZ?8g} zh9`n+v=K*Gl&74f-Tg^rZ|4k(XzdW0k8bkt8)_}0sGW2#{I>8N@vn*MWwFr~RQMf) zix$QY4L2sM;mV2C#go0i4ulqu6-Bm+_X;X0{GdygXV&buS0jFMnHV|c><9~5uQP>#niJ7raBd0yKMx16-8am>~K zQPQ-deC>|k;EB>t*Lj^$w^nDA9N{ktA_py4<&sssy->l7D0;6Prb?2uR!O!lBOKi6 zv!21bBv;O@IP8Ts6t!9uuj-7&qrz=n1lka$<@ft|rhD+aZTB;#v#Z(I7{~+WnnSF2 zgxzb}j~(l?gxEH7>~7)dEov)Z8Tl*|4D&TNae zSNzcI^~((kHI;MKA{Jv&{m=*F*|-M+**19ERc(hnOu4Bp?Z}UU8xh#D;Y@sGpiM>V zN#yv77;X=-bn?0zZH#+XbSC+u7VIx1&CO_nFk?(b7w$g~1?|9B-)BnXOw)g)bG@^QNHFS1PT7f-!sh55W1#eNJ0aM&>c`6ar zg$s+FxdvP*XlQdvGzq|=9)$z-G2Fn~3qnM%_mq5um_%bhsD&(RvY;Fz-1ga-d1wI< z+mZI!+=P01^82S&aWWe83wd~+*H{)bu|=MfCK+ewCKr^qgnXY=0C*^Oyg& zL(S1&`g2O`aqSzJrmy=-{=?`ps-dmNh+2W^aAgxUCDq47C0`1$(t=BNXDW(~LS#yW0Ro^5FW(jpp_9pX3b5ur$2YCi@vNH*rtG%bNW)Hmdq?`8> zN68#0b2YZUKfJ?+1$OQXrPd{FpIo|=?5n>i z5fy&LD6FC_d^z~Ry=w_FB<)DFjuhSI$lsUhD97CoFC5;sa^LHKdS;#Vad~Lk0jgWG zDQjzY=G^**XadU9vO8D?rafgPxkbslTl@j1781^qky8=eyyWpZ?|1@M-n}N%S=HmJ zTCt<7YI^)>@6qRUR=v}mdvafcdT)`TTl8uOU{(FS5XBBSAp$W@`)}Z>!Ozi7N(YXy zcYaoce0wkAbzc<2SgE^G@3E7SYfnrAGaKeJDdN5UA5fFl4_YAF&AS+JVI`A!KivC> z-+Wv9U|gP?qsPiif%i6kwM3ZGhqrnlp}wB&%GKb}SMq)QYX*i5dm}iOH+#b0RUaKr z#muT~r)9@5Fxrq=&pi&($CaHIWGRFyF1Fp(`P`ONA~3fvMCQ(W^zX?6X5oK{BH#=7 z?OfTQ4ZtQ?$qt7XLPYa_#mClCrw4>uJbwwY!|4MNiC3kD5sQGuuRogz*Zn@mDvo7}`OlWlcUIumxDe>YhsNJG<{)s*GR#GM7 zXtLasXF`4pKbh=^Mv>W-gmkCL0+aR@u(v61e-w5UC+Ca+0{s7(0hG@#*Z31;&0ry; zm}91#?kPoAL9&zucuX4vo?zdhqfyiPt*1Pid=C%_B3-(m+-0errLyA@0gga zHNNLV>U%Jz&r@vblpE=$?#AFKKb~}=viRLyJJDAP%M5L%)a9Y=@~~aceOk#^WqVj> zVIU=1HL1Q`(Ch&+ZIo|P38Q0%?dhZ(7z^P_6#wN~gzyfAfB&f3w>7-4tE=uC{)P^o zxYDyYgK^~dNdw$7^sZN?zqegGh478n5tl#QUwQg++I`@OaXg_5T*Lx>(S{9+&PQV2 zr3AKtwStR~u_nWmmO%Oa??Th#KYrdNqc{I?3n_2W;vk%VAYBr%d_a^x5WT3FO?ggv zS8&`MO3!)bU@~K?+WB4!6xf0vFf|ttwi)oEo_Y~+zYM{`2UWNG&4U?jb zZmTVP=j6+2ZXWuXH?UCp(U0+~hLM%F+z4tzMvJ*H>Vg^%ER5?B8t>^XbG=2oVT1m` z%H~jMvFih4eC0sA2L}03kFJ8G;~Jj5;+rd$e&6~?!vLDE$Kp1`plx)NFXk3r;FD=y zaEWyiyZVxBmb_VQI@$X^QF554ENN95)0rbWXZxKrJQwuCGn50y=PI4zuk8=du$#5< zgh>c42{zM8lC_NCF(ZP@W(ij`7wNScOU~V?nvhYQA=;a2SQ6p%4V^y3OtW!1={wHa zXxH|>@FYgT++S-V#Fsrg@5)pCG<|M6E~xu;ds~9jJ%oE};(5gLSEE#LB5K?vf{jKO z)-CHBsdTj5Jku39kakSPbvf0g$4vv{uSJ&5+4wDd}A^P(`?=9 z`U@`X8sBTUC&b>Sx3J+pskPOXczu$4u{=Y**qd@vGf1!|R8pT!pdI>#trMIren_qnV^6W<@k3{0 zc=Lety&PwnBLjZ(>xCY+ADDKocQs#-Rq28z@(e%LgPaGj?_+eO;zv+m$SNRRmpjK~ z8ALn~>^T%8h>OQs=eX9oIRP>Z3Bvv+7;F-R}Qx{LcTKCP8fU#p4*57&xAL}9#B zhxvI+gTc=SwUo^m!lz-qPHC#q*&Q~H`Qxl#Xw_zJA<->*+n;Rv5|KBLBE-E53=#Zj zP6bsGyO%mJaj5MZX%_|bNKsSi?b&+7kTNltLiC;dza@jw^`w-eGrr~;eQ4iCh&m}y zJ>JJo%=}uq60!paFG^y53u?D0Yf!+=_uC^$L3zjy;NC?G@PxW^^I;1aas+?%)ibeS z)c6!jxs~jrMBU_JvolH_hoR;#VlNbvtd8s_Mk}g*=3>CRUn{HQlu=j7zSVnYU&SuN zI@`k?m10vDy;1BPq@C+=usZXe-hWE~+GSqT#se-97{iKHZ0xVte$SUgbvLXz_Ed9jOV*%P&)H#WsW6gx8cAFXncqe~LH&{3@5b!*Ih2{r(?xt10lt8f=`}0}Z=4(;}+}9q1)ww2! zSP*}ifrnu!wT~{@qlB3P(?mM_z|rSLxPAgQ35%D?(vaJML&QIS_x;(P z{l2>|oM9jM#&7(;BjEw(!M_wI$e;N4IATUHVe#DiX#wa66XVtVO5v|xxL(^(7b0T3 z1P{^95(UOZipd*(eQENUT3xA3kczUQS%wFvSeXk!$x?eIcclijK9Hyr)&x_ zBX)36+m2nFc$7aRx9}^W;ty7rgrX#3GpxZYPj`oMz9(<8KtfF=a{g61GfNLBU zS$Kb&xGsR00;Nh+ierNc1TQ~IXM`U#USs&Y=ozRD1_q>$yq)gG}hein7{aQ-wpYoH+WnWk&-t`m=7wT51<2A_M@ z|I^-;fK$1BePil`Zlsdvs3?UqALoRk5TQ&VA#;?OBuyw&rjiPED?%w1QG_BHQi?{+ zZc?Z;x|Ig~*M842oZ(*mzpwxE-S2xo&$GSnS??a#UVE?ouDyTjHK-NN=&o6{!Y|~J z#je+_>ln*DdbHcly}jD>iFCyOn9d>fu$Sl8PKjBtS?@&Ys?qPNx@DTKu8WdRPn4Fw z{_cy&w)rDup2R#YXW~+$^WKdJRr=`Mo?E=*ODFDG`_*NK^uLrQKA{Ra*|>Q{KC2%u z>X{lz?Mtd&`t6(8*sha&P-w6}$J4!ls~^65P0h=XvmrPP*laq#PCYcV*8K-Pd!I<2 z-hy#P4HqrT+%7wBTx4LFq&~)HlCgK0md1R`m~D}HtK(jz9!t+{v{W^1{_uF&iRTml zSie-`R%=hJqtFiX$feHUjX|^2>GQqS%HHOqr;?)0DB)h!xS&!gl89`=>6H^xXNfPq z-89SQE8(#F-B|^p#NVgRJC)&ictzRd(3dJ=^Y1>~RUW2jI`Vbm2-ikc`?YzeeixW* z5}Y2FU|S(ke^XIh!ZBk+zW<%XG<>1lw45=HVzM_2*6J&)DVg~0{PMcg^4m=t%{M-G z=$jf{%IXLne8hlpdBseZH?9kBP&XfWamq^gL$Qg$Ak)QIfoWAhoN1oW7 zg;`xf)X8PzO)rOy^PJmNXpmXfyt-3$QHtJ$@hS3Vhh1YeXBwIqx~3SOQP?sq!?wsF zE!}EXk;Ri@uiEO$dv>=k=_UFnev9iGVfAgv3;Wc!rdv*?kiJZ`s?gk$VD#Bh|DheJ zyHAf;DEo1gpT4r_*0{pFd*_>+q8mTqS{U^oBupYTUh-)upY<{Y&czJdP&EtVz*@Zc z0|1F1qq9NGFXm0Z`!xQ?+}a>D5w#5l&EYHB51dT;qc>%|*QYd}na~^3K(!zXd!)J9%funT8RtJF* zkOPJ$74Tg1cFDC5M~}vz2h+IpIT;h8(^D0~vKz+dZz(;6@NBkEZh_D%1GES~yi;=*VRiDKaawvGS z4xPGk7qW$LPn)zu&E2;o`~h9!)Tv#TIY|Zahlt4^i@jwOR%9)ox+Z1bmYP|F zhfmAoFXh}Rf1DM6esi`UX%wH<@L8Br(5da=cEe~EOoU;4`5m_rAMX3$%{_QC{#GfG zcW<0#pOK2-a zE?MMMw`aEpG(^t3m~(h^eeapVlg}>6Co9;>Ov}`(BG=Y*PTGC#b%6&UaJ;ILL$~gm zeMha{R@6f75uCvP1$ev|IIg9{Tkew=7!i@)AI*~mX3&TL0` zNW^_Zapro}5c4TLl$j*y&iiuu4|OAyQqq{UWx+kIpPLI0_T5d;sGKvaYq{Mr*Bf)? zx>M^^$z5T#J;5LD)NdJ0eUtUAdI)jct#PVD%7W^`7%xpipCf}IjqN%e{!B==vG z9ku9-*0{=KC29H8ORbU;nVVNA97@{LH#q&DQu>$g`9C$-ac8}jXIl%yw^+;*O}QgU88S-m(&TLa%=V7te8 zyRY>mic}Ic%1|-t=-%vWBDzuuCkztlf#bgB5R!wVmnI$|y$VwES2yqCJCJ}Xt<}9RW+vGex#J!o+;<6@KYUOzZ1s>g2^l0yFf0>?d*B89*7Ck*p ztCe57`K)T;%<-@zn{t|J0b6=>0jXRM16;21^MNy$?w)Qqk7xd}%Q94KZ3SeNJ)WB> zrbSOyceM;UD5#@fm+Lg?=)9QOTV_iL?Ak8%hF_idtQHwRVLqUNX(OB{uD5>4ld)}; z_rw%&cPnd+{9Rig?R2zg(|I;$Y*KE-4BbD|9r5?N1g-Vk%#9j;n#o5pZiY46&Ggwh{;~f?fqnAb{A$Q&nMeAW?I`W!3-G26M>TtKKXkJ1s$_0h za(GqD>mMhl8JCX^6LL+6YHpT0nsG7OY46k88%kEIKaZ=`QPQ}S?j)I$eM5fXjxZO- zM0{{*Uc|e@(f>}Z|U_gH$b{Zo+kQ2<;A{36~CQ5khk(KHRTbt zC?%*=qH=wcWy_Dl$11M$ik@4m-p8+5{FX^S@dB%zpeYHAM`K{2=cegm^J60(Z_ZIE zj{n{>{Xy+@C%YX&B~=aWbMtNnHci*u@Y`O+b{jdn({was+j)XV&@UX zXO?rr($)8sJ#1-QMeW#KtQHhrdZvPKJf$kmBX-LLue4IZ3@yP0E!U5jFPNavd@ILt zX2C;|x!X#7$2@X99kJ((Vn-a_Wu~mE?$tbf>G3{IGqP!wU;%u*r~A<*%bdF16C*|l zS~DiMbhfk<=&QlX-B<5mbz6pHh-0Hh30qSVgUi^ zfGa|72NtCkwb%F+_7sh<`O+tNV4>A+UIxL&#^8JDAR-CiQ3Eiu`KIRPde%c=W{7zM z2s1nf!i+cM zJrJ;B2!cikUJG}@o`LRg?*X5Qf&X41aF1XEgdq12cfgggP!0rNFGzrnkGlt~arOo* z1)i`F;39y`0$35S8=zQ#_P`U&0oH^k*x(7aYz!O0g^gh&xFaMAo&X=%846(|cmskB zH-KOx1PleR;RzuKzeWO%6V~AgVVp=d09A3K6|jEo*LZ_+Goji>0Aho82r4Wns9@p1Fg6rAj)|aQuysrZ3A-|3u|-FSDnLWA;virc2}1%> z!1jio(csDeW}zz<5Znx*x!6p@n2pLt41_B%7Kh5DK^hKN7^)x%$qyl0(4J*nKvT73&5Bo|Oea5>YZ5suk26bt@8s{_g6(LTzDxgT*B*2P#Yu z&}^10fE*0uk7A)VVZsDCVSHr^2g)6Ouc6?6lWkHxj?;&t^?pTOoQ~#bqk?>ER|*iIkN6d@BsJ3)SM}#X5EXA6)^bcX}D8(6a%UQ4)w|4iw?DARa6C{sC2aO6yPb?o65DyUE*b{(3vn#<; z#DMu`(Lp!EsA2tB1Gnr}hrlhH(V!S%R>NYigIjjNVQ|YX6;t&j*snox2K5R`(=Ry& z2!qVZ!vIDjt15$0QJ(!#bxA`Z@B^Q@BP`H`NCa1BunYmg=Jway6itlq8@kLCW+FD5OaY+FW|Ju_<=En33JW@R z$Cv_gmMw>-$dLd$2Ms)!HCdU#Y|ri#Q<(7CvS14Cv0>I^7e|6G>Y;o1?j8(mm_^xi z3ZC3wB2LY(FK7uHV$R2+4y*a@G0-Fl*eyw^tL&=dH%HdV-j-( zt-e3$n)D~p^RWKf*`6ga^W@#b6FqaK*X0`XIiH1z(l28Kg*H4N*c~S1GELhu%3j^! zH&Nxl((UZtDt;r@!~0r=3Hd+1l0%2?e_!B-gsNT#Ad>=m*mpY zeWe0DlC$_6Ucywt|A~emMfUGCgsu*nO%3}$_zxWJu=NiK*7w6IAfN4;k~Mg+bRLKt zEndChRQ4ls%R%^>;MWpGk;Ihnt)C}071giW)YhoHW{cf{#ufV{ba%#U|M5zy<+D+T zCGG3&7PD)a>cWq=y`5l1)xK2juXZ{>R{6|J!*fsk%yQ?Q?VkDmXpDKR>=akI(2t`& zChsAcRJhKWZju|9DO9%0p?$WPf6?N6-boYF<^RtP#2CZ{Y&rAqvE%PnCjXh~8Cz%l zx=J#Cep0zXGiydGs7ud;cHsA2%!zmM@P}hTf&#_m{4_GQ2>fLg{hyrueyP^}U%U|H zojh4z6DJsZxgi%}U`jx*0I)b;;I{_sx7a&V|JQZ^=s!61E`ZNuBX?*s!K&=yD(!_Gkk zibvUd`T6RhPnFNqn?oj&DMSjHPNc(Ir-q#hav8?1S7@NS2M!sO7&vex#r^~P0S#;t z9{2$oT3(_*yf2#u_8jD|hJ!|CP|@2CHx1boxM;{S!b3xzdAQ>t_ca_e_<@WYTr}i9 zh>J#}B1bt~@yG!A=c3W5$VVDSJTilU{Pb|rm_+pA#zjMxA8s0zr!HhBou^zfQv+F2 zxbkCym5P%_A!2p_j(E_Q$RNN)qan)x7fk~hVK`}EOW^J!3fSPe`-4KFAX^MqTIB7D zlSTp-@X}~JeGWw+D+fn@Py`RJh%_osS|W|XBY#9X0`qg`N2JryN{yooBAtnx{c+O3 zWff0a@N>q^KcWWmKgN|747%KXL}Z{dbKLQ$eB!|X;f}||EIOR^V^WYEii-w};h`~* zL69RJ3~*!u9(UWo6CuyI1fOciYROp!i3zJiP8ylcKz1#T zcw|V*8;^+m{Bg$9V4!Ui4jOD@^Y9*=z46u$eAe;s1uWUza|DTsT=#L6K_eq$9cQ^T zK766kkkaRh$3!+dP8v)HJawT{dB!xjsRJL7+-Vue6&PoJAiCT>52F_Uqg~6*os4!!3$D>kt>qliELnvo{G+w<*2Dg?x<^1!q65n*oH0q{Lsn?v!xpOdLTnNq+!{(gMys{gRvD0 QoIj;8aY{=1OAK)T0@=r1H2?qr literal 0 HcmV?d00001 diff --git a/tests/triton_tests/plot1.png b/tests/triton_tests/plot1.png new file mode 100644 index 0000000000000000000000000000000000000000..794c86900835cfd60103bb4999b23ac33bf3d7b7 GIT binary patch literal 121873 zcmY(rcU+Hc_�snDjVq#^C1rD>$4ttF+UB~hVGXb)*egOm!DCM6}b(=ytFw)U3R z@3^1m`F%d$&mYgrBfanUeP7pko#$~L$8p|)n(Ak%b}{WDAt9ksRywUsLbA<_goIRg z=XU($KI@@i{2}Rb#=u3#@v_S;GbbyOb7n5rt~k0}u{CGAY31Z>>v&aEKwLnWkIlx# z<(jjUprFJ5{DOd^leJ*iuATn42>CT7LuV3_J-GTFGtvyXbXyWq5)$Rp3c9Wd6N$GB z==8h%r>(|R826?b+N2odiu2`?sisuLw5!%Js~)Rr56!Ji`kNWZrRd8Mv+n{QsmE(( z`bxipr_)reM^-JmAEqhO@{N>bzi*G1bZ!*&O^ol8KEtUTabKQo@6K&oeE$2Rzc@wz z$M*l{2l#pE?w!<%|L3=!(CuXp`9B|2;9)zx@BhEZKmT%9TRuSOf3Ew_kJ9e!HT&PM zB7USodg$Z-{j#VnC+}1LpC26|Td40OA+9Zb`xq0|@>I{Kh6c)LmGeJ_@p}~%{v662 z92}aN+MEjupFLEr_uzT%uJQc7aq6G%pJBQ--7AxjkWkl2iHo@YEcRLz3kwUga!`yv zoXw0o)D+s~(ev*Em2QuVk0+fw^Rc_TRCD}!dV0Rc^3`6~0V+yL$|QfcyKQ@~n~*2% zsFck4H8619$w}K=(Dc^l8>ePc4mI*t%s=7FX8z~)@3xU-@fH^syFI-6{M@;7!s6nL z;xcz>P9^By9Z^tD&}UyzR5TXfCdl-wlty|_-LH~ETTIA^Te)=i-aV~M-Op83F@-$E zd6&O1@$vK5cZNvghgUlC(!N!DwJg`a414*KS2b3E-(K`i?5eR%PswqMst_J~(QUq` zaoHRxs}<(t#j;;jq0AY&M;j9)b-0S?=w{`(|J@_Y!{v_ypPze7Lrv}D>w9L5iBmI9 z*kb#Q-`{u&XMemh{PgI!y^4dhLhWoyfZPKXnG}oaFx?lr2511x(9~e_W2@`FFDwe1wl3OEs?yvYN_Hj)~EGJ|Gvou&|IP_OzYPFz^2T z``fIf4juC0RE>#xdgg8{y^M@Z$&K-*Ox&H%c=NreS zu!AQg+~zpEy}kdet~&o+7!xTl)}#p-SFCWQ&t~==IwQwLR@n682ZN5z*O4n<^0i)m z*L`I`?X}K$Vjv(eko@M&n?!(U+fZ|fibh67(aln`&Qd39Z(3fyEc>8OLE%u%lwS4? zp7qt~^4Yn$6wj48mx=a=0RcM`7xl7?d}Cu7tE#G8-QDMUZuO1vB}-gqpPHJAd-O=V z(C++m^~4v>r(P^>7!=ykpMRlS(cb=6e#zwB&CshoC1#F}8Y^={vx|#Xc&NHL!~Dxb zf2%Uo6Q$G=r3&Sj?0ar+-MZEG`wMz}!K({z1T-GAv$KzW%cSMv;?m7Bik92>^X$W_ z`Ikq>8xkZV^R1gwv@l?GD*G1Vh$ghn)c}E z=-9Poe`$Gk_Qqcc&p$$07v5BT`$osh%PZwFDzB!dmeduozE3u|z1_f%iut*#SAp$U zKRm_J`ogY~o0CSy7cb_%e*HuvSyu9N?#BA6PPVb+fddB&@+Q%Q$SQW5=R&}q zZ;XCLH*9=FKu=GvbaP|%m`yVYBO_yHkwg6sR?iO=_jW{!;Rp?V)sU-fYh%Py4mZT# zd;FOB_WGiJ()!iDvV#;9!JeCa+pHeEy-K%Ua!8p*^m4<8j*k5}2S?Ap_*~6;`xRdk z|L5@V$FsWBVCNuJW}x3NbOU)aTECFMNOR zajVyBxu<)#)^5F51`25^(ako;q~_n@9E2Y_b;p`jKU}7y?9VTXUXN*+hA&_C3!86q z=q&j8?15~n?Q(y`R!Ssc$kMrd2Gb6P12!!b!qaM!oBT?{MG_nosSSZzc`T4oJy1vFNZ`)bG)6&u+wWl3VHiVB4 z`LlX|`0^!m>yF(*HqC5EMk*tXi9}U&e80%R>iy@?Va1UC&w;d_bG6(}(q$byqq&cy z-9#uD1wP|rwUOLk-#D6LS%S*SWUU($o;q|DQSKMHZ&e@jG$%(xQ!|`>%SaLb*6oxZ ze*B0*GNi9*GZvW=S^rlT$5r3?pX;o(@Q#k9sqAMLAk&KXepWd48 zZFnsHK(Ks`=Fatjs_TDd%n+AB@88ROeSU7>Tc%EsW})p@E+U5nP#@kW>-o0vkFN|P zNnNa9^0XM$KA}5!gs1Xdw^8-6E>Y7Ptv1I15RR-KJ~7^wWBn~d>y>GFl6aqIzWl&0 zlLM&U!-j34$YObxwKQ9{Y?)bF;t~#^3fpmoGFKt~83(ag?H1$jAM z%z>%z_U8M%M!C(kd(WO7-m&{Y|IE)%ln14vkYk(ktW-~)Jh_vS(sHooVS!_x3D1`c zQ5L_3hpUKfLE#$v{*sYIOH1oYUs+<$yIYJltL}@xD>pZm|B8@(ioq?>pMAR ze}i($Hy$BHI?oTEdzF<%zGn}S!b4bO9+`deyZ(2Z*39l0Q?L6t|HR}ZIXU?u z0fFofl6jXK;wrF2&kl_DzW1=OwEX;caa>tVE%odJ7SHv?Hnju^>LY({T%8-NC2j&| zQ!$jO5_yQo=p^f_%ju|Oo}25-8Q<*yKO(JLyS_XY7j<8-i4nM{^6utjdyxb4y&Ze1 zT3V>(;gtu`Wbn-a201c!3#@9!@;YW0Yu~yIOXb$Bqsanqz#~SRvWjt*B(^2bSiYj^`{R_g0 ze{5__+04wWcu6aKAS$uYzD*^R>39Uv3Om()k<)hgM8VC;PUEht-EXx-#KmKQLY?_k zm6ac=ojaF?xEpO?m3DpoE#(Xa4?jN;nQesbGexSzAzE(rZ<1|JT~mNiN&*4`buJ1I zzM}#}4KWHBhm(x`stm?zX!n(QrCt8=Xs9I7&7mGxAg#=6b+{yqV{Aj?ptP$d@r~F3 zmrmW6dguO@Z?4K&T>!ehu{3g@oHi8^pp1*m&dw6+vjAYkTMwMUiGpvtFxpZa_A>3r z$X%>?G_U^LdUyG4xxPg-AXv}C@o{lVcLt}Y1v$C7M>=k9tgS>uM@OrdXBibnC_Z4} z_nQAA9^vcjOLySFE53KJ*EC9IX70L(&AnRaVeM14KHQ|j$YK&c#Ywm3t`*l;+eUi5 z$5qAfqBsM;5l8X6cdFR+Vu2m02*L4&{MWC;^>R$IMAd)vl~OT_Dc!;Gl(p&wsNvkR zXOERT%0WbMaBx(7d}hP^%uKjZfz4AjYfDQd)QU*4EA8Rw>tr%qfVyQ!M%;;$*+zvg zQd?|XI;~`K^1fzW5Vx%fKUBA{x!$*_9DVeBYlq*2rCc@^T^p^-jZO}3ZdE`M?Ms)Q zh%{AKpM4~3;n(PH*mPU|k%(20NMc;v5zF5^zxpNj{on40*q<^x@&NDhH)p!Z|W8wH0^&?J4tm zx9v#z_~px&TLJK$dSWIvol%;i{r0eCe|?T4r;R!gJ!U#3YHZ*XEi7}iwapKCAD?$8g=)M%RgoK4j2gNynm&v?ea=Mlg5I4){xK(FG zg-=Gt5fT0=(UiOL1oNDS8}I#1VV=^EOMt*mzOvxsl}C z!!^t^C>?w0=_`>2=>_ixeBRmHJ;!AmuQ4y~`9^iBP1B*KXsYYlk!j5amk$P z>IBAFTVJP0oIgD7n1t#}QfT+RH9mG-{6kk4GmurHnBk|_-|UV9ph}I?FUC=_${sK^ zHJ$2xuZn|$u7wKUb2O^Wd2?f(*t$Ni<%8I2bSz=AUVo2tX@(E%l6*j5OEeh*a_bAU zNb|!uQVDXmWzV{5-OQVscN(fQyLvTy;^phtLPN`{sz3VXyWe6*&`$!YDaY$Kk6SJc zy^2JC&+pj zJ6d){tvu|zUKa`^G^X}AK7J@J_C^P4<<8x^FC(3xuFcN4p?F!hcdJ3u%%^VKsQ zPVSnmOPYR?LAPXlCxZO!3xiyCdS1QHU%%3#D?tvdz}7@RzIEr$9TN0qRmgwpsN+rx zfU_6ToK|Yct)Iq@qtpDHS~a>;$WXan?J(Ez;b>~JlxNRuzhzLf-d+<%v|xXeh1@OR$H#Al=cwAn-b75AiS2PvH7!ehM(!KMknj$Z1D!_i#G0mP>a3UIhbaQXL*QF3aO@&{O zx368lZicL1SNd4oF{b8^{~p%M$Z~aO<5@y7!Za1UL%i+SzWuV15SqD@RpBzupdIV+d?T2xeYT2YZ*Ib!Za`1gAm zsbSF%9_&GnXMw<%jQO6Sb-2xehY}HfSVV-z%WM6BvM+@Sa;&hBP(S)G6@hzi-%5b| zQLB&PkG@%)$a{u!ysRPH`M7oD?fF^IV?s@v5mF|Jfy$(}i|2 zN~-s0tLv9a(1hGAFPFpid;a<94|>LFs#_8m&_63HOa0^kuYUF@hXMb&AP+Sn##_@= zi#?VFkVFuFnVJL4%*^vYE;#01?U5qdAAkQHcu<`f#zAj3L4d|H2!A95PGCpDrRgw$ zzKEBK^*H$R7pv>*gS*~1(o$1D-PzWW6UnKdY&=r%_T9UIsi{U}uCcB+g20lE=kJ3| z`h`A1Y9(SftE?JI&#vQ1Sy}vO)-18#1vV{r$!WPtnb<>~4BBEzR+p#K&%e|wE~Dq= zLD_uTtvS40q#D5?yS|`~Ej%X1@7=o-Zyfu$`1oF(HcqX71XMUv zzPbJYi76r>q2Yti*64;MLFP`1od8zx%oci-K$rxc9+};MxOBW^6Dt_I)NA#csPm9-2&$cblfxT z%s8nVcG-?81i@3v**KLr%eP8VPoxwqp-pe8BxsxgEXx|cBj|loCu4SY_I-bU-a|^L zr3H51gJSqE%q?{}=DW_CqkcA~sm3Djq|tc4cIXnzHZ502bmivdnOBD$v^mBkEiEl1 zDvFEEh8#S>fL0pGHZVAN0MCnyenp`pnis@aoRHZTB!Gdz+DJ+7try;0{el+7;PoX^ z8IL8)w(N@&-Nia1Kg&77zNacjW}JV?!0I_e`b;tOV{>!(lP7Fw;LiB&q6e`02$eLwbA_eE5DM~AT2Ul-7xF8A%g zne9FxA)hSkB_tw}a?*DfS?7*otpxi?bog&?Z_0vQh!QkSLQ5`W_Hi4~-1BZzqO0}U zx+CvOhX~VgyVI7I9BguTF7Nqy?ba>sBS!)d4bO6NngTKp=j~ZBdP?QMJML7TQ+MOj z$L`QzM)99hQxA60lloZQ>M=m=j*pK&X7_C;ushLQf_R}53_&+N^F_R`IaQffL{v0h z;<|ZKtwyqVr3@F^V`6z<85c*GdatA@iGi?5M*A3v8ywurN8v6P&PXgQxTqB4VyDWv zVUm54P8`9Ef|q}N4kUVJQR~JdZKb6qVnz0CyKmjPwSD)28hul*GoGHMVg2W|2XA~4 zXt*IDDH)IGZ_c;Yc%hs55gAPaOp9ar`pxV~ZnZdh0AM2LW61~5`I;lzBwy$cRikUp z0L#?d+lxTW16#pZw<>tTS3qlBA9%gV~jqkwF>L^11j*A-k%%$l(Gyzo>{fT z=ZETQf9#(xx}qbF9stt4-m4B$Hzxv;lUb4Yue84=4S#6^ZnG5x5Sr17^ki=%JF6ak zetv=;0~1(*JO}#42UqByoTU1|XzEGehWHb=1W-i?h8&xAw3s|0F|p(A^}`p6t~^CQ zyEgM_5ANWsHKm}SVA<-_iL&*j3k1QFYfEl`hMX9pe zue-16OUs#`{;a?H`=5Bul?cwv z%=9PU-T;Y0#TQrdq(E%Q_t3v80Ref0I_#AGeto z#^UZv#v!@6^l}^RQB=%pwxbtrwo$37s-i1VhCnhiJ4=WeYm05BK{G`*E$mQSKq~vP zn@64h8}N&T7SNK|7SLT9L_Zm0t139%S1#8l+gV_Hudwhq^7g`bn})Ta!t{>06vw`@ z(eZH*c70o7Pc#LCwkEoW*C5b5Jw5-Rrqe5M=8Omc&x<+ssyR7{V0kUuUz@2EmPR~( z+a_6P^yS&v2d>lamyt%&!OgiqB+}ob6JGyN@Iq0DF9qYT$;mH7ubYeXcT(mtGZT{t z`WIZ?rX!CN6&rwYOO9!0MyfR#=~f;d9wLt~cAGMwum3VPR8JZuH6ebrn+@^~2#||` zJ>R&5JbHVN6l+B!33Wc6?4Glv`$l_tX`(~La|x+)m%}G$RVd!9*IZn5X()$dc22$N z^L}Pp?hR?i+`{4>Y8T)&p4u6Gh-Wnx^e5ol+H5te+tC+-%zyv>#V^i%+9N0AG*}~2 z$jQ%71qgK}&=kkt=M#0omF~9>0Wv|Y=%i`~t<2_autXmRt6^YZ0A@*l^tBICG}|sA5-6WL*P(3?8f0&C)ky}5A?G~C~HKHjhu%SNXznkCQ zStXj2J9g&lO6i%+58_)h&E~JXS^Bl69zzI|-S0gZ?;y|5VWW-|^A)6hfB*h{U})$N z0GKtlTxz+j_~D~RFKAg9SXm!LM$+zKmE~)_VF7v`OvXmqyQ{B^GP#bbcxyN~@C*0% zP&%mwawyXf3_iF>G68zNn&jOFr8Eg*$XI~QbdUS;)R&d{5y|VnPRA$cV>3Q{{Thi~ zXON)wFpI;{#jzL!*yZef!*)Ap-h-Z%RReWymsih`!fUDML(R3FOE`@78_K+~v_)@! zXCf);{2d9W!GH*dzVF}92kL|DDWB`#Pe*qe)p8VA2fYs`L<7(P$OHuV(Rr?>^^vdO zWZk>zGF7l92D#>B2o^vU( z_h;7Pii3kK06xmQjyXwTN>Fex1x}PXplSEp>-Ru6{QULnBht|d;vQtiAf&nPnXkhc zJ10p30s=Ibd^wcsydOcTLoOhyX+wkB-{tA12u@YH?>E6O5VR^mvVFVJR5wulzzy9a zaTzE`L-m5?y1hbN92|r812_~Cx26st&~R9kXj3@ZsxW|tP%tuhy64%(U z*qYxiWKoDp%WMFITUUkbuRACArci|~(M=E?2@dJ(?ak%e zD$iE|YJ}{6@IjIZ2o0@lZ;u0CF^{auyUIT+WB|6tCtv;}ZcFj}Pnv;#+&+gQ# zkXI@#DXnGzbJE^xaRlbUlcR6$$hY2OS19Gbtck)$$Q&s6_4@ZwY_4WIzCfccjW0;{3N5%?bAb~La0fjk$QD?)X-z3J_~G!buc zJKHed54Ut3*!GD_(@R7v4Gj%}jv(bi&mJB8$#e9?%?XGpR1mWSOiF9eenT$^mWZtj zPf5@}^QA)fD7)I#w5^5GO|3qi66HPbZVil&2j`$cSXw$}|3l4I2kkS#kpVIai;C*7 zFnsy)Mfu#hp(eRaLZ}54Vv+T%K`*-UD_BqotkTPzvex~AY-}XpR*&jt$OGpAEq|)4 zJae5MJGKuZz?Q*Xf*m~>g@k;7qw#BS5Ix&=6o$3AdO_n-S1lga=R@H2Asy|6?pYbS z^4PIVtIEa2We}-!03AzHw<+>PKRP2K8=X4!$@qLP_DZ|V(+!FxgoF_!qiV!VWiaDb zx|(Qrl&sTd&u+c8x#0!k4K!Ce$}68qsf208D2U{sEJMega2!7EH!D>UAQB|5-^DQ)Usy(gF~59S5v11)5;-R)r~f%4 z($^@b5cQZKR+l0tBbpI;rSpv^f_7Ti*#*6MBaW27x?>Lu9eYT>Z&hPs*e?2`ugxoW zy+Whzf4t&lJt%`l39cZnRsa^2)zusGx5O1)dCQQ+-M|bU ztJC#kf<*S&Q8Fx^l7{}%^lY_}M#*UhDk>^MT*2|*Jc|vh|In&Wf&E9{7AI;Qrf%rS zk2T$1H(c7f>lC`UZQHhuMlA{K-dvmC1E|g=CB=jVBaS`9L5hps#+sf%L9g4wULw^ca;;yUJRCA+ zTE4fD(Ic0ice5R!7lxD6wzO!)x4k&eOCAmZHuc1&@zPVAJm~9VQlXo^i;I4zr|4^W2kVA71$1^WM-Ze3H%M*NAR?uJ z;!BFZrmLax!1VSy;eNRX!r`}NViUiq3X)-z`)`7&mv)Q5$F@4{3uSO%p+zJDp-SBN z9i=V@37a6q?)&-<&0H$JHJzm1sz*yps~8ggatqv3NcnG_M|QJ$m;q9vgVsRI?WCZn zzvV&k5ot5i^~J*2SAW1}5G+;LDlm%tnygsxx@Dw5VXHd2-7L~d=pu*(hJZV`%VTJv3|7u#Ur1eOdTd9|Lp+z41l%k9mLMh>2br1QsNfQC zXOOCu)@$^1r*?E)ECQr7Ff^nFKq__~A(71Xz|$O%bh;0%cFgj2(`5I-<91OIrkrmQ zP8!ZGFI%H+`LnX31(gOzN&)RF__hG`6O$K2CxvMV>)jKMaE`$Ev(`?+f7waGh7KAw zJCWMS9DmesSO0|D`JE7nWn?ZEi`RxJKKlu&v%0J6gofZgQVvd^BJ=C(+pkR>;NXxS z^m5fEHz7QJ_ss)^_C&JMMAMzWa@bruNUtoL4T9i3u%%C0YzUhco?r+Vove;V&aqdT z^V{djO5Y%QUZ;`9gU}Nm+bie43fx+lZP}VX&oX{(wD}O|fdCHW_V#uH#nRBy7Q1Sd zwbYo5eT1V+=6&qTnrOX@Ep$H>nX9%qWk-}Y%eHfKas6&!g(n8fd<8HAI@$HPL1j+z z9zfq4b?>cu{8z}gb>`-kklBq`eTkiC2m8nPpUVMN;3JA15L#qROpJYD)JNSx!udfI zq>73Pm@)Fue1aeOjEk*+n2Nq>v|P$4Q~Z=LIhqJ~ldt&K zYSC^6WMaz`g{=rh$Y?(y9(gW)*Z=e9PcyWdtG%VQV5kUijOvfu_X7}40%i9aU2C4$ zrtk2aRXE4%H#_X@uU&Ix|9JWzlV%iOrMr91C2i6ZI`6E#Uxmy+$bBPHbW~IB=JT7n zLOiao=8+a8kd8{9y1S#3|1~wWyT+=yz{WY)a#?BxyGdGXJOncYg#2q^#!O_K_Ua+X znOm>H&IH&jmv|HAgSva`2sMeD6Z_B@nk21X&f#Us>|Jcm2rwGtFG ze)|DfAnJ4j+w`q)=dbWXGPF;9jQB2UgIYjB(DR_?wjj{&u`eH_rG3Wnpc80NSU4}- z@TDFau~C&PUrfvbIP zZr%$32s$WE#zPD>^fEv$h{u~#-O8{})^)BiK2n`qRhO0-DyQeIyM36wzKX?1e&)9} zpGMqWKQdq0T|Oi|S_o2EOUrj&4EbLg=Fy+PBSAf!oMHfcu|Qmxu%A(y9>6p7TK4cGLa;TqmV*mn z8#qCNfEaE{CYn!zY=UNK4QPUOKH^140DTSX$WtVi*8!W(kUupn^OQteC?O?f?%v6*n_-H06OGrqMJgf5bza}d17m1(?@i*`00Fy zb@D85<5yfBo#7`moG&@3zjBdo?{OA~Z!JDY=JG$bWjg;Z@;C9RhLU}-9FGwPg5(w8l;GGp7POcD?20$Q0HN84zF|=vv=Rk(VLq7$0W|x%Quc4s< z477SObe^F3(JfN~x&Hd~%N2eK=O|vjqc~%iKinbx{Q2{x%a^OrU&o6(Cgn*>Nj(A> z9veIe3BDfyLuy4QhR^UL>}nbC=MlE8-rl4KvVQ}RxOs-Mor%K>H1{wmsWJQFyQr0m zB`#0lbc5Bw8Tty`qeR)kArFeJ23Jj|&=29+bGhf1grsD2jdyJ@)TTLfo;qnN^x4M6 zci}6;5)I&Xg8t~!9o2iKBrku5#^~GI>lQdtfxz!aql1v6exa2z$|4N95Q6AS%&QO_ znV{ui_c$QuZ9L1ygoYKgue9f%aQKICp&Inve}*4BM%_R-DFY_S%F3pQL=agz=$y7Y z&9i%ow6RD1soywVsXTi_?x@AfZ8^KzvQH{*H!!OEsYJ3{L)t7Qg-!Hu(|OpnpUlv4 zcsMHS9xndZl5#Sv>HU3=GgCU6?v^?1x-oH6)a=hxU6QmK^R5&?qaB$3RR_dZSkPh!586IgKAWGX~5yu>}r8kpYjol8pcEFlm zU!kghFjGh(>eh#+rKfDExLtksLuEr=umChMes}X+x>3}ViDLJ~qljp3=q-{*xyYTY z(ceh9y~yLr*}Ze;r9U%2)sm!Pee2x2Z=Vt}n}qAkrQI4*gsq-_Wp?esL|HNO9jEO> z20t$EzwCFp$v~Cu$iDdfPx#q>M^`EvrbWzu?0$R5pF=XB8p8sbC!-$iJiIrG4u;Pa zJx!V@ek?{);g(1?$KP!^$<}Y%fvgbj4`@z~jf_--H?Ib_s>1iP2>Vr%x$N`gNcUUP%uLPf0}_{mwQpSuk3D<(WKW)uIzACmYFrpfW(A}y z4txU%I;xN-=QupQ@(~<|PQK?dF83zv7hm7qa!g!-t;P&?C|{KVK2!uj1l!e_4xgs* z@Ry3-1|z48P814Uc(yI)0Ha{-Z5e5K!LG=MN{zgGcP8!lcK|@73O0#CqjS(cK}cl( zQg#`O@U)JT3CxEz903X%G`1v*u>3P|Jm(aVvN=XRMov=Te6{*o@=vQP(f#L7waT~C zQY&7d=i%)fEM1C7x5*CFi12Y%eLMD zaBp zuHODTAyO9rVi7&Y4y(Tq+&_H&d>?)@A-B0f^F^6kQ^uDY6XGp4z~c)aKhCe6?6EY# z6LW5h!qtFzmGDHb+ST^t5>^t6XyJ>;FLsn$-+ub{NS#}BRFv-MjctPBu>y6sdwQ&L zIZr^ft^NJ()*<$j81Sfgsc6oAr8jkW_>21I>ddiOhA6?!wQBl*=2J@L7p8g$Sm6Xi zP~`xpfIt{vspG|#gc@`f2WD6zvp4dOm6bq~~YYzVnojkE|t}rff0j;vf zY6CQj5QZV_zL)aehq!HKX&Gfv=HsLAkkvIfKU+SW5q+3VF@r~F=GpY4UoC!zK8;p= z$-a0#YF;J5KZkKo%HZC^g-qQqMtT|*I7aAqny9pX{aJA@{cim_cXERIgQFxI%9TCu z(kveTGmO(u#ttqCfr^Zq3C{FxY-w@IyNVnzo56QVmDuLN0FB_0f!nfN zJg3S|ObvDhD)io~`*Ea#H?-&|1{*HL%_sWa6_a$dOr4N33l6c0hsX9Z$Ld@O z42#*c&VSZ==RHsReuY-5*JXTWanIQsPqG(UjIG_Hpne{Zk+imj=-lBKr+=A)14;xq z{1+w}t)Bq1feA?Wvan$GW!(AVwAU1W|{O*eBy&K;wOkFm)_XYy3z0fISHh z-wPd9Gv+x3o+R*LX6@kDt;tTjKN(1pek3(q5z{Gdqrs8y^b~3;8G4-`=^LBMp%k#g zIe5hJu{>K%VwB}u)->Lt(`RaH^9wUXKS+?h(yrj6eYN<+m}?*By`wB5=04ot8wWpl zuNYYz+{zHEVtvA7jOIHM1c0kQ4!35uu_R%WN81l3p&va$Q8u~|@%4%O-OuW@=Hqo9-2vg_;CMG6SYDySmLHOrlB+R;qS<2-IAuEvu6TQ1j z#|;Zcxui{Ml0ua)AN43Qjmw;f<@?^kHF)EupVLrMPj7{SH(N*m$ph;EF?^z5DdK8M z_v)=f!`nH|+);>_XfaRN@;tLYev4w}lUUUNt(yy9D%x2>CO*9V6-nZ7Rto7l0TE&K+}7T2+hl z^H6$QN^aO6yzofOK?RKoYDpSY-QpMxxK3XFD*g)VJvPCqrwDDUGmLS7Rb=u|Hi18N zVs0;WSYsC~Y;6O5d`JP4&Ecv6m4BK3#Pwg>2w+}SwH==N7kXKBaD>ctrxVQyjuW~b zG$W7gANTiPvf8xyWU-x!DRm+e<^u+pH9%Gpm^v^d%9OX@YL5|5rUyLtiC&&?&V2mz zi8M~qnG42n!jf=@uNke_35~~`%0%WR0ST0770keo_JkH+f+jNK;yW=~E;Y@VcO@m5 z-Et`v&AS@Rkd=j>UBcZGjUQqwAB=;*#gfy2a6J0i2VmqKg3#~Q`NE2D2IQEsq89=e zdKLh+k;&I4O+@9_(^_&H1YdSmyC$vZ~r6{3{JjwdTuFVRzQ@n~$@ zCnJ=%gM#(7{f@@19~=&Rd$fn@hy47(<}1A4Ov$HAKMdUZnXD+(M)OzZ?sf4?U(QJ! zq_|M>3FJ%rwrqW!+{_ETOR*ej!{pAhH8OX9m_(e)=8z{#zt#77`l4XzBPr=v;p0q_ z?e=J<3AzLw2fVNt71HUnVKt_up&^t$!ej`$2MI>!cHSX)2~FrEOvh}(Ee;?>6Qn(q zF_M|NIgFJiAg7be-P?c}j#N@N9X-48d-rIFTFqr%9%;$2nLv*HNBn@+gR(|U)*wZW zlz5>NByf6swqDkwIWJ^yqPGPhFCroWK5~F?IU)dp44emRsMprk=mfWe+h;#~SQ4oQ zL984$OR%N*6`YtM_+wHKGH#Q9r6rvPccSw~9e^9V2L8X5!AK2*H&?+L{yz%PDbCU= zkZUlE5{4_pOgsir(A5LS61F0b>Gx+5;ByD&amJ-xXGlmOS6FSJ>_!i0gQLUXu70*L zhD*XQm-3>yYx@ig4mwXyl3|AG?afIh!0n2UA5Uh!y|yr_g?WdR1dC!}$b8S$qhFqq zIB^2JTNQkfqrLA>pfj~X-9iBjKvM^LHUpXl5{XXZo3+uNcj6%OgS@6o=XVoC3FJGX z`vWKQYjiXK&>OPg5JVWl2?i66?Apw!#IhJc(=el~=SabEadCrov^D45|I@0TxC&QR zRpC4CVla^KY5j!><`9y+pd8en&YBpVU7$wM4x7YOp;>hdADX*@Z4VN^{3j( zaV(?xREp0cZ&aK(vaf=@U~NZWq2RaVH7X+yYX9F12A%OSf8G*iHQ1~Mp_H7LW`ik? zsAfc-DJA`nGl2+pGE0w^nVFd9<<>~z0DVJkf&A-N=G}Pl2l!B^Ui6M}Fl?(+--?Kc z_-7aiV0cWjhZRV~nuz-%OXwOfK94zP2OE~`GD78jFxIYRXTeH}c^ zVIuD#Z{w0JUn3I}k)|u5HWRaRxY{$&2}DCvP;iV89LdQmdU_6G^c2*N)0F{iAi=UC zkgb4ERetq+B(Mi`%S@zBVu}qsyPmWZVI@MfB$Ds$-Mb}EKEh}f5kYvgLEHi|@wb=y zdkT#bE%^8mourne4Rv)#tJhFYcubTtGc%zY+Cn=wLc*t5snK z>p^BhN&F1w6h>abs%gl}lc3nOD5t(^v4QwV_!J2WbFUfdB;izt`yC2Tir&ll0_Ar( zICXU}ERqI96}$*m{TYq9SPh3fw0b!xVyU-59<+;M+%G7sfT3y?jTHg*4r|) z?h#-EH*KAr-eR-wxE&p0%n~Ap=gLqlY@?BYbh9mIS+y`PPy&f*yabj961;0*tYI3( z6lO9>DAkonu;>9newe=KTX1q#+HgWN434*{|2z{1_=wB3{|s8;)@gtG)|2yk&| zm+i^&9CTVh$=xp&yr8ghz&)SJ-| zAB8bXd!Cq`2%V~d_HX9Q50f4#yW3sPU9BaUwB~1L_n@*6YYmqH=4Zn&<9Ect=*`uv z03o|VCNQRG2W#(=QMBh-1xLiD9bs29d)>fZZN2#Rp_A5Yc-}%Vpaa>Ou!4{O;3wv% zF^2vD@j%#fA~?(}EKb3Mt)!-=>-5mZxsT2d^$mHB@bV)Ucx9U$AUl-f;K=}VWr6=l zR2rjU9y7msRo&B58hiRh=YGq7>QjzkJ{OD!m(chSlfh7joiN-=T*WPPNov$oJ5 z{7}WTFn9ax1vV8U7dNdVG4Mc`ytX@n1HKZ(w_Wk~Nw<%5P@A3;eqEwZN=jmZZyBa{ zf)pmavp0>d4*_pn=_(3Ey(Qj^0O8>?Kl8)g>b$g%DH-+zAm zMA0B2awrAEF2>jA*q|OZE30hir{8p{#J{Dh^`p@^DedOxk)J$Obm1$ za)FDY8()rdD5u@+jr2-ZRgZ(Ob`~0`Hb5Rset|(jztCKeEQ~K7Te~G*Y zbW$D0@frhA$e?d8x+(o|ziUaaiB#$PkHDoR`Ipe-K!ZTzfF(H;c^v*D!g^J$vy*^_ z#9K^oW0*P;BXFng2OeeV5*{}kKxL#*QerX*sJ$Ap6rp3o6$Rx6?-)^$Fu*?(sGxj(i5We`06Mv#H&ydb z;2|M&9=zit8n|2wN=R#KYt)(ovV?%%CZL4HKvgJgTySXkgIv)jYk*msaI4`3HNbMH z{(2D?80nC-A(7PLIA!SPuwW{a@PNUU04gs4_5rY-Dvhx?VHQy0Ugex)>Oqrh?vVbE zOwGr_C1>fK3ftX(7_Y^xin_! z`JeaWXPP^M=U5^+(o`b)-kCI=xkI|*a|6plPy?Vr0fs0EbPAmd(5c1t;R#rU1zj2b%{{}5PwlKGuL31eJ-u!4f~BQ20QVr&Px-95bjmb2UF;N;Swx} zjzgD$joa?Ex-J{GJVEFLOA@N?%`EwYfZN$1qS!OFME9>G#EC_&k z8~~@8eg#6+!Yfv?j0(-`qIjf~1M@KiLp&iYE64G|ky!SRgw-DC>m6Pe7!7M4}t} zk>z^h>&BmfJ;SgmL-K%7dK3gp8juJHjCpY|K;>ib8NCfL_JYq!S*0ddqAis~M@w`# zpr`btu^(d}ffsgCQ4tP}y4=Sy9!2bfXU?3t4r&F~t1-ARbBq3w%>{Rl#gBQ?tdTEA z4j#C63n{sni-R5g9L8-3T$ePd@_Tvo#$gms=&+DWcO^P=Lo-3fZN{)r zT6c-OrtxRD#lJ$UeI48U$hKUvHb491HXFU#QR&#+AZHW$ky~@;2hYE|YFtb?-gYm3 zf5<_K7wm(5#szz6Xw7OgU9}&}YX+4DjC{`zKjw5qQ$<%`OTCwyw<2l6#QDD^R*$G0 z4)y3R*#!QO*WJ~9dmFhMN7MJV*i$okLHtWMHM2cyuR7k`b+qJLk>}A8iP*eRZJ4OE zXYaR|+ghtcvV2!@R`MYU=M{brXCnGlA3_~{xh3MsAXqj__)~!zc{qDO=*!Q{<60WD>% z5Goc`Kd!BJypz^PfBtwzwC~y7WmL+_PV|h7Z7Ivoz)#Ufn2@JZDup(hTGpz0-;{G# zEDqo`Jk4cMGYH(ep^%qJ&(6MFfFhtmWvPyOl?|r@YmS>mB{$D z)5f)`mOQ6FUdXBr(z{f9G;ef7=lJNX#ap-WzvM4?!$+dmxk&`=q?=94R-oMcl?M-HeF!&*QE;&aeM=UAr}PBtl6)K!EeX|Goo`^W-H1w#PYFw|_8j zXb>s4GGM+Cn|nTIH$2ss@D2#rs%S83fWe6|H0|(6RD&V1U4|@9xatU^{pQWfz<4fermTe4ZGSHwG_p##uv?zlz}pbW*s%2khCoQcDiIN z?D(*eVtNc08>@%tF7o9E3J9EZ;kkYvhmi33$+I0{pbn(=%-f_D;gD>p zV4GHO$?QxE$#P%ZCv@pfBSgR~VmKWyaqE6}>jAh*3>cU;!{Ouo=jRR*w3Y;Q2Tt2V z=FC2aBif`d*)KlYAzV#LZmwSQRhA-2fU(w1HP++8&)-f-uJ5I_Ocpr!>@sHzcXCpX zWz^mu9ZfJ-y>ex0a&f^)Im&*=WvlQb%Z58!hAL10&#MFPvPVWd{5blg<#c*Yq0wVK zySg}ci@F~fZ96fIO9G6fbnYA{#v3tQIa)w?V~D8*pyjtEB{4qHR;LyYcSI_rMMTaq z-2XE#P37ED!y;_h!Zi{avQ2sG3r4oo!Skmhxq=0$y9R9^KM#M%PEGsB@5Z^UrLy|; z%6Hrn;&iUK&9F_M{{I+z?{KdF_WvK*6j`N=LK&4+DI&^Vk(9lPjL57I*-7?Z**ir@ z5g947j5L&;9hFr?^u0gd@9T5@zTe~X&+j^pYaG{gT*d44d_ErMdEW2$+gW}2vgE4- zT2{FY|C5SzJrffn6f)lbe1`ZaMbJ}-Elq0(D|Lu*|JN{iR$|fvb179%PbpkK2sngd z=G9q+cRCF#pj<3nJL(^D1HMMX!qMR3;G3{@Ydr2JiM#y(SC6=UM$@#px; zFG@<6?12Y(5@eFt=4%BH1Wqry@F*nhNCHqq^z!SS7nc@;4)79bIg~}ym^T=t#%uu; z3=$vSwRP8T>!`IcVhKQyeRO(S|H~z!n~<9qpI4t2v9A%Xzc0~q%)%hmN`I$;^rb_m zP|Q1zqVpsnB7ShUZ}}ew!kL?J9LaSG#_qs4E%&nO_3hjui3!LD;6i<_{mZt0$Mxqi zlNSqjbvF&l#4ayef8;qseJ*JGY&Us4x3evAHtamPbOnn&eEJEhMgrVr=E4$y&@qT}=`9aJL zKqM5@_CdHO6YmVs>r26DT-z8l4*3=;0+HsIf?D~BeW-NpAu-7$7oPKiV=D>CWv^@wfyYeu3fXjZ~&TXzDoaUHqmJj@HcqN zJ0a-k|HH?GRpcwj^BZxm_;(!WHIw+ubfcK8*^Bz_E0oRde^$kafds;J7}lcGNWb_K z)!+Z+(>RoULeAMiXxq$;(!lfn?Y{jte!0jBO(f4IU3Iv#cU$iH`ruPlGAX)88-2L; z_ntgcU}BiUpdG}kT;e*`{A$W#^Ge95q_xIoQMg(TPw)Qp)smLuYa?QFQIS(`Qh%y= z4kU+WwIwColryyRh>3mEW8)I3u2bAADi+ez$H&DS>mC+9CMI!4?qMhSp1o%`jcZI( zj0}#8pLI5Kc*VIvJ8-t&NozQTKK?szZODGsmayRCmruSF8{7UNg|W*me0~woSj4E=r1fCv0j&Bm~i(7oQE|hFJ(j zVYca8V86z3QPOa}n>*-HNz@P&!>K5Os<0d_FvwLHQH`0@qNi8e<7&74_%)qKx~VLkeOm`>o_btfZy9yE z9zt5kx9isBj&GGhT;Hfx-m&S8q zlnJC1gmMYRbDHZ(DAkWhlW!~5<C)^b+_*W)eYWPWGLy> zi)2-~@g!e0dgi=5)T6@=tB5C5t1^UHt%lHfp&ZCZLoelUU9zueA!$B*k<&*3QTg^VQs|i*3DjwpPnc;i0)9|$mksE1cLq0JBh~rvHd3|g6@g1%xr%GVlr@e zN5C?d#t11hWfw$}$HXQ7<{bscbe+O|jvn{%KSCXy=Y;|ANrjYh_`tULM>Co+!;ovU* zV)wsWSRGg1p9lLS7sKTRHS@m6KT1gFv;KE;44c{zuMQ}D@94ED=wOu zSa-7Ho!yz4nXI!#g=F;Pc4T&{+XQvhnd*ejk49Q-7dWrAQ#D5ax*hw+v670pcjCKF z9J#>tI)k6!RbdQ4@GE^YZmoqtL7JszB#Gpfy^Itkoc@Gx4uyjtaUzETVoxL#jUF_i zFdKwEefl(_<>miBz+jz%>m9)MhzkiLCFS~C8>^VBNxDwb+x_efbSHd3unV`lX=^iL zSWI|kF%*S|fxMOHHMo6?9JBE40p2ClHYm5v+pye{1_TzuufopWhDw#xVD3Jo861Cv z3$$Gd;U0JlGVpw1)ZK|8t_SSA1iTYL7~Rhs@EnEhJ{AG-Bm(|-j4W(EJYWY_OF5^9 z2t5$U8+e;hG*sV~hOSKz42aP$)IGvd56B1cXse|{2`YCUQ-a|XiBe&wVG{i5S}29bO@7}l@k5}R9P6ru)@zVQRfIXMlH3W zL9P~YLLVl) ze`Xi8dQ|LmdU!nZOunVMo~6Ye{_gktd#|yExGgD=m5@^TZP7Yet(f?p#PHuzO=NmQx~?oof_tkbKrBqJ_I*2WnI9P$X#jhYFjBFzvlCWm z62f2$s> zul+L!UciKK31ZjG9Z?v!flK2g!O=?~UGW4eLw@;Gb3+9AtS^|F=j-IDkq{<2oOLsO z(&!i`YM}KXw+&L0oN^#kYQi&dHsS~BR=Rh`-Nxe1ss4noBLjsr;nI3z269<{UnP_^H}oOE(^WD8;Az{`eF z2fwK3Bl~a>MPW12a7ufJ#{3r*{lkU~HVx0jlB(O5mg%!jAGiaq@6Bt9JCWyBdmBK- z$i7!0DVwpHSy+TpkW<5?;ig2Al#EGw485P9(-`IS%&GNXP7`;57O}0W+S&bZy(3GF zm@f2?WJ$9cZzI|M>GL${wdhv~339G`L?{}$Ru3GG<*d2hgZUM`n&5;I9F?mtmoQc&>r(&oZ%K5M?&*BQND+cw|-`Muse{F){?r($m7 z^N*!dX$-UuIT*0R=x>42MkPi8pL(J8-YJnqd|gQ>l;9|V@q#g2X_J9SsG~%BeS^>| zIGtbMTes#!6lL6A&=%P(8hZMd@7~2hg852KPlDn1E8ukCvH}1^6SgfJpYk|0FkFX` zie}rkX1Iz`Hcv^OadB~p&?e@YWv~{)u%!lv7mAa^pgT>^PI=o|UA`=C-N`}Zj|eKA zq2Vs{Ad)e4wqt3Pa8C=Y|3IOJ0z#07rq#LU6Mha$JoNC`go8l{SB=Zdqcv4k0%RH% zYn1Sl3c8%$kCqvEuKn8vBpVjbb&zUjT{I2#(5);q#Diga1Z+s0FSz)Fm2w*2zmG=` zb@!-zs|iKTX;!NIoYh* z%Zfy*>U!oA2Zyv<4w|$-C~~ZTy7XF(<|zFG_VwiF=HX#(glP$s$-jJR!a?@<@nh>> z88TcNKl#!Ru1i0Bl*4MzH)Cs1lbukNg3|#Qb0V!sC|sZ@%%g)5>;P5N7Dk&>r%oYN z-dpS83eX@-rYQ250Fl6JMby*B1Rda7jeB|3kf_}0=q{m4LW32|aB}q@`YrmTDR$nK z_j+dVNs1HTS-omyHQIhBYda@OzzM+;%RA%Tx z>e*1a;h&8JBj?AJ2DY9&G0)Dl=-m-~?c{6czx?i-Ua%Ar|Ln5=NMwK8lS`RrW}JTB z|1;Qo;O`TX8)S0rB%xB~(cjXF^B8tsh-$LRjyx4Ksg-!%&G+wStc(RY1@9rs)7EnL z?p2c{HjXR`TZUh*_2K@oAr$NpFT${~-m$)>sEE;-I<3>a0#jjS-*D5 zeYA>Acys~xpoZ&>z_Jh*08Zbrk@@8v9bmtM`-MPxT;K~5Xz_e7#o}u?jB?bgihlwr zyGzVsgLcxyL0pR=vm8IEw|PS2k*(tZyzy}3{A%4xIxbbi+f29^BVPcAoEcd#O-Z-(VygPM&J_h3rL{ z6QD<-vm}54Alg^Kq8V{)D>OW`qmSS^hD!^wQ1goy*($$aZQ+FXkC-^Ez=FLxX1 z(A8D<+uJzXfeV@c`IEw&FbxmYCU56en#a^s(1Ad^10w+p$Gf(cxA4={t$jTJ1!(aN z)s}K%+-szoX&T`AYDEGcgH05Li->jNOZ5k&Vb7F>q|*oq;n!j zdHLe7&-@gdWMX6}U2X5wtG)uvL=Ct|gWXe9R#>72pEsr#_^Ydr(<+sIW-ZWF*ePGR zm(-DxF)c>MU;UT6#}HXq!d$gnd3Hg8GOSue<3l=oj1K4};0V(}@^y1_0s=N#~Hy_+rnW1s|q4m3F$+wmwb@6+-Q5!bW8=S@_m?Ls z*+1Zj(N)YYTFTOLyGiHE z$wQFAq9_uj5=L~++!6M0W*BT&oN~1tYwzvOPVn7r5E@;dQkeTdRdsO>d=$`&fl7X# zNUG=eI)Bg}yLCIYILr8aVZGWh5!Hw%S4JjxJ+h<5yu9=8A?;tawtp#)I2h^}E!ah| zt@HiuVE6cJmHDUbS;o;_ebrme+>0tmJrZ5KXX4ty^cIhO3W+iDWG_5~syN6?Cf{~4 z54!Yp=2@9@#l5g~g&Td-a5`qS0IXT?X_ z-B!sQP)im6@GM#5;V3@pn>;VvWpaexK7Sasone23dH>oc?^G$jWGcSE#sKre0h9Dn z%|8cXmQGG=a#j6_zjvkO<@Q=@c}cxNiAe6I+drLFgbUx#f9Th~^DZg~QqMkRN+&iG ztwiNdpLdefNpC%5clVyYUS{Zjf>MYdZ->7EwIl)cwC&FHE1$#XY6+z>^9kn5!P z*4kL3X#Gu2K<)Xwt{d2saOT#k_OCfnfwUuuB78cO#%2p{!-mIh#r%Dx@zliav9Uj@ zl5M+Zdw1MZ7v&;K17#tt4n=PAZM=Gi+wanJ2Nt9YhexTa&@C$ey9CL99?A_AgXl}Q zaAu(mB%tCro^GRH1@o|xu(lfSs;hxzsB(0+*UWYGy9Ehy%lcoRvM_J#$W#cCo2M6+ zDY~os+;daag!g)TUQzg5L!Ecy`qQ5QSB6jh@lXjNxkM#;IKXAXvMXZ!?8vl!t%JLP zYGhEi0pm@z^y^$CB>}Fq#y5U6to@9N-TX=Z?SytpuTIDwLkG%mIiHn$YN_Akq4R(|_ag&i$kN^2kE)=*>WrHLnFilTg zAMAS-Ipsqx_fV(e(fiXviKWKyv5-@Yy^Yo8>+(7seT3cVX274pB8F*J9;$+~6ptPp zpxb)(Z)>$B@r--2^6;2yZQs8jWSm4M7@cW-aLT95A-Tdtf9Ict6PA{(>krIi1soG9 zOiUFD8V9tGPQ8&@Sn;v@YS1`AempUFI{VNLlNKi9py`+cVyzR|mDwGF<54*e`m&6w z3)!^gro>MDM~^lA*WLtB7Yr|MlVI5W!~aGv{ciyEMIa4M?~9g}x+(dzm}Toe9O7R#Pu&`pQf?`_Da>TK-CLJE-D3#%8T`BQOW+dWGJ$;2f)XD7 zk*85r$Ml`N zwfDl-#w$WQEtl-0`Bta6kGpB7ml*AMRx8Lqc1WNo`Oze$@txoYF1~fNL#6u_5^JM; zQ>A~pNRk-0<)u+Qo5(gB{`@rW-&`$Yn1SCySgrMbX>Z4k>t_i$M2FIt*j{I zfW|7>z^a^>l(i!<$6yTc1p^TFJvytLC*BqxnW-1C+Dn9zJ%#=x&1PUth>OKp$19He zk0!-kxphC$@5?>CFUDRTJ(B@#Khj`9Qb!=dRm66G7z0r=Pm?e6x4Zm2*hH%~nY=r{1H~c*Qymg63nC4 zXsP#o{sB`@vZU+fh~F>+H@|yq<1xE3_vxH*1-U=Rou&KspQ396#)?`<1xkgqrouYs z)^D1z8=d#M>tYJFgAA2Ac~|P{@oI4~axVNQ)t>r#S6ib@18hjriy0v}!t(5308L zoRlnKrMGZD-^VQId{QbrJ%Fqm4eF%9=67Xn3D}nc=>-i5nriVezh41+r6Sw;OS7JR1 zHbuNdSjE7*CtE{aHiz}Nm3 zHt*$S)6!7C%vI!#kOfjo6qc;mvIa;9%m^P!6SwqlkoE9 z7&HIcwR0!SZ@;ITOh#88>+cZpCP|+$VL{ z?C_RV8VtTRYq>Yy%cGhsMn2C!oH=9LuxZfX%EXLX^Lz-^y9%M?Z|+#; z{as>u4-FMz+i-UFIC}1wYK-oP{wk{A=%}a$d<_{D=l%7fp;$Aru;eZoKnWm`{WHIX zcoHBYy&4=$BS5XzgQy$;whzt9sv5xE^kPBu_g*3Zm9M%bv*4zd=gnnd&@=pB#K!T3 za+8>P1oFH}_YVxsw5E=Du6EhD=E>qPRin$NS-qo@j-F)qQkQjoo#(bxPEH6n3{VGW z+=`({qGdIo?)Ik>UC%M^7BlcYODe-@5&vU$U&BkXe!q~`p@Ru>C(PLx0}nRnk_?e0 z@Fd1YP+lj=Al2MX_PRr%=&ooCUm3)hk`@!|PeoQLI=`~@_h~)%{E2`rk;H7T#jtU7 zZcby+zA*&dYBr{HSR)4BTN7H~Hq1BTCSWThC0xQVRw7lfwKbg!2uCP2l70RC2$8no z*g=%S92YeXh_UJER&;Rjb&!FvuTd&XiO_WNbYSERCv)hB?y~dnWqJa(gY_GjWl2<3 zRgD16k0BJ`tATGnbaiPd+hZ2M19^qmStqrGO(SsNWCM8Y5wV+*(YL<7zO$BKoT0HJ zVD{nRTM67Xkv>Wi3qXz}nP|)6)q>RWbedED*Ve&U7Dsrd5rGNDZwD*N-Jl=};?>+c zYODDSj?g7c-_9buNG#NQY;$(oO_=S~!DI(NJqP4bfQ2KGKu&YGP^3$wTF_7u%L0&= zpndlg2|D;oiKJ3>J-L4^F3(&vsZ!hz6HSnKD1se%&tDM&s|xc1GY7|O$Ww#}gp=fu z{450`bSj>gr83BUUTpCFzUIo=FXc5fza;+T>-&9zf+0xy0i)K60^?p0P5;jj#PJMClJOll~y)tYUwfvi6ZAM#FqK@o4tYomXDs^ucjaZ9Z zmrBVkT?dvzmyn#}H@4l{INupMHb*ai+t-SeJWjdDWoCvwi!pYaaQ?ie$?7Aiq}neu z5kI#0kZ^CgMsj6D)10)NN`rEQdYNjIftn)QH@v|nQ+WRFk8Hh^i^04yDxM7i<8!vY zZq5QzzVT(OHI8$WFLLNejQ>Zai4FcS!aQ(rP*i)j(}_`uX1h!FnhtahW6}wy`@hdpS;lP=KmHExJY9aMCXIWWUxP;Xe z6v&a2!;#*Di7Rl@O$66j@Efs^>}@Vhv&XxY(1?irKqBIp7FZ&-P9n+=)?X$JZrHHq z1fw%1j{w1gF&d3glw#o9$wvfwAGgQ1<{q9|&S|fJf!Ie90-*27a-Hb?15lqe)qEUR^ zor#YJd*nn)oA?-6QX!(p#KpY?z$giS3;uzzn50>mE?P_%Lh878SBkaSrdl6;JHY?jOvry_hxX7R>l3Q}yvpW3Vx{|z+Q~(7 zW&RDzU{Uq2=N;|85|u=1$GYr#n^f3*^{&#&u)?+jZztVoI`YDIYsrTi@+5}ddfv-W z|BS^b@b|{z4=t6k@ZNvkf4h&sfXnWNzEfdeR6xqcu$t#(0)IzIEM)#mzG$7O3 zn4NLI`<>Y>u??&7aE>@Hh17q6G6CLu>)scZS|w|E4lJ%*QHLvym}7moz~+kyk{gqy zKgQu^bsS)W;}a*=Ai&q|wORe=I`~_v#Ew3M=4QsN1h)}DsDT0FmIT<$3ACrx8cs9< z2#gsR2D&Q1abVO3TSOAVPG{?Hc>`&y@~>yx0$=n$YOa{$;4|=-b?+~~s~6?r@rMzE zm{ROhfT(QP?t(D8&bf20a}|(K1cpsaOxS@f!x&MDCrl}akf4aYz4NeNmBZ6cwBZnz z2tL+t#EfnGFZ`n0ARePT#t`Js*5Ue#$E*${UiFI?c_u0>ov;*$4e~UV05oPm0f;Fx zm<-W7l(g+}?57ncjzAq>+8~@c{?=u%Op`T{k&6Cfn${6Y(iON=Y^q5B%`rbx0Sf{z z>sD;Kv2u8OK-qAY4d;jFi|z%*ZzwJelZL4uw*Yy}Akf2|aiF9oRBzDA!u7y24P{xc+D-CiqAHxLGy8GitV&4)u6&0^46Rd664=d8O@VNR~xo11ES{8InVyH;KVJA5ciDYR)^U8bpRC3I0DuYt;wlDLk#mWyWf?(2mbf4> zbH#&^<7$lMds&Y=z7C%2t~9K_dKJ^X4D%AdUhe*figlw9oEBmG)OADsE>U zCno7QN6JwRluJy+PkcS!+DdZZz@rXDbJbe1p1)+CBkxMbTaK+HJXg@+nfR&>)E|$A z0^hrN;ol4C$Jl9{7CQLX>7C}bo>v~{x)qncEmh*=PX*7MEz%v5dc2JpHzwZp3hw$e zW)*c{`8xSe_04UEnU4owt)8CK^0G@$2)*iePkUwG`>*8tywPklG)`a4PK6)6E${pyl^d3t)%B^+^J6m`W*>7!CdDW4N6~~0?;JtA# z7==nC%0EQT0M;mvam+?xh-)LO$42z%(f8ZA%5zvtp%<`p?ZN&4G)9D6g#L08GQwrd zov_eZz{9QF6;XeJua2IZAWE>Yu@S{XYfa4X$%E9e$mHXEAs!)A5jdYZ@W2r$>4-7I z662n`Rf@{WL@IDE5(^kw1@Q?9-Ysjq)YjcMWL zVjhpvS0lGSO`kS<^04u-r8wzyrjDP3QPuM^B{8~XHPW);+2bvqL(+AmJM!fgIVJtd z*2yt)zS*DK^2lP#Ks9f$`A>Jc+KjzYgKGYz!i~TAkH&i+G_R#Po}4g!Ic}G2Ylt&u zJ$v^E8|}U{&eUV5mDt>HV0?~OV1HiriTo>XL%cYOcpFGmR1{7|@R^)DzK|vGA0ZB- z1bhVY5fgx_@fZKcQCAzE{TEfUQb| zQJ%u?2l!|Z-t`exLsUumVM4EDrTwQER1~PR9t0kmPz2LR^w=@}BS&)?OGoBJAj zInl{ovaxAJV?gRrZ#8VK9Kq2D_A0UC62XSIe# zOB(Jpl*|w~=&%}vZaiE!s;=@M%2JH=zakF#Tk+nVyN@J*>`OZTaJs{Nc)`m4eX$lQ z$|`7$2B{0}G0CK#jmihq&hZb%Ev1kgXAm$R?y(hJke$A+vbj60t44pF|9nige8cyU zvFk6@_rH*qrY;yxq}a)Ovv5gsh)hG{$e-1mgAF@(gXB0N(qL(4zM;AEx-px%#yio^ zTNB=SRFfz1jngU~_IuJAADXg9PyBPl!Gk7AVa9`noTq)aPwUtN#%W$C z!8h`TB>P;p>%j`BrP>>R=VIIA{|XdMTLXl73z-sUVy&0;evvJ zB?uWi=-F5m!P1tF`e6)EF$koqh?yco>QeIl#$prVi(n&)VHYa)qOl*bVMLFyS4t{{ zh=R2|f!B!zex+CE6xAWd5UEK-1}t&m{*EYu;Dk2ixZ5-%nh+`g%mZ<)E;R3UG`F@^ zMk$T#cf+izZwYKJ7>XVP@EvNRIW~|0UV+n<2e=tfIgkJ^xg$Z5hxZvbZyc~L=GdB% zmc|JT^#>X{*X!4tVZb0($3pybMw|gPS18krQ3qzk9%JFhXsv4g8O?X?R*gy8_+h_! zbCr8faIlDdCF2&E=sPPOhr#d#!gvOj(lb>RnU+<~@@M>wK;)bFGU0MXW{;odVF9yB zC-1HrfvE%0zRW$Mf07@qb5=c5B5&FJko!J?O4mHZpHU(`nYmp)wY%*uYHCyygsBpi1p*HZ zCP>Mc+I^r+aMchiy3sJ^;CqvdS!ZJbIfNjOp@RXZkpZFR*Xc5;Fru*D&-lCX!w36g z%CMOV*Z@_8!12Hk+hzf;{H?yvR_0wjcN(?Rf@5z83ar-_|B0~tf zRd;W16HqZ2yb#nZL`UJ2v%q@>gE);ZE?~UW8hD40O6NoL0jUCUtgAV>HqDC~iMYMd zQ(2+U3;MiCYypOw1z!h|nU2*uxmWmmk$fhsJGkLTG48R3)&~kkBp`Wj(Gy-=*-Bj3 zXdbffnEtDy3n1cF{`>wP|EB**>4Y96&uc`k)yd5*7rzYg+rbks=}uFqQ?!brtE2Tr zVrr<5}24vh86ivM5gZ;>_EmG{4~zbOun6mtJ;p3dSaVm$@@ zu*^Mr)y^%4OJ7Mp)BUE?d_l(NU(HYa1VvH8YIb7bEo@DVp=1-e{Ze6inESr3a+YxN6^wrh0lj)e7;As`I7IrfT#V9r;LU_~hW9zP!eq9H&j! z)|7pxhdELl-aOxEV}9gaV|?M6;nC^$yTig<9B=MD^Xpc1JfGIn@}S7UhG&8PTxwry zpM_tY$X2g@Toh3<{pGH}j*a`s2W}lbx=-K;s7)AXI>A$2E|K@TtfGRDj;$PS|MNoA z{wE8h=*vDXGhNAOZpue{PC-ul`1gT$oVUEKrhoVT+^g_Cu6~Ea4*G^B-87UcLt&z{J4@lChbL38Cl4w7+QO; zeQikHbP*q#AuVG9U$y#zi3Rg zFuoRG!W1_Pqbe&0rQ{{>_TwsP{&`3RaxHVAD?E+~F15|e4=XyBZyzqS@8S=B=h1FX zuC!y5`bx)#-f)&Z_9?6I!5n`PA}BUy8-K{JJp$>Wy4p{wKlXLnu$)?=;`Y$bK|gNx z%Sh7aB~W;cZf!c#@SJ6Jl2)Ttv9SB%$1`je%krcjcZf(jz80iTWhh!9Wj($szbK

E>=p${Lt@f3{Wt13LCnf6df z$MfL8O(vbOKUYVWB6F{dl=6R3U1>MBpw772$Zu_-%(rSNf2C&NW2M7Zy+rxXTSd!a z4~*sV1(d0d+6&l<$BwSiR;}dz3Ub~zze;~_N#zKAR0>)B2Ia@nj`ev>pTb1-;}Jrzn=hWs z_LGSIPq>Kn;oY$gg}DtrkhxCBjC<|PJpERhra=E@bYEd^xkNMo8w=RtBN)#&jNl6l zg2Wq*WAmVsECG}zcy2vC(HKjgEjYt@3vSo)e0R{_kD}CVnR=TyAZ}J`te|*GBaDM- z?1IUSPif(qs(y>#%?BUNdvkjUB<4I33uI#*_qg^jGo#{9nHp&rSuL|4k=Kv+&5X5+%t_x#3&YbfUprFEYopD3ZA1NBM`R z4m1o<{dsRk>GoCOaAit}+|6;9o^U@czxpVOq{W8a4fdAK``P*!ycR4uQQNlf8m0M9 zSZ_YY_5Y=k4|H^Aza9G$?R8OB*J;6tT+>tcfJ{KT6V_&8al;D~b&@dE;FZUJT7nC% z2Wbb880ZG4;Cm!&9=K7!E!dR3ImOBP>V5O5_002}Jv@qywEoYx{NlP95|Z6)^RDk^ zY1s;2N|!z}b9hLKehyD&U(MME0@-9d{wyp?!|u?8mh|-&UF!&j~2Q>{_tr`Q~s-_3Yo7-33Wdd}`H!e; zHQ7oj^o>TNZNhT2pA$x;XTnwDuw(Plm>$xgiJP0X+;KPg#njwrAJj0^iHvtVc_vyW zUX{es=vOPRJQAZoW@u=vU@9TSe0XTw)t0hq?^PFyeY?eEiD)!NLuz7LU(!Vv#uaj1 zjxpY@qR)yqIlv0_OPV2KZ;N4u_??@{`u0zNZ1h7yo)|p9NgVv-i3E&qmKgaWy29?d z+y=+?^o4h{XwF)ms~y9@g{8FcANC+WyXD!=qa^)`dU*SgU$tT#;+<24ZegD(EB!)4ORPjP0#1$39vEJhtx ze~P`6pwMHUa5%EuP`kZ$KB&Aa^caU~3gzDV)2$+V)TuPjIdL(tX$}pqOltN$-SM`! z@dMl1GR5AX)l{YQ3*Z!quZiX`3gHMk0h$B z=G$n31Wv#A1(7VciD)Ut6+18;MX9fd=_peOIgmmyGx`DA4)eQ>-TguqU%5T2e9>)| zucc*T*Lzk-9~~>VKHIBUI58#sI4(MU|HOF7P4aZwH{SB>x(>$DAs*Ojqk*})Ek=;L zcD=e0!il+lczAdW$rv1wfQiEBMPw_GASMDt7b}VcT9^IHDz-d->h|!_mc`s!sSn!u z9z!kR>Oo`d6RyuDb#m`p8C310P|oOhm6S1S`H)06KhHQ&keQT$7s6}0QJ#Ru2KD!Y zF?v#DkLF3H$}(+ZZptRpccuqjQl98IUwr!GRmZvb_7SqZ ztDCz;cFSSk*03kr8g$L;G@6s+Sa#Xkl^>NAvnD%5bLyY#Dr!qthO9Mm+QecF3=1s9 z*>~;Qg^AqW*fzqwfyFD<@CBkP62){7hB9^RB{=W!3?CeMpPMVQGIz(1%;lp$4SKz{ z-Ztbe!rYFA>CtRs^+>%*ZA0(j?%O?%y!&QiW5v1e-A;d>xm3D z2ST~nf*vS1-E{O5@(qNFg^emf*>7OVLk0m|m;$Do;JFf8(hzmTDu*VH7VK!>FDYqI z#<6=hT`RW&oEDR9C9l6NeHJ-(mFoT$yW^4tu_;0f&tvl&V$*i{2D&_ds~W31G2yb< z;Ad^er~cmgUE)~-k?;WxuB!~kE*+uPjCr&i5tkg_8^74K^I5>NEXvQ$Okp{DmkoAJ z*t^_~Kd`ghje0{pJnErctIj@~o$XG4Bsg8Q*i~A89=~|WVqIeE#Y+>-HPv~&PtyL} z%STQ?L4QLoZMMv%GA|`7&N=>VQ6~Seqn&3o2h%Mknz8rqe~XOeUWjhf2W@UP%)Lpb zsE9asF)__5DoXwsX}G5~!QablV&%jcp4f?$0~&-EK#!n{zsvK> zV17p1cvu-U8f$#9Brqt1qZdYssp(({3pZp8RQF2Jz0i7b8Kwemj2jmSO*U|AIn_}dAUt$+pf3L;Rm)n z+4Y!D$s~h~=A)CU_n)DYUGFvf@`hQ86b|*DeAZ%@?cJNMeMQB_V=GZ!#Rqlx+E-m-IShu!tll6S8EtN?n4v9A7-N`&9@&Kd8{P|NUCbJ$JQr;7PGc~E+jT#FXs8S2mTfUi z{5YNcBZEhh%=7Qv^QZs)rj7dP_KZLhX19I*V@6w4)uV@PX4$Xw+)6rN#W}vCZok!X z;m1#YD&hU_uKdW?8wKLGZqbuc9-T3&kYXKr zrg5h%+UW|Jew{Bl^N!DzB1wI}B*X;+f9Q=Z{IFArpSDqWY z&a4(LjcnGqDxNkpWWVqE*m`2iweQUxA8U52v!7GFU3rEzQ!NWCc3zT(zJI5mldq(v zRFYVlr*m1AFG;V6?0npcQQwl<9$)`A#Tt=~wK{V=uc0DJdaa5=Tt)?5*1BSi2`q7} zVlmCLXNxcGp#}W`2=TYjR{gjCJTKViOsIE+jt-@mpl!|0cEFT}xtBNzzEvWfK^T{F zlqX6$kd`8pU`(7QJ!3J*H^f3G{p{K^fPos+Tz?3|rJ|CywsyLV1NdryvmYV&);z_y zlnc&ysY&N}D>EnVqLS;NJ;eS+O5xQ7)0g_I|5(RvhL_J$Qeg zFEWbR`T5OY>>)s#@b($lZ$&+gWQfO@I{6Btt-!o@=d?Fwwy=gFSUhS4`WXK8K z^n<5bgjV_`i&%usZZn@;{Hjc%q{O&`zl)AH+V{1_i#xL^xovW0oOMOy-yXh;VmW(D zQ1H>#BudLkk*gFBVpIxCO16Ei&Zu~iqDR))KW=S0af6F()}&Hne-j8{xQzk2TlG~b zBs&Ks1bB>G#~}715pz&f zZ6ON&85tSbKk!s9EH0i#%oh0u7jVMz%CN5VCCDssC%Iuy6jUjUrS6Byr|74<-^#8* zm#bNtJTpdhAKMXw`Lq0AG&Rpm4Tm3#C*L`>W;G*KcOit;oCL2$3n-Zt)S`RpB{;<{ zDkNM%WM*(mz~NIN9>B`^MT)s%;#Ep~PdzOm&YW4Oe@C^!dXayp}rD zpZ5AV_ADJgi^pv7em}wU4GPR$b+J0eLQuT&$w5HrRW!UgX zP@FxG`?h+(I9XVVgzheik^^#hlgy+69DpJWP+g+PZYogH z`FsiK4AzaypT6U=AcDbIeKd*`L`n+rm?#Rx*E}~R!1oAe0{NQHH%IK@gyn8w7s}E#`cDh^!m$pSfF6{ahcAXH}$qgRZ5RUK(P|1FZ8cI@TNSGKOD1nF&70BjbX7CBMt#r5(Vgdm zKc_b28-35Kie;L3(b4?3Eoi5MgU!b?!gX;X410C@%(h2J_%6LarN?mjLcs5%4cT`d zzhTT2=_aqD-$~lxn<{5LelbaIzDV1A`y__!Fe=KI{XykPENhYD;{92VomJTJXepk| zLr*zQMkI{pS5~OTRUHT@r| zZvPRJ4{TFbLMRcECgGJMSc9ku$;@p9|4t%9+|u4oixqB)X84y6@I2vDCZri?YDDD$ zjd~PH4t!HRZ{I>j*<;oc#fkh(6OPvz9Bk-@S~`i6No`$3ujF3#XJ0w(cN!^2 zRA&UUHTLJA=UM*xJ`^r3uB;z;-F7pDy%tpmDh>OsJX&>`2=X;tMp_ZGe}7ML!@ZL@<$So5*rTe2+dbYC zrvof-&|3+kfDH@~h#$TD{5b_@_F0fRytpHUs(F3tl;efivkr?>|2{u-xtH|_{o={X zI#L{+?$_T=Y;-C}zHQ)mX1#3YLS=fHB~knVy`WO2a8Naq@h^B)o3We?bcB6cY{?o* z$UR|Kv#)gG3mi1zAZd``aM4rS(~nO-FLk-r+}2hVkgvXH*A9p$hheu21h;6b8uYw} zv0ahk~%E54NTUvo-IB7myizJ2>|soVG5S5D{RaV7%D zK=F*u%oMdYXSCsnA{G>bJJMQyMeviM zFoL+cfEaH%q$|*J3r`<_DX&GPyLeB2T1vh?*SRb7;SRfsyR)+=J``Es`?JBNQzUIa z*r${DtI0&HLRiLB7Nu1#3|++57dVXsOU$sy`3>%OLdJvvp<}6vo2mR3w$MvQ8tJI2 zd&e^5t2KX9TBf*DPh}^iU8lU+{L^FGg01!$+Dc*_n zjQjxmNHGJq$WhAgrMGv~En02iFBE589*$L8dPP^gqCUwYb}U%w;T{$4{euDh)#NPF zcm6u9pJZX5-z`ef8n(rT@&V<5>ohN;mk^77AT}FeccXUMy_rK~nORwWz`lql)0>$M z+pnA9KEh&;ci{RQ>)+EgzWHlwEmlMM&2eC_0}sn2DC4cyQ*p{$SX&dp2Vw_@PEj1b zVt|7dhZcr4+j*l^${jjbu8UKoK3)1UUfEfNeTk5pU|~K`Yk_--;J)nIWgz=GE{+98 z2rWQ6kd!b7gOacX;NqLAVgP{>GVzk7p2hsH>JF9^@h+47fCQ^8Y3>nFOPBWWkByDJ z=m8fM677%h&B=X#zVbvb6XQ!3$}+*SFe$OSfk%GJ2Vej7ZNB9DL&KcAcb44*3vS4c z$krdr>VI-x({|(hR<~t!J*G}^tM)z7Dnw7Ovno|v{p-NF{4kvw1)j6C&)FK9l%rzE zZr?VKjea`ee0xXOHU91+ZR<(m>~r4xef}nE-Zs}EdHilXOIlu}X~Jf8jAprrmeTe&=@&)!hziONOEZN8LQJ=T$q+^%H zO%Hsw>;5hi;vma%!h7`z)4_TwS=pFSPKgqy9>%aQxEUGWbv-8ttkpO_)2pK9JYPnDPh$G zq>TsIt-u_E%y$fCH`! zO{R4CSwxmEJkW1M?~XI(1y=6Cmyw;Fy$ai3e8f-a90v``1A&$hvfsqQ27&>M@0PHr ze)`^ojS_gf5GB)uoi_f%x1r~$aJ_efrG0&_uMvVci-#8+&MnBx5YrH(lSk1+dV!om zQ2p;}^WD68lNfa1;ljcUkMzz7%(=QheCT9i+6Nx~lP6E^$=pnU2;u~?JS>l`u%EKb z^78PAinc17WBSQB-teq=_V~u!5t97R`44^>hLE_laj~n`X(Jyma;E6IF zVRbPwvi4wSXD$D0B)pTIeWFVC=+UDDf*J|WAg}&J|V*4WNU$O zCc)9@$kC$&tRA1+Hbg>*?+UWPxW_%LRZ2i_xcKE29bQX6oST>lPQd^TQUsCtN2N=A zB~bbZkNMp?JU9{RVk7_TZ z-lIq=R9OAJ=fS3_aD&Re`6@EwLaKyq^O}l%pGfLTDP-EMMH*Q4i;rk!3W;)7lk_qk zXRw#279~4%M#a*Re_xwlZZ02K)PVs3@)4?r6x<>>>xf|}?stsY3+!A$5Xgn9N-QnJ z3wJ)1SKzYlPw%*U_g({1|LLhCH=Yd8(^AkH6DY_4T&%={F#00s)#Y=hcf!M8VVySq zAsssWe<*wNa4g%lUHmpnq(l=kgbuzkq%yr&U&9%Y zs>D&lSxj|QEFzk2d!CK<3s~}r4QW-xRc?-vr}j~gZrE(^Tza!z1wyfH+d18dEYw%pqw990M*{+8k^krN9 z4+kALs^?Zb%d%Hgoh6z&pKgq%c(bE-glEkiGvA&5$=3Y;n2}%6;HDS`qg3g&7!Y3jO8KHLr?iv5P#e7Kc^(kM<;+yfBe^ z^TzUu;lkbxdb{}yX-f5vF-p=O1y-mraegngm&L2VSnC|Gs#9X!9+ckA4r&!H9Vux` z&KI`P%+U<{Gi-J&^WUvBJS3~eyCD28x9dg01mqO>^N{Q-5xNi+Oib>|KJR%-lPsux zsgw?CU&WVr&?%AryF$>f&-U5$pYLlrcP-9mMFy?`x?2BXtSe7+9F8Znq5l9hdoSvi z$qPp9yFB8XXAl@MKqYeb4a32#rJM7-TUSD<+*Vx_6XHwDUlW$c#L#~AiGjySug8@E zopI{wSI(aF*nQrcY1RBlb8l})s$h2^i-Tx(x6k@btGU{Z%T>&&>?!reA60o57?Bi! zi_?n&KzI}~RdUE+^H@1JQxIce>z>!(YozN_{hnxrw= zvgxV+g_}z^H^wqGBsC9gv{C6#u+v-BWB@c7M z#Kd$|UtL-VjW@8|FUfGD&F(9c7z0=60GI!#G}V8z0;n*`>-4E(kK>lm&Ax~dJi^G# zEU;$}F&mO;04Rr|*_I~$q3$T|LmxE2LsRs-dl|(xuqeZI>DF)RF^Oa^<|q z6@?u|^cgv;wr=f6rpnIywB|dbi*syP*R357woHrq=z}%d`Y*N5(2=FXgDa;2Y!{Pb z3p{NYy3~&H5E^M>lY~0sI?^y*pirk!Q&WHLE!#mt-hg+GVYUKjj7=CJg|gw8T{(*O zA=}1!L9^h;)JmH_O6Sj<2+X`P!TQ?6YgSOrw^~Cpsxmz^++XN>P)~9~SJv=iq^XrT z@{~4+8P~#=5Za!sr((vX;4SLwyooVb#L}2U%VFGEA%}UTjZyWB15K`Qt&NTN-@JI* zFZR6p*S}mq(D2&$NcTsY9Cm#h!n8vMs-M2PJ-M~TxSXE2y_$7;1| z)7{AilUMxiXQo^TLuSd9Uk--xDc7LMMbjp@e}7d+LyWq>x4uV3>J6Br24Z3w=nM7t z?XqxaMyd+{Y7y`I1+yjoV^{Fp{3U~@59#aQ#y^5I_HPG9i6WMgltD2@%IH-XMw0Xq zNe40%5Qn~fJc`5f0Qm94R#zTC((;m7=;N0Gv#Z&gWjI@2)OV@|Zs15L?Y`($-apoo zC`%od_EhkrP|xoX+D(*w`<{vw0W4;u8ryFj#4oR^e1-b+z5tKsOSc}$J_nt|@j17n zAWiEq$E80?6_jm}N?m6pz1)ASmZW^A@boiE#=P3WSD-|ywhPn8Ldl$a+`oc4-Y!pq z`-@Mb?CCkAqr(a&i%==hKDcsipFM00;v3+SNC7$@EMc19Ls41r*-KnH$H@F6eD5rQ z8xTV1J-y_J)`vj}Dq~E!yQ^1&>>^+TK)Leto`Nz4PSxs`1lGt3I^*|Oc-ij&sRby= zP-*c>xE1ZL6xjAb`ul1^+V@@n`a}}cICyg#jJd))B-04&)zjT{xRX_pKS2j6db&kON{k>q*=rxL!u|@FHksgZS`jPlmk05A9l;H z<1VN?cXoAa{MF3kKYG`(GE)^+ckQGT-B=_-UAJAk_^-5L1$TYU$$7ipzRi8PXX>l| zcwPAY)1FvW&^vtY&RGxk>@-jsGMf{5$|JpR5i^Vih0x%tAk+IWgo#f;AU#`KLeFRE zhZ>ofiB^}T=!Z{OIhfWSE}0?4D> zXnX7+BPDgLSFt4Zuz>-)d}Km`77RO9@_V*TE8tmocXunikFD@F|Fep=;)h6E-kqQ9 zhm$sGTU!ZG+?u*imG*ts^IpRrFf8{(D&m^4hud}*@!bkF=U=}57*+anmxp(sRXXo% zE^pF7U?sC3)c7M|XTCk3(c!j~T5Z!&YmtBN0hFZpltR;(^R{vga{wyWJ@FNaDulY1+2$t5dx=^abkg z@->5p1SW?tjs>{+0L5G3C4`LF&uuRI^yE_Xfyjp?C1eaUp>ds^$#G6OfBrl?mC5lut+xnpdnh(6MfiU!0`b*eZt-o)*&e)~~O7e9(fL6Sj>?vs9GRX<)PeX^V_ zc07)VVMLlj)j(#WBJMI3P(3({t9S2iA_=6>R3C&#;C;3CKJ<+tu7hJ^ne0y8Pys;X z_s~6iwpRfZ4DPVdzYe(MT(( zj$xeS2==3LIhKXr1-L3LE9?4ybUmu;jw2mI5hRJUtmiEIe9^dQ1{YfK3f{entcySkg zpf#+-pvR0}TZl$KgbRkSH$R{YZHf_ZM)AmD^{x-eS~8d|40z``n*1Ey6l#dHNgx-v zEYa43%5AV0cYwj+-~0#vMWxU3J!n&kG0qTx%?q&jOuNP~$m;IC%T6(Kcyp;)u{C+G z19}D4S~p}#7eSBaxqjY7W>kpslI`>f!Lp%uk)57PMM?H+z4aN=XTGwzaf_TOvvA%1 zjsNm9gR`bSV}{S3UT*UE1d`39(4Nqm*6C*Fo7uZ$o=1v{i5;@*-1nh^#$iHtF2(Ep zoNL5L*AICS?lm@YXL3@-emw>MrK~(7s?N6R9ebMgM6K0`O;(rR%&WYu%VB1Bb5@@d zS9GPNrV8YJRvE4##3!Vhqiw1~vu*XsPQ9++HJ-CqH>^ko-DtQ#IiwK2ZTh0Hjp!zk zU%Pj`RaUdO{o*nEP}=S2o{7z8Q=Icp@_IgU?0X5OT+Vkz0VjKXUtc_s&Mt)efy7jg zleWR=!__zv46jv^fKpIdWcD_WYBdNNeODH@uB# z$j2QgpA0<_fH}4}hfo*-1>0P`^7||XqfNZY+X%1rC)82MS;ry2fr52|Nl|oUq5!GP zalJ`mN`m6wC0Aw^7OZI&&=`b^BFo!n`8W0bns4Dn2nZn=oZa09^vs)*FPd*bV0vbz z5V>VRs)AU9Wpgtk@`+jms{=qcq$O~ouVLXF<67{9H=O)}jB@oK0$n7{4`wq!3#bo5`<%171iXxhhL8i}3w}rK$dO5j zCqBzp@F*K_blblz2L$;@i8x`zw^PXLT z^76!O$_AhSf6dO%_0Oln1>?k1^hq4UrEB4b46#0k^H6E!B)-H+xb@zlmjeJz!z(0o z*#3_eYuuyp?x$p=9K_cmF^GjPE-T9bU`BwtcM|*H79ykQkcz6JU-Hw)Y26)W#WIgy z4OIsoyJ(oN(oj2j{Z~PcetmwE=ap!!biK_R=NV>yzSLKZk4{@xP@Nl6cuK0lsPdrl zm|ZMq)YL*K%OoJ)R#Rp`_^mkI-%djMdA3rp(5QNV{`+OSi^n zxrVxu6V`O^pbDQ)a`Ug5uPd>2p&BdMSzj{HP#dc>_QqVolE#vW{ia0)^rCB-#GZZF zyBHSAa5cI&;cIlpo8&Jq)V9wyT;5pneS4%*EzdmkU!rA?xQuT${VH%jFCTDw-;_u} zUSD0}mHS+)t&5;K!TmsA+t?Tkl65OmoDrpw@mP5`E=neCd-$<1ekd7ZqVdimku-ys zh?sw#{-+0kZ_=}9J9(a!l-b=zcAQd1{)>vm4_V_-i zRSWU~a1j(qJQw8dg`z-(n2p{NM=j5{!6>8SMdlCyvp_kGxpsQ?_aOpC?D2EBY_SA} z#>dt0`v5GHF+&TBi$uMdl*EII?2zd`+sN;@9`5YO;_BvPV%Zr=RxBZ8P}(an+!5JZT8 zWyBdV<7H;k+;n!j0@2oGRa0`a&DaJIdGSe01G||NuG}sC% ze4P+c`M9n7;M#*=d$H;gxfh@qaUFWG>Y`i8@yw1lZ^R?MB0dRWs*<~QrQ;ri^A!p7 z5rWDWPraB=I^dg#>xGM(`*lyx%jV`#h*I@67W&H0ZbYMKcuf}pMo~wMy*3RmH{2Gw z{ddySXg>4mX$B6{=mX~abv||c<~(%0*>|6S1KYhhm(-x^+QmV!vGf>#<-%e*kEaFB zsva>?V{MMYf4EB5D|fNyGfhjl$iT*U6$;?xi=E3NH>iTPZ9GNULR~{m=NTrxmXYZ; zb*SX;m%Z{qjP`c9r%meC#%To&gzs~fR z?Og9JAt3=Y;T4Wg93lGd9`5e$(+Cd%r}+YRbANny8wRwk3yN&SB<%%^fRKmg`#LKR zNr@((fK0y{8VZPu+l+*E+`J{{2UW-nO%SddCCmlTkKogKPZyurZf&jRGxn|zIbdOE zV;`WcM`{7WB5WZb1aCUh7>=4_7D^3B_}-sucQi1#qvsyQds%vTd51wZMOQeu>?!0v zsIt$yxv3f%#iAfCfz})62}g9tnYM&rtZQfiv!MNkOlk(5X)&Ns>M0A*=HeRm8U&Uz zh*Ai?bnt6XSmEy9`>QmCbg0u7=Evg&w!f+uaP($D!>PO|D#+WDMvEeYoz>`JTJ|&3 zBilMJKEAxsOxovo^E~x0YHOZ%v4eD755+cle>i%{L@w|{->B5gwWl)v6WPgZp-Q~3 zqOE>#MrStO2-VcvU`erF%X2?$*Y}Scxd$Ob%Ndb=wZ06Rlce01?VqNYHp@$)oWio( zw98Jn{3{A4ZD2a2UBLlQ5_X_k$#6vA`oniI7GmP!9IG6ZC%rGAIa9dF7p?hYZY~pT z04cUm<3PwAo1x{5AJEa+IfNingOPny9-G=jH%k07G5nl}?1zsZk&55aIf=k}NO|S2 z{1isTqdx-Su>eUMLrV~kLi{x-Q#@LYqesUOUz7?_1Tv6qpBA-nNjfJv17<*XzgjTp z4+^J{`PZ*s19H|mc4^_~MAtvr=@4xZ)3utK+T=&iVb!sbk!yv8lb+uf#^RkxQ-mG> z!SV;1e*=*dWn*34uY*7t^hH;uc1VzW24?tK^nN84a_F#ySo*JWM1sagEJw+DnSbNch95&LMaSsmT3yVIqp31n{aI21&_9vV2^R-Rh zF@K9;w(>nZnzB=L=aG{M>C`e59jG+7xhS96gRB8YP5$~B-<-mn!|ge48X@gp>{HHW z2R}K}7|hpJ+318C1UDOUw>xw6w__+VnM1`fs15~f3hrAImpdab5E+F&(l$n~B#w+g z0<}q)XvB6cVB+j!q=~0$>SQ{SF|Md&>OXu?I1+xVlE|J(L5uqoM>Rj#JOk@*e|5dWrzHmCwF9ZMUUa2abapm5TS65L80w7!K+0v)MXSZaBYhmo3C2yK-dT#Cv2x>A>g-qx}yb75Qg@bUmZWwA&lJ^jFcy6OO7?KQZ z&YzPXgz>j>a2$b2N96{f#Lg8^vMW94n@)n zwTp8)aDdJlUtcH`Nuvm4LSOUzH~iZxSDtkGAXKWPv???xh#u$eC^c3}y$LBOB;X1>z=@}zqzT|+3CT?Jtj9w za!Q!QoA#|L7WroCW~kwp^RXv#c3@FpT#g_6r56{gSK%hhg_ssH8OjcM2f z*2h7GK#PbQc7`T6lb6!S$V4hsVyC#`Bt1D=10b>E+c=VxIqFvQMNLoHOiqU&q#v@=Nx{t(XH^$KD#GdHfZXdn>+Za&)zH5!1~6MJYY5 z8cv<0XK$Qqh1ZBxR&kz!o=Mlw~>pTQVcGO3xcoco84)ZSD#Rp5p z2NMz!f<<2ZNUC++^|XwPdL*_eK5pvFYC^>#2SGKwyJ2V%fJ|FDIpKD5BYCn|bg6I% zLBv452!vzgM1$aK3_?fGziS8=8K@<6dTb)@!>yRr<#UJU#2lV5?$(%CsEX3CN%*tx zM%sY^esRN<-J~$q=em{W@=>BNmMx1k*tiBwpCnQ{Ahw4RV{90T=;8IBzd(n z@*L1m8})VLyn38!M&ZDX>zq=_JN?6X>{a>&EV(+E80kc5#{GhCaP9>q%@=S5>zkJW zh)N6Y>LVDOLSmS}I)V0d0dpb1whCY+u5vOs9)kGrs;a77?G22yqE=n4+FagZHEG>% zyKTuK=nx&=!Yra;&o#jID(!1zNVzCW0fb4K!J>(}_uPMeUdOU!7qJ(TY(T?w)F2}e zelDR3#9(uyz~i|w4OEuX)ah*5zYXhNHvJ5+cc9upC+vJuuY%5N&4*XUEA+-|*m7T)b?C`&#PP>;Q9)IC`qJXSq3x2QO6<;N zq6-SKHGki+%{2MBDLGZ%eZLh-yVHFdA}r+I(;wyMytI!ox(T{g>Re}_&0UgMfr25kGiTdN-oBl5DAMNIukJK0)iqtUKznA%PeR{rowfQD;{$M`P@NaaR za6Zf^OhPVE{NRDE7@m3p?u3W9ElCBA_kczMZ$z;r=>QJ}y#?x?I(!r48|n75W->DN zivHz1M&B#j_*YKveY|dZ;T{*A!dw2!wY&X?sYW)oHXa&$i?Y}Nn5NR_!3XxJ!*BFGmWz{qG#?t_G_(nteJEGS>`h3|W-HsD2@OhB?%n2_q=#vUGY5L= zg&DP}z1{4Nm7o8nFSzQ}Y}LiXJ8K_P$NEJw{&Um_zc*s~H&uXEFoNZ+?5Ft`tdxV| z2YcKLlr8ivGWJy60fFebW2 zu@lUP6mO1q+>%AFLalW2AWM;isD6cV8=p+Tk_;xNt>KbWf``@~cq_EGs03vkdv^gn z@96auZc zPsa7t^ht0%+CHTCersDDI(u+K5Z_;V| zraduH`bxdrOp`TF{gu5rV*gWkXGLpVkKK?A;3B!_=pk{yAAXZ5`~x)+=7u#vnZZ=W ze-k|+c0O?HC+ZCmpp5}NQs{U#&H292iAQf3fmwpp<xH=@g7S7`)#|ZJG{&PQm zPUz#ZzX`@2O_}I}Rx#K`%K97^}S=_Sm zQ~FUi>)z-GI^Jz=1FmN6^Xj?7f#GY;I=mUm%n+oos2tqv^L1PFbEC>AK;e>a9@bvtV6{F zp6UT2GXg`!YSG|j!+1>w6++$`GKZR78ixRSoxs9Er}95Y*x*a@o->KOL)%|s%L+U_ zWHd_j5C=UOWeO^R%op0k8M^uc4lQ?{|Dk$UHQM**lDpCT@q%L&ytkVkUzEGM`p>|% z#lU{uZNtH@mQHM=-5-T$tg+ z1oMN^L!G(Nx!OC5t!0ZK57&$Eh@OzTYZP+wnup825dEi>z6qyo z32;opi{W?$cfX0uAjM0lm@6wTK7|>H-~$X{Wb505a|4-vA`V@lcOZblNMRMcGVs$` zC~MG{xkB1Qcv8|gLnedaTUHp2Ow0>A8NIx{ad_4vui9P<5E1Bzh&0M=REC(!sq((q>fnbpJ+}y4B#+bLE+#W1OCd96KuLvV=*neeYRc*OJ%0!(3umgG z-3KN)A8xeHS$kT{*RZHCJoC)T*wAbB8#>or_I5l7qyHi}J0NLl6&ZWKKc?7>b-|Fo zPGdODXODR{yY*2fTKUK%kLUpp-w&4ex*@8KCZPV^SfJm=*tYbkgKsomvZ z*9be-9fQY*38)v&pH~8|h2|Gcz?S;o_?ZvzN>^6?IHRV6;}vRxe>E-yjVA>jGIPK5 z6f}da1w6n{y1&|5KfGxKbHuNff^|ZypBoO}y1u?Y%iE=|f+vK-4w1?m$v|_6vk}Go z8UP4UP@|7rqeLhPV5)3p1dfujqJ=31Ij%d(JUu){AGf3VOerYPhJ1wG5r_{+GFBv~ z5YMiQ3+5jHypaS)Twjj9Oc{3JN*6snJpg>;3qOWU+6d}QM&yEB4{=RwgpChKG7+-? z5(AQx=H5ih8?EV@U>gS{{55Kzf4#sXKz9sIV33$KdQOghRWy1nY?mLyhlejt_(Ih2 zc;>qsT<4+g!GMndr1zy3z@UcGNv1(DKo%_Xr@=vg=v#;9IuQ8A^WDEi5_PlqR zD2@#BpFRoiYClS2ndT*zMfi33MSYF zlr$th`{>c5pocd2flUV_pEbg_l#-s2@jA~)f|#U$h&*Kciv(}EefvsKjb_Y2)IpLp zQAp4DD$()&;<&c5&G+wP>bv)p=q{|{qKayMXu!?8ey#8K6PD8kswjo_mNp@M;C2!^Tj&j&o`&XG}e z*rpvF9iK=nA;F(9(gQE&2lNQ&0}VgT&X-r=B7&o{IvEY)d#BqXC-Aum_lfg}<0}>- zrmfKd8$Vupz-s|kw+J&;64gMc*Nk%LZP}0<42H4-o;f0sxO4vJ^t3LUFQ^Jqp_ZOR zY&n*6hHG3SP!-b80XfxAaP>pWPu>OUs+TxcU;!ddbX<2kevj>G?=OD}KMI*`3j+Ez zbcJ}I1;);ZRYh${hKk~1IrrglB;a_F9v%hv2y~Ssi@ngiDg;Czh`#2zN=#BEH{y9n z3c-z}*cBlzAJBsEI>8-B_7dj_nkiC{0J;(le)Z+c7lf=7yizDdci|w<=5`rB#VsEq>EaP06E{=PoOw z?;8J4U1_yG&X0~N4iD9RGJEvncPf)1r>W~DorqLzet*9Mlec5|Tg#^WRBuX3z8U=1 z&v4h|tA3Gwo6{Yhfd4&-d!B=kxFKnd|IZ}uhPQ%*#I?5jKaseapYNufJ7y-nP6Ne z2QKb)Xd2L3%l2J!MgBr|ULMQ?*%#YTvTEYxl1L5?M>tJ4u5EElwCVgW+c9OVJ*gdf z8AzD!TYZ3jNVvQVB{4DrN%si6oC^71aKj(FCkx`G#f6J#E-+xS4)}ZMeV~rW(E_?iRDem@$zhD)+#yj)xbvM9GDHUQTxLqs=|uSazh) z`bu4?FwNnNLv?m*6kfSYP~}pz4Q#^!GDy*FIF&0~QkQaO@ReQK{NO zS4-JNT7S!(9=+?Tr+e~XMW6Na!sw0;QTLyB^78UJAHvK6$8NXvIGNp*~h zph@$}lRI5cd8U&F^i@q4MKen`>(Ov8>$$}Kj=X?y0~8uZfwMj2m;?kf`|JF-?`tPN z0L~|za@*{Js`BT7h?On7^7&e5LBo43F5Ubky?IF{_1lv>TDGNP0@U;UCt|3K46Q-5 z32$aMe$ye_^YUT9sX_^EZ*5{D+PPePlXfzJ<@vK7-iGQgTU^%P6nA?qyVG6h@8$iv zsHF7lC;R2lhowS6qm>}H4ssj9C4rVEq1haKq7{SxE8$xH9!)NpVgY>z7u3%LF-Gti zUNy+~xXG7OTQ9#vz7x1$uG_Y9~na zC}3d(LDJ(2Ad(PVO|Bn{z=-Hga*#Ss#N`nld?&6cA+5W~a+6xl@mh8Dm_m)N8Z=5H zu0Qs2FJXB>7N7W~mXb&Es5;TDrvU*YQ$rwamqrhtu{h0liBNk8^-wkLdkTjfRbag2 z^l5NN30T6Tk5>p{5;TS~qf_lahMvDQMW@WKny}bLKaNODUoe;}swxzyeY{6PC(HRP z6V>5BOnK7F(aYd^o0pRVA;O*E-?svP3b@}xK3*G3gwkjeZ{s-kD}PD#SY?)N=cT>= z&I$U9!_}hI?f*pdp(=F1%prnOMZ2^Qn;AsS$gcHch6pY)lfyZn{{tyQo4!-INcgvO z`o?AFHLvgUwsRSTC<20)hGgZd4^R36Ji``v0e9Mk%a^k-8Xj@=B8~uiqwC9;tnBS` z1b8#DvM>p(m36WxGxIPK3y6nw6$^{^)Erna1oByEFjFlqI>A~%l4`IqE5@!;LD!7& z?vj1O#EytIaR;8`@sN_Xmes>J*-^*hHSapx69&6K#JBm3@0-HEwN>Qhovk*zF)&J$ zWHB(LGS%YkJU|*R?pw5z44CYT{E7tIeGp_f;C^1j8DfmFJTbojqLCo}5U#xe3FwyS zAZKmtIg8U#ji@)#)#cX0j-|_Pv~#B3mXrc!NayFihp7A3jnLn^T!VG8L0HOZGf zb*c_;lVo=T$Yti-qcu3#g{E8~cr(yDNN4wwzfa?E_=6v!Nz4Am)0Vn#}Ow`|~LiCOe$rB)47py24 z=&LwhLCdtwUs!`_K~yyK9Uvt!&<5ei zN;AGgn|&iRGz~X^=6D^nY0$N)ihTWW{=$V6NH#SGcjK4_sKSA}0jW|!I4$(v+rXrZ zm5i(Xd2KBvB2*Hgo>Q$1G6>B=TTX;h{MH!e76**hW%;)o7ML0?^F^rVT$j$d64Vdt zfQa3Z*58H|e-!f@C?19dsNIOT8aggHRDv4WMm@ddL zCel)2DN!&y&T63Ad#ZKEwg~+kJwYUbq0J(T8@3$;-M7Kx1$R9dLIym1!*~&BlV1JL zKDx(wZONA-8m-L~*;>F9vQviz7Epn-f;)AU6;>$R-JhVn#rlIZgk*d{xY!Be9{@${ zXi8&4P>-Xj9>1WSNo`43<%_nq^QXxk*$i6K#DfR0|~C87E*1n~>1Lc_PKYJK~7mpt)Lb@Xu!)1RtusIf36H5<}&I zc;kPl&@jD#aaKO)i_pQo#s`~c%Al^huE#cR2&04dnxBIn&s-^2X!X&8JU%bc2)VylyVe0h0!Ku~yV^%L4JZ17@>X&$$> zEMKDfNv{>@`TR~4`yRo2Q43G`gqO32E~gnki^TlAbbM$K;7K^PR<&QHPKIB4I%M6* zO;zr z#Wk-}w$W5KyB}a8lr^;t31g4;`gx(dzShXXJHuhUCYdxwdVszcACgp3RsBDmw*LRS z;(nS1$2{rryu!f>?W0`j-~qPNqZ~8;al~c_NWI)19vS)C)3x7QOF&Jf!=b`9CWz|8 zADw=4q?e5zR0>Ait&3x6x4_7LS~CV_>?!am(|c5|!Alvy$_p2V;E@=3;|{>Zk@)oz zEV&ndysfHv=XQA`<(|Vm60eY9-ndqdlzW#OxK#x;wyoyw^&JkWxal7vzTS%w{kA1S z3JF?Gff$JlKSF_Ic+TYY1y@n^+avuZ8mHF&-JvV^*D&(GDKx>qyG24+Hb0OYUY4zG zn3_keI}<{zwEmxHHpx11p-WH0Awle7$L=4&%W^7IH9K9mNEb31$8u)KY9du7bVR)? zBjG*gi(kV!?2dcdgi05m&_AIUe11?ZLI%eMX|PCUdFWfhs7?&j9Dw>o=F1lzi=9a? zY@Nl`1?^M=4pB$npH+-8p!c=ClQC(vjVCL8m?43#G;;e02OiJm|F+E%Z8KO5GFvTo z56UWbwsou4h5h0@a&H%lUwfX{QKPIv(Q>9J^DLKhUE(U#Y3z=>70%gpHm`DXi;Sq= z^#D3U0Fo*w`S0$sk4pVb=0F)k4Ug=qhW`G<_a#S;4vjy1mQ8=iBFf)@+Ch3KOyJ#n z|3shuh9~TF?QF^x+$Unhy-Y&`5LQvA)F^sR0zunML$PMyN)`Q3%8^k~4l*z!^F>A_ zJ~JLDJ|uLQTK8@4&h=l8>J}TWg^RQ%#vJhp4i0PzibY{+15GM^BC4eR!ILhwZQlww z+7F;UkUQddz|GJJ;S&h7d(N{;^2TaDBFxBwH9ZfnA{^j&PZ9xgPuocDjHmNy)iJI z`T3(YcBx`Sx+rlwB{{7~V_eFq+wQpL!$>A}2zWbyL36#mJkc3=|0Txib#%kSJ4;_3g>?jZO(_{oUqgCA*HD{H+zI^|EYONABr3RoukE7RR6Ve{b$M zyy1yx$&m;m$Z^#k+PYywP>HKzRQer$tJC?ERijP*7ZlmPtgIK+aWQRYZu-hm?iSe2 zVtIgz3A(To1+5f=vE!K`veTFgjTwEBg%w_|oaZkeKktT007eJpmxWX~y`#9g?w zFnkN@Q*+fG4%Q;8eE%hnI)RMmDUyqSc4oX-ckbBBfkY)6!!7U{Ybp%At)l&eS{tkx zfM&G7*}Vx?r?Brp@&%!m&FQ@(2DOPUh8W>4J&Kpf8c! zJ3LdBG>_b^U=Wt5N8S zAaPfm!n&f%Zq9q6`DCIv>}pQSa7N-n>aU9QsFA1_lNc zCkOG9bD~lX;o^#i(CHY)agnGZ0>OHt`vih*5ZxH4=_382!GyBd%rLy3>cW4eV|7-r zE=B(E$rWY6V$BG!+IVUxHA+21N3}s|{Xv=g`~w515$~c1M34@&-E1&|68APhRd}2w zW|nJ?T#_xsErM}V%)?6y4l#N}XoO-e^xP7fn-ZWnB+?hY9ggn+l=y{&=z;llcR$n# zE}90VW@b!T zUf;1~$Ur_435$0H;i42qpCkJ9#z?{51awvZ5KPyuk;;weo4`S?uBZ3^q|YZJ9!xUm zHKVr2+1Y)JcwKfluO)H3u9(Z*_v%|I5sIB%&TA{B-J)-Hrb)-s1~x!|%svGfa|>2h zbdLpmrVu*{mw9XPw_IglvzB&iX|XXrF{Z71r_qvxEVPX(Rk44L45bq3pgON#AcCFF zuA9%VWoE#(@67#e&<{Kb+Oj;hTUd{83_lXb%351nTf&9lIXrgzx&eZ%5j2+x`!27M zk$Kbn6(ymu35WB#r#R4d?KX_{(={^Mc=BUvsB01iP{IIYjrd^fqID5QqSL>AjY6R( z&m@TeIn;XQ6O9I^W<*vwg0Dc9UJ$4UB3?F(*SCc0z8%FHlQy0_Ivub^P=Zi{Spn(w z62dK14swcNyB>Ku@uo-PA4pakhZURE?h5`SdJ}RaJjO9%rZX_l<)`Bw}OqtI&?? zBD@K>Qv>`f9&_I|cV*Qq=HeVNdA1LgGjSSQd0}*_ z#DvF2CNM$i17%VB4+S!Y5Ef@3UKkLH+^Gjh&%{;-CP3USXy}H=#$KX{Vqb>yzYdA7 z9L@jISLesuxRth+r(swWzKi_%PZZ!epa*EqgKfe`(a)^Z(4NBt4k(>1C^P5N9~~?z zw5SQkkrfma1Z@BzC?GXRH|D{(XP8J|zj*^ka17eJA6OV%w@ge;<3X7egW)4%T}hfQ z(sV$b8sr;GHTXbPBLskM4R`)R`$EG&z8@aT1R;=L>ZIgZw4cjWB zjw_OBXIrvJPj&7*yLi*u%1krnf<=RX?}L+|NQpm+nMY9}v>_0)P|#MjH5QS1W<}5595Z6V4{HYOgq0-3XC|gfy2jUzk?Hw(h@&$_-17lc>Of zF#u+gmv#D(44V<=m&8H;K|?4}lAes&)BlVyDaS{Mx_3_#SE(HW^?+}fVPW+-%Jji< zR#H{~88Few;$OL7RA8o7TAiY!563c?{msJJ;3DxEzlj>)Nk1*pY8#fHkwyd zJw!Z_xR$+UTV*TeNRlmbbv6040H2OxwGI+55N;5#!Zf(D>*bjk8Fg>yOl_S9roza? z)HL_l^&Ts}ZyF>a*sa7Jw9}%x1puJ>kt6y*e|`bRW+W5v;Ftx98BKf854aq;biu{t zE_}NMb&WC5zn3_bDsZS1J`9jPNm4Cd!JiQrhRg$il81zXbm!b2G9Dm$&ot+OYJ9A# z03bwqzXG$v;ma|kXMY9)Z6bj*xt1G1VnH)nF+CmnCd@!-h5e)g=1_P}^?Hd44&_rs z!qVVCFvLYo?>kC>;OP}tK4?@)&&a-3g9HaLv=ao1&^J`vc#>*;n0)=@ZgxR?Z(Rell9?Kqf0(ZC;n~ z!9yqK?8h~8-Cuv6su09obPMUS4sX7d_#0xPF3h!rS&WnB+2Fe$Tuwb+k|(feav?=K znZZSj#>4Q#3txoi;mbaN`v-mlxcCvrE%A0b?a{$>)bI6(YbRD_peQgQ6J!i?rr~|0 zAKAJN#mO)@{t*~|+Ae&MI{mT({8#-9zJ8p72hi)ERF+Vv05N|E$B@s$AOqGT5sXWi zS1KJmcoj}!B0Pk35>N}_-w14%wY`Y*^VnDr{*-7U@TUY~xjboxGYoSn27x-@s8-~e4CWL?V0Dj?%7hTra$4i~22QWJbgB7v5h_oNR zL=u2ORoN7`!o2`^23Z<+5)G=v72g0}M4=`PKVP6xG(H7RRrKT{&DAdVVBedjL6RDv zAbq|=Ri-aHp_@X1u+`1YE!QAN1-e*F2S0^lAjA<8-~8!AE|)HiL8q91qAUc@l1fG5c)ON@ z{JI>iNjK>iHjD+|PH;TsF5+jC`#H>#`g}blB3r|*vuo8xo%s2covSpr1A1*h;sKHK z_YxyF&`df%T=)1@Aq?ku)`O$=khrbkKKJeUpRL3@VIWJLrT2-eoJpPf3r(xxIYRF2 zN1AI+e+&Kfdoh}hy;o*Vv1pY|>@Fze|5LN;q|p4e$j1Grtfh7JV>_JguKxYfMY*oS z%;(@qB&g+cMV}1(6*jw){@FWm=ZOot%fLx=oK^@JxR|^}58zVF`FqBnTErsP1s_<*qo11;GB_mtz$Dvwtuz)Al=M9z^TtHDxM!{LvyT_i{!UOjmx9r+6=We_EP?`#{H zYr10<1rLIRXlwSN7?y0F>09MTY`nOwA46}2_6Xt@5Iu)fRH(09xk56HA*mR1_={YZaU@0Bg3Tn2RUO_`co!vmf{9-1pT{t~BDjzvwi)gn%Og25B# zHH7_`#tC>MXXww$=~NXuMy4pf+Zqw-*6-eV-R-z56u6u+Gk@}k|3mh@r62{IQV2SzCa+J<9Zc1bW+bw21@X z1K5JwLxtC^?5pr7OTL=tqSX@b{Abt?fIduuXu4E59t}&R6!IX%zGL!_p?lG2^rPqR z0lcOPk43}4Vf1Y0)^tTmtx7BPt+!jb$r>lhHt(`3%BDd#?}r=^@h^p3U9Br zx$K`G7Yd+`O7bwg%gPL#Kl;TV6mO>-9a|AF6Y{Ue8c6+k8v7~5uhgFNuiTg9==SZ? zR{T11YUYrVK&2LyzP8eZ*n?vYzWQL)R50inXSji5;8#%KVqv*S9eA0vC4hGjk6R!RPL!p%FhpYZD=3iK z6!t{q@UBJFD~eST{h@yJ=<%8KM&?!wh>5n+!ZO+*VXk%Y;!n?dMs&)+oHp;fB!-89 zgG(GDyl7OSnB#ueQC-K0B3doTqDBgW)8=_)%l%v$3O5`~2~# zEMa-@mP#=>2$EI2dyJ?k5;;iupX$yT7vDaySMJ`=16r$y)v@&j1 zUfDGo2tbJl4naOHq6%S8(#UK$qNNq4_v%FEbuFhiA4wn%!dTdwaRQO_elw$*ED z8^{N=Y_E~Z5`Eloy_XsLXl1Y0SsB@)0a_)CgGoufOV_RzA559tCfli>LHo_8wxq9} z!;!<`?(0flo6?Jh`jHd*M91EB`hg@6sj|+paTH2+?;mCi^`_!k0xhfVm@|r!l=M^d z?!pi)M#Q?P>CZmb)17Q(H2q?*M)5&t(+!3m>&m3VQpfj_{6V_<#EGYmw?dLl5MqN)NuTKt4lZld+Qi}sA&hDt6ot^`7e-ws+5FI{ z;_nV#km(2agaxY#`aN8>7=;}9ai~#{2~dmTyaEzf@E;5#hB-h*YSGIte#t9>02O3@ z1nI3SCfcoWL=b8X^+{R^U(y&BQW`khB>-3<#RK(75r86wbj<+KC>|Vq0NB7=9@E!9 z4fPGyvxveiY>vYi#|kB#uE43epFh*Fc1d7zr&>_3DTNMs}cHUT&2y@EzFi`}q*_C1mu2*yUrT=b5bHZ3xEnX+5(_g{O* z9=>$(_Jib2s-^|ym%69E-MDv;(O7ur%s(+3z3<)$E1bmPYaR!K#SYLvI2OIB^oM(7Tr;s0(a`5_HM7AMhdIZuwywTSHPZhRb=G^VTjxeAw zo&$+DZZHy2CELL=QkIgGq>NA_QT}Kbw_&OTy2l^r3b^QK zQNo*7dujY#Std?lqD=nUqdn+9_~lCo+H3;eL4l1gDSA9#9SgC9TLRJ_2yvw#l?{xD z@(@GQ+qPS*5x@-U1v2PGb+ZGk_C&#pQW^z zAG$PK2f79wjU+k`(BZ%*8!*}L$V%e8l+Thow&e>f1)xDzQ2hagr1!&OCf=bqlnt2cb$wE!>5y4zWME@EH>owfhU}Cq~Fkl4<|DO)wBA z6V1#>FOaC>CB>Pj^+Nw+R1)@A;LV#ZYb#OzW7U%F*RJ?S>vc5%1Ckk%m2e#GvC?@h zN&TZom1B<4Szha7RB+rM>iAbeWk{}ch5pCu8859b_Z)UjS5VN(JO)M6Er`O zn9l>f>MdklawDuZzs6(sOpcDR#a-)|3@0n94T5Izn$*%)Uu9BE^nb1fw=$`2&7+$p z|DTzO5oZPOb=G355S0lwP8uXexqH-fsfnxvIyHV_VRb%*RWohagv5dj97WfC8wWis zyU0`w!u>BLwGP#Fy^E7Y99!^6)nJhGERGW*Gr-NG3_#^hPsP)xX-I@7oPUka@?CjX zp}#x*VEugdU)O|NA zswkM~bN%-ZVuusoQEDAH&eZ+?QT86-T=)IoFiD9zLqn#0#-|_yu->>!N ztlP#j``LcI6xoy)K0e+l+L5*@dN3@%<-=M`!UZOa^UNcz^TtrXOmrk~Lp*~sF$=GV4K%cy$ST;|&a0Zv8J;$2-wnyTN_y>hv5`~sZ>MeUbI!>5gc zX#Cgd{JG+0f4d1yYu%?$=NUc)PrPBLw-}zBYzB@^2!60hu!VHZ|N3zm#G+=#I+g1o zQNE(@KBVw45}yDQ`v*70sPh`=9k{TH2XA5_T3jseZUy)^e>S{-uPgUHTp-Z9g?hWZ`?aGUYfDc#ne{{FJUED}bNx#4jZbXMk01ql z+Q0tPrt$Co8}5+Eo2@!A>g=DkaPj-sRsdw)TQq>Pj9jTigo-Td_g-_S@=*Ijai~k! z;$hFu#iI+l42PssI8(2?DYFOV$n*6-J(5?GY+`xjh}xe(pcd}y>N)aOec-SJ^ zrdga@+BDsSZEpvm%+$jlLmN@D^4hEc{=xm*Ab)hw5Z71I zn)PQ|`;Q(ju7BHm;HqNl!t25pSe$pmFXL?w#ia?Gx26`V9}RWY%f{={8(-B{8Vh|rAM_q2nW18cI0*>OB z-wF=(kM8+=@WMK`WJ4{y^t|Jto8`$Wj+!w{K^%&)C)FR+rqv2h*77iRGE!W!8vJKw z@)Iss%wILJ;wdT3%s*qFQl9NWCibhYr(a*kuQw>08{V+`&yUw9-%+%1Dz|F3QXEn~ z{N;V&x!6sTS^_mj*E`^nkSmq|3yC@#_WdU40b|FODrXyVu`-mU~M>yGNCjFh0k9L*4;fHj+tDDYbAIrbl*dr$u6nr&5 zMF3dmQEIGKMerw4ra<#pcj?K=*@^$YVZ?_*z3v&U!>)EgO`$Nd6z1l@I6y)F658%J zmn2J1^t;G-2_7^HG4FNt6bdrid-!>Wk&S-W1287KkzbG&RF|BI(ZiL%LCZWDyX`;Unhgxo8MGpoKVvAbG zPkl!%$WgE7`ma7~3Y8*NeIybaKTK+Q`)GT7O>q1bMwhklq&L|*)H`oxqZcV<;Q43z z%mDV`Ry)iNo$9olDB^z?*yESm(bL_q^KwyAG1K`&6gC;A5oYzFzT7EFS{<5eMU%Z4 zLP74RA+;8dgw@|aE108Ix-ar7DBVA83IO2K+?^(-3uJK#2#@BgC6r~W~FFEt=jq*cN;MfOm& zY3?*cZt{f-C(-Q5m1*2g*na$f^m^F(+j<78eG$XacU@@o)kd40j1^TGGT|GBb)RKG z--FmFM$K)@^RtoIH78u_Nb|krxmC;7bES43|86_^m?h>kGsq>pPq<~fpw_vhOkI14 zaE|@^VTj@BbO)+Qok?Y}j{J^@Z3f8iuRO=*}j`gD^} zy9}&;UXG*r*>F(0QJ)itd}cMqlB7de!Nmz zYgdpeQO`cWbU%;Kd0tr2K%zlof!p=qxU9xa6h@p=^iCe#O=hGJ> zs%Ot${IfiV6_`HQk%2DzD-aoI1?peEOnCcK6u5EfD0+3xN*70BwLlAL|Lr+kNdoSa zfyBfccm9~YkxiGqJPfK}( zCOzrt7Djqj)^wX~>-X*xd&8fpt%Rkem|Fp|@n$P4C^Ys|dDCdEBguZ4OMt|KLDjZ9 zo1f=6r3-C?R?@Kh;3%P2(10T9;`{2Z@!K`Wm$vSBjkr+EwD}+~+-AnziSCimO#rlQ zd-jUidlOF*8bHmNV<*sf>ck+1qf?e!PNd)Ih@bWRxr8=m@dfAb_5ceL+(l(9V9i@o z9lI}l^Ip-eG5_0{H-o=vQS_J?jpL15xH?;XN$JSMh5p;z?Q!beiG1PNe4&dHexS&4 zA>SVAxzImht&%Oxti`6KBB9dDS}N4?*g;}EL)(k;F_eE0GMM6nVXU_QQojZ9>0a7= zmY7N9syzRn%Xx2N68ayO^B#p1$hqGNrIU1RFk%^mydU%rpB*_r8m{`MW10Ee8PV(nYUq+WkvO91M3A8FME`A2p~Yxx9MJJ|49H`N70?o*AA`+}*sy`oMuA(viCeZ*>9KkO1b`VR zCZ8)kFB+r}ql7`3)?J$GHVCey%*-Y{h}w4kK*wP|efRjQbgTn&MvfDLo#qL zR#fMa>a9AFf8KR>r^1*dcgzDZ@aXu5kx@uqA>8G8`~VV{iP~uZ zM`F%uPxd)^u#o|^s-+o0l_^7T*n>mVvAaiZ$p56zr!%4V3s5p^pKn=@eP%Acz)hgV%NDhK|KkS|`G%y!`?<{g;o!XmfQV=bB>C z$AOT{^fa9F4cZN1(kE6=IWA1%EZgMC7P#3mU+&(6*`A>krGWA1bU-ASX%7T^g?X$I zLuYtDQ2KjS9&=$wJ)wL0uU)lu1=SyYmF|TKjm*4YGUiK{*++RReD2fJ$4e6sUdeLs z9#fkG&;#{*fThWu8WKT_z9yr8GX;ryCW0y$;&T>$q9cZnVlSM?$b1O@=D9M@g|06Z zUQVKx$AW3#Fon0G36qLp6aJz5uv+0PZ0;yCLyn08X9%N1+^4i&*~kA;mvBg`{w;zy zM+s&}IlkHl+Yhh9bVcec6#e8;g;R+D%s@dP5xG%U0w`U%_nx2MCc;DtAZcd@ut{u~ zh8Elc8I?+A^|70S?qTmHIfp?`!g(CTQPqmIVxFF!c*T=5N)X5!J4)B{?wwZt-|u=C zGe4qw&OzE?YHf{g$$*HC@n36w;xl=Cfw`2`k)N`7611t;iyzBN)VRCzp`Z=&G#=DP zn9cZ8AN7+gacdH~@J%E6@DgL*g{v6~8td~DGT18zSN@ioeqpB1OT0KZUbYp=h}I}< zFm-B+W!!1_%^B;HEs|OGXTP! zOCYHOWD?mqIM%5fL~q}@lh_M@+4)69MTMPPg|~u-i56TfJKGv0Dzi=8@f3Vc(cH$l zGN(^lcRytY2}$U{Y7cklFZjPp!$yE_DsN)47Zt4vWD%0p0@xP7{t$-&F;A0i5@?p7 z2_+G`gxCdk<1cwkS~?zkqTiwcE}lm+0J4+r#KrBw=N`sqU3V)ne?%RPbMlV!f3G~T` zVf;!R%Kcic@8GzF!f>adfK4&Y<=2tYT;;B-hhxU$hK90_P!#<*vU1^)orF7^Md8oL z#yal+w_S^Z-~V1}NNZFr617p&eKu+Qw(b)(LT*5$+=~72I9v32=Sq8Jo z^$WZHF3)Z~a_qel+i#`8s%)b9IRhu|?bu>&{KC<8vefS(n`l)@t3vha(k?;_ zlUW8Rm*xr~Uf@1LLdFi%p)n|+N08=df!KS?>HYin-^KOLLwo>CV2Vmg{za?sQ{x?J z2KB3O@+1XVPO>@;>Uo?xlFW|qD++oPw`Zf9IPLcIt0{$vY-_YaMvUZ?S%+bBunh@ zp~c*m58-1`5oy6KS62~aH>P>_&tOfyiWeczEHkdA_aSy%&uCI!S{r@(r;E`1vU zf(HJKkN}K=|4Jc$V2wWgqH87sYLrH8br?$k670@gJQHj}(nPlffW=v>2i%-l5Z--T$WcWj= zZ3bKv0udF}Q;SF7{z*X51{4!WbN2| zCn@ipZt>aU)cLG@hD0_lBH4s`bb9UgG${$Kr0{$&j<> zMN53Z!ADU~4fPM_caA4t-*WBQs!d_WtrY4o}6wYbVN>EP3n|8%H7 z-?w#GJBA>&98XTp4Q1t*S0aZu#d|_t*PTaR$Ne`nwFjm`>b& z;UivRSWe#TB=RNa)(?j3+?k zs#bTh-2f00d2$eKia`31lGimgWQzzl_kFmsho+`lAf3ejF$AFch|S{U9sS6A5?p_N z1c~125&cMefR5)@)xwa)$gsWx!k$!8#ZoyhZt%|h7XH~Bm=*;nRWu*$&H~HD!?HeW6)%Gt$TbP zaJYXuc3r#zjMl~SGN6EhPrpk*-+M!Mk}X6&Sj#J5ukpU4BN(wJv#=_l_k93s-P%vj zwe({VAuO!QJCzl2XTJ0Xjf0iH~<#{jm|r^OqdHJFv#AWAS?(q3atjoFX{f( z(%znq%7Yhz%F1e7Pvk4BT-cq~R+a0hKEY8@2A2ji**eVnBJanoh?P~2oTzix+@|9` zy)qGUym){4kzv|>-p8HiQz&ZD+xvg42J=nn3vnJ@V!1fW)61A_@ND^iCwd-w9J`Sv z5Lg@p<{*oTuYC;cF3Cf+{?O1MPtHJxIrzY3DW81d<6#2khi|xS#4b9Bt>b{3P{4k#9ectpaFN}DC@C5t9&5kYg z`!S(&p|9lfrYFX*B;1A0d^`ceG*^F z%k4|6u9r1L^3pPP=xZ&rd5Zmf#NE{1!g^oy%+pFITaz2-Ez+8LezmZVU4QV{$tG^r zaGY!GNi?@NJ%^Z}0U>A6uISS2llAbZLU*ChnZix8&%2kt>eP+i{!@QkR`lk03cp&l z7v2i~-5v0eJAMqGyJld(xU{s?w;G(e&jj7WUn{pj zdJb)x>t|@L55<4)#CD(kVB%|WjA4x}M=j&^+Cy)4q%=DxHxF$fwK(IW)R3_NSy`<6Fvn6WIz3~$yjh^Jww zlgJPJSmCTiv#%ECkT-e&v>Z(6@XSEB)pT@dxXj>nWJO(M!lzxD787DgZ>zs(I0Iluk0IE{B?4e zHsw%{W!>B=e1n2m`}+FY-38GepN-zn`~!IrJWIXYyjRYWjVYAAIOugqG$;7z07&i> z7ak^|kOx9g$bc6lycyuZ0bp`TU*bX22+f&t9 zUw@|aT4Xr3ZXH4#%imw%HH?iZ!#|ugG~5Y{(Z&&8M&%nnZ*>-sq!QSyJCT)Jd;)r@ zaHAObi~M|isBy2)$jjeHjq1Qu0RQVhOeNDid9g$)ib7)vsmLG59MCa6KuNQ;vxCB! zSF?nu@i83b!Cru`e;NWNNR*x)8N+qRq;_JKNzmrQllE$0U^VK?!3!`m!vRnTX#*7d znzohDs>nDGLqia7_bwHvO)|CMvUe0;*5kGO67Zms2mH1IDp1NyB@YUBio<$2lgrdn zq10$Pp+T(2lZs_lpYh58z{Ee||Q;X}7#FduiJ09peMXqL(#)H!8JT4z}x+vd!hL zxPRBxx%5lp=#e9X)3#SeRb|cA2ebO*Xg)L6xi?GzAlz(k}AtfDfRe4Mum-wUJ~pZc>da{uwlN0<;`0r@#UL0 z4^gBGFLiJ<84n{ODsu7!va)c2vC=>Rw;xSl=WktH_teW=klVlvng`bmO2WpI*olKt zFv@TcevP5KTcs$pbG{z6>H3U+iM7U+k7AKYUJsk9S0@g-#u&r1R*O{(JnL|wz8r3oWF zXuP`2dHue9vjsr$8b=U9{3H6wlr6owxD61nrZKPyYga`0xXr%WXklrI^Ra>ClVGGG zV<82ae{++iwa1tq;aE@ESIGl!5+5YRM52Q762 zWi+>I{Kp)oVruB5hc>G^<`;;1EH$)wSBX?!5_x&|&SPttlQEQg+=N?pH5fAcOX_MV zb0uYld*u5^CuQhn@1*s)192C}&-?f9n_F6P4SZ&zmD}g@CUovRIa{cGKOkd;ICjAv z*~}eHn3a2~&f|Y$fX5Ff$2`)Hh}Hg+=bC$_grMLStS-gtm9gaa-~ka2wc8dxqr?xS zMnz*keq3L(l%zc18s`|5)xg5jedPdVs}Z81y>}1$nN@K#SVL=d<4>q#d_6Sc2riRa z{!B&GYS2yq-HbUh1}TcBT?MdNqV`bTG{*^M+o$yxQWo)b`bC&&DbSL$@85s=`-`KH z`k+;??aiaaC&SE5R#kKjFH6ky1`J!XmM8H{8uBA4`Uo-JA<~f0{dnUQ#~xFg3q9Pe ztE)>CT9_@EV$uYmDd5jy3@avB#oFvZ)XeetYkR!>e=!^M+7Br|oC>}r`sB?*n-cBz zEvnSnQrTx0&gT@Y8xvD(J9|J_kw0i@?Zj3=F4mg^&xOexv9``OvONYVqw zUmTyx7zq814i)C*B>-H8;;9oT6DIp4mb*yQ34`v(485E?t&9)=3z)CMh z$dU{H8d>vN_&I1855WS})1zk?CO9bpSPWLG)CKjC!WL~>-C0XQ4X!(88AxcuNFQwQGKFQ^?P+X}N9?-dnv#spa*-fw2UioWSuA=^z_ zh7SZ6^BfPK)>6LvM}jNbaA$A%V^iBiz4=$|w~>2+LEjPu5&9=hJ7H(QJ9t+cpt?;J z#+IYvUCrnwx;$)=G~Mv()hWTpSFdN%#QWgAzp?*dAWlhg{SXS?23WRx$=Sul9zOw0 zO^Qy+_$w4MOao|#x|r4@wZul+9x$KB-xVvcA($QIf%u@L5j*~)Cr`33Dr(X1vc%Wb zjKPa#fvpahJ+LM0t5x#|6iDO9Fwd&T{U{hDnyd=2N1u#=pL^h zC>voBk<_U*gr7s&YHDTG08b+ZpApa#5qESE!}9FWZn8}ihj^V-9_&IS#R)>8mpwg^ zP*i0Z+_ki{L~m<6HgoFCnX&4o_Ziu?{@SYA#Yc+_KH41C-5{jqx65^6-*3+?`$d%< z2fIbzr>#a7=Ve}8>io4V#1OT;>7MVFZ{_)P^4sK(d6w?bbI{u%2KCOTb zS|7R5=E>VZae>dZB4)Clg%y7Xq$@3QGa3C6`rURxwc8wxuV0_$1LsoL*H`vvU1xv) zPKYegEMpM}HSlLE;xF#)Jp|Sj)&#yCdr-x9p;;%kL8v^+0JWf!*fi0k^(5ynlzIIj-NO|ytZcM=H0FT#HWCy-oAZ1 z{OJuq(^Bm5+M##6i<*b1)ggER8=%2QmrDLLc5FT!v)5)M@>4v}#8*RrLwfp`&^&d2 z^uUs>CcKFlV}e`vklZIwhoGqvqa7at#lfq55j+4^CT;e@i${`_>>V5eiZ?X1wvr>y zZLvch-|9MS1?WylKpbT_Zz68v{xc6gBT@4u{D(L>?O`;`j8ESoX3vR9B8l*IXm0=x zn?r97B?(YX4Z3T;kKhObvT6UV_psuoLICoxpqzpn5Y}?CClb&MsL1n|z9p0&Ff%jb zWKFpf9!x$JSPWw9hR_^A2qzQO32bck!6h2!4fLChc+Wv@kZ(`89L(sLsq?W(VD0bB z8bQMj9y}|Uye%Wl0Ew81=(}5=h_jNT9U0@$rS4N~66P2KO{IE~R zA={-9E%5Ww$Jc`U(l081(5g(+YyKE1%WtvAuf)4MuhIM3P>=?a%$=K4LuL z@yoq_FI#)r@9~TblVA^>U%`$C10W;?do_V(h*pw%*ZyM#Vm2ZgiW=I-p0Js1R0$KKlZ5A!aK+em{yMWh{v`y&y5aH4Uyyz?RkYooH%0Ahw2hgazqRwM9 zWweGe+-k^ydgLfYXI%1N&db4*QVE+XHcCvrk!16gaddrq`lzDt`q<{TM;QeZu?e9vW{?U zn`46yW{g6(*D!3&E-nK5E$FS&g!~@wwj74ZBRdePn_~E6-Q@`k@3<$Fv}K{V{fsz6 z(Fa4bW25xAQ}aJZhY8d=sOnMMiR0^TNV{4@|NdMr8$XE3&{#$bff+SG(ckM@Svb#WaI_u<6;(X-9d15i zoUOsIzkTOUe~f>uw0wV#7#J9s;S?piwq)RFg?`Z=r(yu$4i?uOS{w&Hz%WBfFr@N< zB*g0BGW7kdC=x6%V!I6Gvo`kOT{OLT@c_j{wCm#6;O&Plc;yupF+sDO1O$)pt#T#% zy!jGlL!~F~oYY1?5}NY?#7ri&qtvy}UI@Npg-4UDK*4T~QuRlZXj3}K&jcJnfJXoY zwV>9#>3wonpzt7KePTVvO3ed627&4kyRX~W%0arT3#%Ct!Bh*)71k2&#sg~ALWWmt zc&UeTgK#NmAOFMa;liC_89v-B#f}zLKKaGn>j*%FJ3z4x4>X zEau>rmPfM~0CYg%=VK&Jc3iVNtzqTJHuobjx31%Zyj)em<4ezkg6B84o&!|_*N>6) zZ)9A9WjJjFmzVkgx)WUd595>Lj&lWZyJEg*Q#x;qWf}I1-6%IN96$HdoBq2ueZ<Oee>fn8{m=~Rvz)A~w|=!pC3*5NyfMnq1V4Jj4!E-xr-|H+ z8%4{QU40NF+AsbKN@yRMM0!AI9Q*d2g5VD*FEl&^6f?t{`7fRYJK8W%WCj_g)wuYu zFir}r+gsedUt=8DycwpxVMJzgU{?@mE;T&cuy6wi^u{LGTqG?pA^~q0ZUTaenPJj2 z!?Ywr(`|&D)zdK1=o(=M_iizJSrTvf-Ccy6TOKjreo)=|!d`=L=d)ys^|3X;rOW`? zN%As19*k{hs?0Fcn<3UNYl*CmlSRcP2`{LN-pDXd z$Ux|tf{F@&?tGlK{~~Bahl4!-f4}JT^g+@@6f@A%v!lta#f6l`-NK~(_|n^i+n2z8 z&hFDOc%ws-q0{$d5ct==rjE4|uCvR@(v6$qy*t;x+ZY zwKQyIT|a+BO>I70DHXZE`AFT{1_H09xmgKQpg#^-x_wIN-V8F1Y}y&>r^tN+*=t1F zS{u4rCSlgrh~pry6;2O2^d#hfK(qDT%+ggW{1ELvrNEJ>@NhX5mCXo(PhI-% z8j2Wkb}V_ux&Nz7rO)E;F657dMn^v#)8o-P2APF2WRrlJU99b@F8@sKk%JlxI5%hn z*5*lFMpSg-j=-^Ak*E0dfeIycj}HF2DEEuUKICNWrncQ|*V}u2KZ{HIKS}3zP7f$t zY=7Efk&-bi2nBU#N%h9IA0sD+Qcv-QZSEEKTG%w6X{PP7r%kf7`ll2|r)21LM{p*G z@Aj#RkZQZw_N^*EMl#=3|7#(=-JAI`y^6-wNxA7z3Nm#d(ZK?v0Fo5R>WTUJd91K@ zP205?)7*0FiuU4PFGaoLLtWl)J|&pQuzSzns!yk)ww8FhZY(!dZ-@ZiN(tHny~Ux$?=_qhdKZ#jE38m$_mVmyjIgw(A^m362ar3)+`e z>myv;4m~>-5&rOe>363n*NLCYU;6iM&Jd$;7j%#QDw^nbbK8cSJA5P3SR#)Inc7*s zA@g2fze;cNA(!$ci>E)ePaJnjxa)9#eveF9KE_j24HC)&e2lbHx8{bQYf>&0EQ_2! znrMAht6*K4gnE+FsVx_7Kl2K97C|$50S_@AP$JdC1c4lJb_lL`u69_QXT19PyE*_( z{2DXnloE^!{2(3@x+Z(z(JaI@?-$N}6dW+g*qW+&Z8&JkNaB_jyjYasKFMkds%6UFI#MnQ*6aYiKNivzTo`3Ph4+?1jn3u z*Q{#LZ^Qd0>xj_Iz6u6(8U=1E_JGFEe@d!g;N;(Y{7Fi4YiHR5>r4|?Cymsk_Zz+` zXYlV|C`Tp0b4c`2pbhcjnPawuT_~+WF4S4EQ3>xF`4!b8qj@u=eBBuDq>f9z_EO)D zvTti9+pnAYX&Ss?3AcXVLaq8{En>pO;4!1H4_l9udiOKFZ9C6?9va-LSQ<;wO?+D3 z4D~NhQo_GM2;WI6LxXMe>}7G4&&jvsV*@yo7WkC$cCpyk?0yJe~N+bjKs$900N z?`m?W`*+&g@>@ zV;@-0t82YHvivf5$>jYnHU10pVb0%z2HL$UslHYvi;pL@1v8}i2b*8M(&8#Oc>3|C zD(MBuy?=gCh6bsfId?i{ubao`&+;30=y-e48144^atPbyf*N_vot4PDyr z8>Gp~DnJF6sArmeaN2A?af1iA5f%-b@XbI_56lAOZc*4HAH)$|jM59zt7 zYNC->+$QIE!yoc(i)T^OyvgY7?CiapF44L$4gO%$oXdd^gH`?9xjw3x{%meZk#5^? zQu*B7_p6c_y4Az{J3f3kd5U5nD=0^JXl7Ga@w1k#QqmXVE!%@X@3**kF$s>3Qq$uL zcgM}sE}!YKF`k(2WRurwm}fB=V@rnjE_2#DFJ zAYf<}v5=Vf7Q8zXCCR&9Z{4D&^-}+fcJDLWGHg?ng3Zh=?%q$`IL$KM zYtuJ;Whn4|s(dEXA&vBzxrRRrA?xZJjUum@Ppw%oABx;-bg{ET5lsSxmm#xFhfwm~ zz5ZLR$9n@ObGn;s9@qU9|4n)6e5%;s(X~J*s1BFEx`bDLL20`Qo5Y*-F3q<1?=M-A z&t)gCDCl}M@qbN*^P5=ya}!3bv_hy0TeDID*TejhT2V(aa(T#Z_=#=}GP3Zxur19D zLS45NW^#h(08z9@)Y<{HIDu}#<}I}_)FRWz8WPbt%o}DGE_eZr08hdtD$1%})=kh! zV3MGD0A-%gO;n%7#7WR)jHl?^OmR&$U#Pu!3{Q)u&6g%M$WAC=EEywcSaUNi9F;)* zBrzMtJ0A?(D}5I5TF1M0#Pn%9SXEiO7qQdlGO^$aV=z!ZY+%15;|Lm_Ktu3ZYv&k6 zK_r>=-2_}3057Df(E~Ujn~n_)`$a6>Lm-M}*s-Ij(seC5cCVD%V*mshycnO&N}U%D zKwQpOk`0Y#CzxDFlv4`uRd5LUK*fYU9RpiD9Kamj_@w?AS{3m6Wd}TbQyuD!kt;DP zOZm*b=IN!T%Z*J_D-Daik8Py(>`ff1wl;ny)}Rlcz;u6OH}iZhMrXSzypCQ3vu;-gO(O7p9md(C*nHvH%#f3pIyi9 zrIGn#vcd1_J$XdJJqJXkSpgjGv64O@CccBhO3Mme+DZ9^E&d^V3vH%}R1P*HBAkvw_(wIvDd($H zVihY_6mrF7Pubc&i`FS?{qQ;_<l#Fup}FNO(%4W z#Rd2p|8K?XI!t>wY2gsx$LoUCm%!H`u+%Eg+>O@?Pc2G=dPXk%4OAo+9rE23kSG+S zS;ymV1uTkW<6+n&A+g2g$`08Nurzv6DDJ86ZpI1CVp{|x8b_b>Yj!6 z*7WZYB?2d&`+B#ow>B}(Di^p?{8;+MAH}Ubd%oIK?QVG)+bXg2tfF?sT{KnrXWFqF ze@pgkIB@mlx9Ow7(V?^%MnMs5Cig7waYVU!OaC$}7Aq5+bzaM~dKBMM)<0HjFfLuf5B-gsUeD2Ez~8sGZS zB4mj+eCAbclO@&G8QQo zto@No$Fu%7I?2vjrvERUB#>CGj6LiGUA9Y4P2F@7a%xn^@}>vZ#*Fh0TEl!yVjYR6H{vfR zh&EYE0!f|QuZi>|IN9258-T8otz0&i(6dAOme=?T)Ce>5Es$Fhu_2Luk`hMt)Z>SM zFl9jV7BVUAyak$+j3nAq9j6Mm)|R8!rWfIw~0y3-oAiB*52Me zDhSzk4A4HaiiwHko)nIHe!H<&;qb&)P7NnVW1@as%w*PAc60xF`TEdbAK~W)(NlaA z9_9L?tKRxcmaksC9^_b@Ii9G^XL8}l)vZdYnz4b>f@d~7WZL3!nSDAv^OHfz-!dIM z$wn6AmI15zzv~j*@5;V;8C~Q0al3?m?}+F3 zreT(eWsv#k#w*Bzej^BpRi)%VhoG&1s>}?|UYnm;*mxp9i>1e-5-vfZu(e6wh8h<4G0~GX~_>KlXW%R2D3kZvcj`L z54-^=X3pC`U(wqPqe!vIEPKy4G>PZK;hJRCr%Fe5Tmm~fCpjzys;3*E?u99gobc$NVftqOBg+3_`c1)~d^08T(*%%G4d3m{@6D(Szs}8!i~} zY2>Co-@8vZ&x8u?4(0pS$kz7FUx(kpPTQ{}zK30qH)e1=mUsSmqPMH4ChhmDA}^_) z*S|NCj(i#XP{Zc;PLIzD$!FUvL)5}9Oe7_1{n9DIOPCLV8DyD?%F6qNgtl%A?Q{-* zpR;=PiWDdVXhjqh6mHcsWvV2d;u)l&AAf z3J;-xQ7kCKqb}aQOd4aTL^p3g{wDXzr`z_7^DUP^6(+ zQ}O*wBw~C})`;?8!|UAG))vv`1Tza{rd-V(Nb5x;Wev`CWiB*M&HHkf{es zCo$)FPNW1U5*}~blKKy{VP9aK+>g*^Na7$=gu>2TOhI75Z$s><}s*+J0tQz5QMn*=i zVWkys#3;)0%D2LHvVsBVI5%cQ&?8ibV0R2t;zA3Tj7}-`grh(H#JK+x?AcX?Y4|bX zQb9T)D-y1fGl#%5!ab1U0(~Y~3IqJ#4oeLDFcs# zJv0Rz?RI$h@g&IY=G9_B93e?o$3S}~Q=4VQ80+nVF-afB<4>k+@GG(q^ob$86cIQ&m@6wiI@Bkx{yuw8KU*8E{~ z?D?^T*LnSqbp!?+zSxcGHVZ!6lHKvXYs{ic(TLG_`|rQ{6}A<}-9`>M-{bQMyM3F| zv-Z!c-~Y@k1W(DA8kB|?jsS=Z6yUmX>(tD6Bj@rr?Y_Xif!&KOUtfSzZh$ioP&ZLC zAk!LB7EV=0%w7(eIe;9{Zl#*OzK1XdVir+V!hU<6%AWMJv;vDJMj|XikP`*TvqYDq zHRTFif=G3}nu{x;{UD2DC%ft7&qyKg8h14nh8YsM1j7x4uCK5*npRLhEhXg<;x>R- zr1~g9m|?y4e=J0<<8${nhol|1`aY0K6r?*?Z-I}Dbx-#(=quSf;hy2+Q$D>FhfxwC zRy3%|WV;gjPi&cba3d#vGo5;;z7mJHcnl5!GDBIv>caLad|cv$$>cRf2~YhF00ZA^ z@1yW=P-A;YxFXJZifVh3j!4Ss+4{y zL?+>@+{Y$4Xf9K2+M#gA4+wekM9^|Z%pvU(UFtd{3Ub-mQbJ(~gzIUllhnUL92_4? zqGLFS(F_~pU7$oX$?WOswZE>J79JtE3<=4}ufboZ+~THyNO%O-;14beSe=MV16pt= z!JKHA@e!>e7L4KBD1uoA8<_y45>qr$&ykc~1XAF6#riTY2%|lIO>o03?Uk5l8tl&e z@kU(59{35NSsX!0YTwLmv=gL0Z=Dq!#vFWnN+)LUiAjDu!9vujhaUR-gBzuyz*aQC zBE<2D21PF4eKXsMFf?`ZL-+R*30(2UjW2OCAx!#2_we)vTr5t5CqIB93zq6r;s72; zM|u$_l$X+B=ZQ`(G=WXHfF@ZL!ZQbVlrVOK zUdqz)9vSBqE7%v3wQo%-h|R`e{_ptFw;i6M+nV`&LLQ0l1N-rE?p#Au->n0CUECV2 z0^_gG{yH`JxLZSzFUqHItaay|*t>_`iF{EHI(^ByB4%9n%$OZQAe99_WJ1nI4zXvA z7?JNI!W%FVnSbeoAP?+x?SB{c{?3Y`^5h*^MzmORXDf1zUz~yS8k#M%@6a2(#JwUd zUWzwTcmA7p8L%I*nE+%-l>!EdKdGTS`v?<;LDhXo;;pQ>c`wkmDI1X3Ou#!f+Z^_- zS4zbW*zFUN%S{Uj>O458BLtceey#v4$iMg{kF@}oAUkOLCl4NM?k};+(v1rZ)%yCR ztaPRYS%8>t6@#_OwxeYnLGgW_igv8&-kC^e$n1A#fut@KIwYV*BCzQ2O7snduSCa4-1dQ+&qHf}?R6 zqZovrqomOMXMYL{tpwd$tY0Jz4ZQ17NRPzLZH<@So6uX-3uGz{6_jiO25){HiwE%G z>v31RCVjlUXIFm?zJu)%y4jaF;JLCd0uF9`^JWw{TngZ7_^2iHZrFgLg#L~M`dY)D zp0d6ZV->%UqVg2XzqqWa-{BU;;>i}c;k?lRAz%>~SHy<*8PS5m&u5Lzq?q3*$d&@2 z!H3c7nEigLpU;7EOVqo3I!}?0d>__epmiiA0iTcTAqQlQhieGQ;(YJKqL4R;iB}s7 zLa+o#3maM{&Ws`k8{Y@cwd0ruk-_x{##803P8cLHl(hrKd)v{Ghzd`J3*dWDa%Q}5 z$67c!+!ml}7I=njvJL_&2oGB)+^%-}PQh`a7loV@OFEwzuHZVfqW|v>%MuoDfBm}k z{c4++KxjJ{79>VAFw;ifS~SFSU4riD*1)vKXk z9ca)yarc_}QUyM)XE}fSE2nx=v#e(O?u*?5=cP1O~ zxmeTH(?$4tZJ|K3@`*^H6A|l_!^w%KX)Cdsqx)Yf>Byl#3*;8!HbU2kdEWHOl{O-v z7hPN4sJWF>bQN(< zlQ@j&C^&*hoLP67lCJ7LZDU#zLvsg!9X{e7Iod(?-psFM5 zb8&=@7PdFXr6vMOBSJXAZjlV55Ma?tmIj>|yiiEw(bHiFzGP#Al-T`5$coJuYDi`P zMzTxfVypUPB&wTa58@S>M}NT*GO9?FAQ;GyqhNtbb>!-&ShVFZo5}0yvY^ExUR8AC zw}Yg9Wjss23uosX(#*-Spvm{yEdUX?1~LFf)}UF29=#4HFm5QRAdrAO=>^cvjF=np zoe^P-n@Scy$SwTf1Hvvl-3LEJn$<#r{z7XRyTc%|<=}MQO>1n8s{()w;F2}YJ z`kBWg+jZ|}k0IcsQRAr^MYQgf=JBIa>4}c7&BA{Og>HPV*;hHtXsy^samnWL?5FhZ z>vr0`t-tKXT{<0a9CgeW$=@%=yQBR@oXYnj;q^-o#r9s0`{LJN{k)j%r^qx3M|hIo zoL&`T%Eh;XWi>W3*f~;zr+JKJak~PQ%Zt}-G^KItx>YZ@{U~_B8Xm5d-sK*LKH`5C zG}krX>i<>DqdeJZcu0CHx+NdscoumaYed5g4PGrGhtLjUsjm;*;ecCw3fx)9LQ-Ip zV5uXY3CESB2wl4h%p_8HfT$j`1MwRH#US*>+{{dKm_t{Cc)%cM@~N$@<;XTwLc{># z?-D?zVX)mtH(&4>>!0vebd9N~xgM?%*t+?a*tCKghxbVuJn$vjQ2f<{td&Qz-zD%c zZVWyf^!Al+YYt@$hoUe)4VS~b=`e<23d(#s<6Suag;Zbt=;Tpd{N%*Ds|k-ak%k;c zAg^f?w4reGhI5|@qOMo2#dM#MQHk$~FTAa3Wxc^a7NS7q~O%6%P%GQSFc+nY$ zUv1Whhd$hWPDxVf_QBn)oa=B z-NkaRi?`HuC2Q5yxw5txPd}eq;P{z^LA|*zIPSAuYhW0^cPWo{-=iy%dUWl$UqJ*f zgtC(fAfg#?R`96m*l(f)d|pQ84o)5hOd-r@_=lnCV?qG9P2UlCF9n%seTeHDM;iy0 zhCvbwbv9T@R+t=Eu;C^+Q>%F_8W0|V_t8`F zL+!x>$1lfZPi}4m8}I*GQHy_p74HB?K{YSSkn#yA!VgGC0Fr~13r>I~z?72NWP*xC zlBWVOCtx**Krlc1EdQ3ftrP&$VFc9UM_w-}TxFmB-y@-Fvhef=e4E0sDFX@=xtReN zD*VunYyyx?;?Xh#aZ6!|xQC5Rj0i}khfV~+PHN@n1BSC(0~bq(rJX2cfEW_I^q}D6 zhsLYbtM@QxV>CDbrydNNWR13F%<(FILPG&Yf)Hv=WN@|R0P<3B$}s!uaV}gW^6sNq zIqM)IFqur~RdWE;$;zyp7OPIn#Zbnu`;Am{_MdY$zi)J;CSEyR;aGWCRd}?2?y}LB ze*RFd)e;uP$R8&uX#|Id*S0*)F^>+)wBoRO9Vj!wc(`KjV`Js3_lwz_0T$nizfHvU z@5o*+c++wQP<5WMlennRK<$}#_X1tmGPAXIRm!{6O8Ij|^v<)aWsf>f>)F>=KaW4o z&ia-%uaYFof7BZ7Xhg{o1R)NJXCC(E{ROi28JN~#QCo+vTABU_KY#2 z&wvM@)MfC(D*-l_ir8+zMx>gv?cU89d0gF6r60jm$`Jy*1-alA3h=e@4I0rtU&r>Q zx#2o&NY{HxVz}{tD}a_$R8gsY^@`HdbFF8|l~)k*-NvEK@K%FJxr>=8)BU%_CxI#q z+=%Ud1+f$^H}NCkd6)iap>igz;#TC0JoRnocohx&+bhH?3<1G4hu@-9&B(mDg&hm< z$q03_@xu7=X7dp)F026&-y`Mb>*uE!HPVdX0QPMt^h;U~WD1We-}E>pZho`in46%e zR`=bea}(z}8nU+zJav4n-%$FP<@~{>O;NVd?kxV!3(J?1xLHc%HcEKL(6}1B(Sp%3@-y9!h8Sw8H6;~{ zFC^$<)J*6M2QM!#(p#l^Rxq_#0=@vsm#8xMIzZyfw%fbb{>=PSvFa4!wCg`oURr7s zcNR>dVPRq8H|1SotUr_p0!u`H22rT)S^rQ4nyBSI-_9^-_AMM~={=aB@4`{Cc z|8H1BX_8POS~63FvP&U*RU#o<*)uB?GK%a?c2=^tG?103L`2ET3>7809iZ4v&*$?R&*x)3osLZ&><*8Kd{|ZxrzQCN{kY1;tnFsk>weE_ za8NEOFH zJFW|9b#5MZ4tDVHG}eWQ%SDHud!^8Ly3bok*;@HQy>q}AxWjFWh5}35|`rJb!wvIx9Wv?r%z?${+Uro|#gxGvVLH zbwb~$G_+sLKK%KK?F**;$?L_bo^ka0J301UePCH{J?+VD9cnzF{$C;6Jr>xf_qH+$ z_@gmgAbv}VE*NZP(qvLxaN@;Dfw7(A2#i2co`K&4`9m)u7yH)eZ~P%+_Q;1sEcvbO z`SSOI&SXaVYc=XFVXQ_u5xmi?cI*>;Eq`ch6W^(M6^``n$Wo+w`pah_AyBg6an~lW*QN^z=3%sGrtRJ~)L^BPH0+R9QJ1?PfZ0>DluC||wb^IV) z_Jyg#`k{f&NaV{0ZJKtoGW>lZYGAr0-~;M|V-=Y)S?Ac-~A;9iBeXAjmdN)PF(oVt9mX}?3VA`$p(nePAlrTM3& z8D?L+O}v{puk>!8<+nZ{MTN-llkUsMo|fOLT>5lCUae=_XSoe;nZI?>AJh6;`QnPl z*{KZK7rx>LL=P6cbZumCJ=y)p%kYST#g+EgUEL8sOMkyjd?5JwfoemLni`nmFI)YKIhv)eX%LwD!B)xU)Hv{v{5|KTappCPIdu=dc>&T-P2r z(E4X14iX9uy2l?eb`b(xVBdjTNB;^P@Va!~>f8n|$4YskIo5Ff*T#b&oavm^OsGL1SE+U9<1?M!t%Js4VduT8+S}UR z;$moD0~3?`Z<>JAq6u%VH)>W^oI@hwr%$hp2`f{nS=7X{%bn*qcacMt%KN!w(4Gzn zsZq-;{mz_EjQHU8L^xdGjFk7Gfwu#;gY2?`yv2&O|Lf&HOKX$G>bv{AoZ9}+jBS_w zo1nZgv!QcJt#MO#dw+EqzwFSkTl`>%q{Tx&=MSp0(Cy{Xl-1j^%*b%EJIBjo+m`S5 zc19#CtKX@e>vS$ZXVwgwU9YBm-jg!s+eJa z+q5KG5rfbsG(AF>V>DehQ`$B*K2X(9ij=l;66!cXALha&r=)bm zdWX@#>A>noxY|&14a2ZEJd%Bx!qkXvI`9f@5$)(U1cJnX3=2iNrP;`kWmJlQC8ixvoc7F-_vnbY2iz{-&VSfqt9(RBq4jM)DZv@Wr~;v$ zLdCBy+h4IC6ueVhZSb;HE#F{tghTkr;^A!~CnJj;3ZmBrMZ>t9m%r-xn)C5}67K%9 zhv?Qoo0yCKtNI{eDNr=D9Fn!N;sxzOVj_*zi%g7-Z82dNHxr^}xyuF|eb;74d3)2G zoS1kfN;f8xdBcYW&P4;LXD{z?+_9I05UNTms{kfJgbjU#zHPylRX0G}bwB>W_tTQ}QU;jFlcfOW&6c zOvZkV>#VV}cb{@MGw54yxYw|BMMy?yYN{waeJWk5DQ+^((D1zb{mc8;hs>n!L>faO z?gU98-vbVkFnBz+ZRyf&IVrz9q_DyiY z4fnNfJKanY6Qmt33OIdm+FEcT;nf9E!|c`HgL{qMZW4966Cce=l2uBYdWXjtrlV}E zW|hSHg#TpaUEMrX3(q0Yg2R!rXU~F*xm|2O$Eyv+j`W7vPZFu z6at{r#;D8xS55(fG^vp?;yxN!KS}&{mZeg_9Arj64xMgX7u+xL3{{@+9?!70n$#=q zIk6X>X9ZFpnbua%mCC2RqGNH_A}QO#C9EdF-AeYb?Ltw@5Ze*X>X)vrejiT0EqRd= z#rBhhVVto>_>t7np!B>$Wx9PWvzRNL%%XDV#rA&rTFO_S)Ia)ihHAX5N|5GWzIfMK zAPJ#RY?HfU-=8tLwtbT==OK0@+eK~K@`GVZW>V|gxEO6rOt*dwv48!xyZ)u(E_XYm7s`-wR7Bei%;E>tC|}I($jyA7+>8B zcOKwyDDED>u3k`Tz2`#C5(HL=aQ!*YAYI2WaA4r;i9j^gX<)@L#S}yx8GQsFgFT2$ zFr&|bJ7!dE_Dy{q()7QRlwa7=_yuDSfdpf?@HPNYsE0|)kiuVMWP^+W>p`IFdS5Ag zR{M}!3%C#X!*CeU*_f|}CxgDHV3EAp1_!=3xIzdBYrtM0Tu5R;ha^;kqetAeEYuxv zkBB1r2xu&WTibsZn6E8eA)+4e=1>8HKUGSW@Bj_N@m3krX#oJqsHjjQV;Ct$^NL!@ zqtYFsWN-lq2F6lAQU)C(L6h2A&vHiih;KH8(EIy`-k?RHCwB%nN#3%e%;wZD5iBvk z))x_Q?{HE1vTm8vFcuahZk#HCrmrM+qi1&H^=B?O!O9hnb;o8 zEGA``+h~Dc;|)pMDl*-+QHW}Y%D99KoeEjUV^K+Cec4)T&BEs94o$no--gOF{wA}l zR}Pz5u7?t|wZJf>aG^=T$D+0|P4ugo*|qzjc~w@`!9R^&r;cv@RmhgyHlMk2U%Nk; z!(zt>P7Z|rYrD&d^NC+>N~qq#=s_u{My=%SX(TXB{ue?w7>EK0O+ zn@FzI@rA*~y|S`RQucj7c;%K67a5eWnk43;oPe=_@-6A7| z0q!XMAD)-{Q9(pBgg)vm-l(CWK}cY9BA6(1J~T=kO(6-ifCH5bK1RB9GtyI++X_jJ z6H?^K)KDM;gf+rJ1^`!#;7VLXH?IoEWSv{ya?U-3}!q;j0w0}oXmbq`{jZ3cI z9!?zOe`I_=G~V00Y>l{%TSR%!EVC4q`p)8)${h_Yn1t6#-MjkavDDlhp4{DPDMdRN zg#6BDTPeT$O~0jqg}+pNwSCx_vzfb*jq|9coNg5D-rLS+LZj~AKNZ3<@!s^Q6SCj{ zW{^lEc(aZ{UXFI*4)hRc6)!Jx1;5e zcSPjDA@ptyxM-j!hx(mLL*tD>A5UOW(jxTLM9mALX;I2fY~{+!nqM{d7Sm36qoP{l z4~B)GH9kE2V43wGFRDsiZzF4JXN}p-9fC=zCi?AC$=cqfXKUa62<+E{#(BIj?atEN zpO7K-XKbMbZJwO-hdwIoYQ4iu+giu-aqNAeY(vW6qrmYkn+}K`U(9&D}4pI7taO<^{_uqK@q-5RJ2-ZAq^~P0FKd1@WO8a;;r`N zriU$%dG(MHbMj(EoM(6}$B>r_Xq^f#m1G%{S4QICoOsvRq97q(LS8V6Jc7wVQvULyJye8&OFC*!8Ci$__Ar{46))y*`i{^xG1s4@)p{CC-NfojPOJgi`T z(!G01p}Zm)SHb9$IWFL3w?INf0euQOuYU=ikz_Z1pbtTd5&hnq5$@1=e0e)HwdMlb z2D1+(oTzIKOG>h0Vi^?V{}-*yXdUDfK#b2diC4~JO0+kQI#4PZC)aW5ey_m~$A@+X z_|ul={iI`Ow41lDm>-u9jrFHke@kbsAU^$3CaadkTrGoA@7;1>W|>m!<>B(A7fTKXQmi&WV*=*B!c%%W&4{&+#-z6E zua8Oj{p3*B8~rim3MT3@q1S~6cQZ!T=~-HGTyGnMALckTo*2R0j{SgPg2!Nl?3!3D z&U-mMHAQAE;O=AxJ%*#WInEQBOA?<{OAJQ5(hkpXL@BXvbl-Rg% z8$nCl4k^qUG9wjmFN$w6{s4hQ(l`YJ7=B43EARi5*tm;1Grcgv5AzL*z#*YK_eJIv z1O%VMRV1Fms|j^-_5jLwB8J0WA?g;?3ke{tP8nwp&(3z>1|!RlBvoL-{7K>Lq09IF z9~8%9c&QD+%}<|tBf9IrA|}+^Ak^*m&uIl>dkzlHgpeIz{~IKhlvd&VFa*5a)7J-4 z0xR@%TIVt$kG-Fl$Ac&nj4xouXostzE1UNdVHI}9FyOUy=hrwnPk(5DdIlW-2ttU^ z%ke=vbfU#^%fMAAV6H;m1D{MQo9%Z31s1v9g z3}6yj*_Sl@bz_}W$_-w<;r~qxDj1b5dly9BT+aH>X@Ku^vUMZgdkh#G84Vms8j<+! zaH_AZM$sy3;4rUX@MX?&e}#MDe?wF-{=%_DJKlAa;1DYGtc<+^QhJjOI_~p!3obXf z?=ACj_J>^XAPCXxa+=LHc^AB?{odW&zh{7fe}TW(!|+7>Goh65yyw;{H=mc@;)DW9i3!7K3q>1?|%j~9A$Qu1q@ zN51vmZFk!ovxiKA&j5`3z-VBR&7Kt&>9T*w{%UK&uIp%*FIAOTG$|&mH@;_R@t>5u zoL43Ovy#`Bo@54n*G-O#BhyMRm09dTqpV&y_q<}6}2>LFW5x4!~Uv3f|xp`zkq{CzuN9T-xOflB=25Q~S$(Kg~f`;Y`} z8&U5dVhpaHIY=(~aR@>2Ls1WfLJvSU&~X&F@^DM3K!zXMe5J+i2*n>*X0Od%P8!wP zJdJHek#p%h>U1)Mn3&UuDhaMhjQd$z|6b4Y9_~HpB}nUw&WVkB$^3+zss8ZQufqXE(r58nCqtr1L!V8$3!Iel;s$DFWc0U11sEoBH9E8^eo zfuAg0&>N^JmfsBy8l-$HcK>cDZnbq{AdM-+$Zb-pMB)YRU|qG2dyNei8aiif^Bm%& zFP}a=U}#&!`DPrYAS-(ZOoVB-%hI>~h~iDfSO1fxnvBN>4Mrxqp}|xC@fFGt6xz+O z#Hdpz;cky@A|LOEp0&v4#1MLE4UKSoo+vV3paFxmf88gvRMKUR^TTrB-q*msL{ziW zqboeRjV=XHio?(f*GA5LX6vpHHNZF{vLrJaB$D45 ztl_>~kS^oP$fn7&Hib@1`K2o7Wdrqd{z{)ePYXo+VvDJwyDhZ_El7vsRNq;U1Kejc z=VjlrrjKUM*Zo}HJ;*<5+hWM0c%071 z(ZhR`lqjjKvuvlQJR)3juB*|fU3!|iP>gb2xCpx8TTZLK2sH zJzYX{5iKKQ3uF)^knQ|opHtcAaaSH}KawOyjf)z|3(X>bu|vr;oSvx)wtEfrpL+TF zVOmm_k?~8X6KVzp`ubMeKo# z3To4stBm(1y=u00kp(yhnN0_A3S^oEnLq4pZIPfEjDDvJ^E+TeZ9s$8!V`|(2KnbC zLr&?-${~~#P4)E@(6mAQ${u+c_V-LBCRC$jJHp;2j#8P}#2`#0;q?%Ot3IrSrGyxn zh>{V<8JYY8cN0a!BjtPWq=RuMT|eH(25fe6ilR%UsK;b5Xa;WnnUnR~y5||*!x+Kq z7eg)E6NGnS+d{4S0vCF=YKm00(lhLe6BtE~&J#x~1Nq+<^IqbA2eqv`O6=4$;j|ig z%c^jUcyX_ho^fLYFwA^+@3Bt7&r2V_G<0*fUh@RhzGxVR_13S*#(sP@wh59pePdJn zWha+#>;voc_F%SMK~8eCp9V@CvU3N8BBbPQ7@QR3x&9zGyEXB?sKc?zQ%|f`UbeQU zjg=+cU~~C=eWJBybU{ARc;d-9U1bSvk+}ZqxtJ%OLBwp>KjpbAu%txr@ngYsDf)wX ziHdV(CPC)X9wjBK;#sd8^IFb0#I91T^{1r!@U8RELz&)Nb5P`CtB<>XeShXBV5P+@Piw=k>m$L&4`C47+FtjL*#cZM%CG z%jhuj-hnC_oH;9C)gJ?|$HPc-GqZEu3TVsNV28rm??U?qC4pB|)Wy&u2$%9MX|>!2 zjQxL9VOpTbPl`GCQqeM;My4eMj=mF+bFNpt@=XOhp2YSbW!D6q7=bfz^XUUTg%UG; zN=^mxzh~H~4KOlfUoPi`HroKK1QbkJJxpYT6MPk-utEP*;vsAp)Gz1pFo3u)&=D{S zrwf_m2}fMe_dXEWEF`@fR;Dy^^TDf~5!|p)^lBLe;VB%DxJ8~}n}ajQ_*|V#A%Hsr zIzBWn(qGto{!7Vw=%Ivdc7}?qS&FH5ZK1eRqf<|8&sA%Sg5NPchB?Z0TX&GpHOenvOvw64m3>Bw> zdalYxzTBlhTG0b%yX?A_j{jCE)9JnS^NMBN^025Kop8(rp@WCH{a~8_anG}=FZ_cm z=OVK)<=xFGbp5q|W5yS#Rn1NfeRLQqkZgRd7_DM@;&e*N>6Yh;TNNE1KgtQrA5wiz z%e~Q1U|{Z~6u&j^rc;a$vc0Z}Jh$Lm)!e%e9Nt~i@NPP0de~GK@D2>rZmOL4KfYlMC zLv<9^2&^4%Ojk!Adeq1f@(rT(Bm%4SqO988@?Wl~Y3sh&+YG%S!dWqu=3Pr#YHHF~ zJ10!u5D*YRr8snal1}G%13Wn>1U8%>ONS*ux@G=*O;wFcR^&0&rBQxp`G^7;Ylg(L zwm^&3TW+2Fp%9Lti+z|TF#8L?-?b!d$Sb5i0^PJYOSs^@Ypo$6yRJ zL}+G97yrp}A#i38M4G#SnSfYSA?^!A6buvPs8k$K4)qA(4@&q0zNf{PW+73)tVT@0 zucPTi8|L_5Rl|)1MnSEWO%iqi)1TZpOh=I*q=1k+xzBcdpAA9_^^zLOlLot9lx@Ot z-2F9zIloiqJu?{0?{5@s70mA1mQiL`1)RMoCY`Cqk0~J>k+;&9960@2S7PXBXEP5* zbkr3v#f+b0;HBTI^l)7N)}x!vA!dt*g?i2(XA?V=^j0=@_vMK%18+wYHR{_Wk01in0zlOie+0qQ&f4yh#$?ZJLMXcE^MK0 zC&J_cRb(|3&a^*kc&!sHzy0vZw>hVLTezZ2<%17z)#}a>la3BmJIGYA(y{6kab?Xq zeWUdrk}fX%k$#9lhYzoTjOs$h2|-FTPWojO9&om?K}7Xqq@ap8jC)~Kf6#Cj0=m4pLcXnBT^6O)aC|H-1%ZCqexRh415+A zZ2$&@Z9Y-pg6DpBiMWRWXA*ZX%7li-Mn7br>O(kOUhW)V$VtXjp+k8NrWPGf|(YG z|5j74(s8-rJN0MEVDkEPGmUp8Ts!1;G<46E{b>AH^_0Xp=ONp;jh{Fiqv;W*iWSMA@u7(H6V#rWxC>4g2><;<{xwlNUa z0%hiRQ_A*Ney%LJ`8rOlKy3R*k+W5nBDzt!Ja$)nLK3MzFs|0}oD`0A33F5kx4O`I zqWg2XT$7x_??va7jI`5ivO9B56h3_`&${LE8IYwfz3q$#%+_;J-lb6V+C0Lo9Qz+- zqwoBZmCPj_o^e;1*WX0-Oy9dqd{MNhsji zC@`RskVZHG=RZ0Rru=IH3R~`8BrBn2B-1+0Q3HbpC0GzDNjVQG>l+~Khex?LaqWKWXaqm%vsondp_4hldcd<@ zH8yU5rx_!jHy{=UCpf=(SGsyRiB^Sw=qd8=VcSZ@2TW)rU@r=eNI^J@35r-w`!GA| z!Gj0tO$G!SveIKMLf z3lS?%-@c6iuOP)apm5;CiD;DMZ?N;T%Hn#IsZhsJq&L`KdAEqYk{GEI$ zx(%*xb1zYM$R~lGqFFf^rz;xdDctwgb|*`e^hPOn?WDH1I##}tL;9y~#5Sg?dimym zmfXBcgG(e%|Ei&v->9eInonAEj#tDH2HAc6axZNiW$xvR4i+CmlXS(kJPZVL&uE5Q zN2zE=KXP!M4o)n*A{_GR$Nl8!j@y^bjWR~&$CVj1w-4?=A(}jO*VB-ybftgYdztQp zLFe$i?>bXrJQ#x!7}f3Bv3F8>M)Rm$vimO15Tp`cWY}eB@Y>E!N<$@tv7Y!i4I+?xf@&Fb-y{QAWQ$CvZ`Egw9IdJ?1X`Sj$QY+ z?;geO0OML(9QW;ce0Id|PzSBi-yoEYkCC_le;NTq&_x0sfg2h&+HpU)w3xC~ zR`L7NxAB^=akuw>)k@OCe{9ENbsPlzghdE>_WtfHCeS+_a zq`K+tnynuAzEX6Qx0{zo;F5pRqiYS#jW0Jar8%?rX;@ukvHct{Y%~&jgFSJ4Vbgg0 zc$cBnI>e>gl_5{t1Rz@^}?%%F<_T zwexP{22`^U);;Z-;HP{-6;I|B|B4mZ;2TRgfGTe?Qeyi%+p zQeI70Ezq-QPxHfrKKu3`T!X%MYst1H?%zk9d}crC{t;(15~8R-*&1Wp*Jyiqd8?Ig&X(-GlBvG-Cb=BcqYB{ULH2=#0{bzkPFxqs!|}zs2&wRXgTV4xMuM z1*JVD>69W>MU;_ko6?(hWT5CN1@9Da_OPlo4q>U3ME{PWhhx)d-W`u>O>x~coc74o@AoeF`TuL`pU~iw_k1Zg<9n-}kHY-IE+N5!9iH9a*th{hXi3~r z5ri{1IY~4!oLsy`wLnj8R{we+E8|Nv?<8)3rHR~v^yx6VJ=VuzS{DXcHU#V*Hkq25 z2JHHsl-a3A|9Pk7<>{}Wh{RTBTRd>W{zQpDpr;sTbz3CDyhep&GIH`-zc|^!f zNUmmyzvAU%>U?81dCkDe(n~ADQ*yL>L(D~iQOtqjIM5h*;d%!Q)5n^e9kgYnLAo*`WPLVsR_ARc6jRzOuJzf$_q@rQ8op zA4-@0rsq4RR{rhX>({4b7oX5io{|}*{d7nEmqGB9-y!T=tA&3;4L+(@Q`Jg7K78k* z9@~DQ$iU!6oZ(YXK5gPM%shPDyP?9k}wrjIq% zKJ0#J85zEnk1B^5)$emv)punGWT)oFT_i%H8yD3%b#r!KP_r+jkw2~7^+Y}@EAohW zKrct@LnF@3=pf@9a{XM`)@@aEly_J-mTPnhy({z^kN7=yNR^)Fe()}9W>!ageCTCc z>+q-c+b<3^*&O_}p`7L8fz{IzOcH}$b&n**LXu=-bJDM+<-j@oy#30Ffa;waf7HKg zj$FCAJa_KziG4+DDH((rDmXt5Y^J=+uP1lSKnU2To$$C2W=t5F4u{Sk+d{ETX^4gpCrZ<{6*8#)AV;|!JHYP zV}P849<45EPYBX>o*$abP5_63Y>FtcAOtxBG?kfI5w9V`?)gW^tN00&%hS>*i# zb(?m?uMz|loVua8c{6P5WUQwa;6F46WGDp%VpJ0mnNx&ZI?fSY3?@NSOlAgOdxSx@ zH}eouA(Hh#own4XHw-)g3DE;1A`5dd-X#)P2#u41UA|VaA4)pdP0pfO`GpKpJ_zuB zj*X0v2`XqMZkLq_fg3S~Ggw`Z49FtlCnPhOnZ1NvvkCBT_NBSO-SlYWNOuRTA>7V5 z%3+q1&;F1!46zJJwSh0~7?Rp3NP-8d6O2`gc;nw1u>}j?p?@GW;g6@`&HSF~wG>CT z9dWJrzHh?ze73su`#e>R42SG-IfOgLOvHtShMk#6<553YbaHh(Q@(fp8*7Gw0Ygj8 zm`2e!*&zWLEZK5ehx&OMJz3-m$_HFk*HRn_8eadF)C7+1k#CI9vnwRpe z`-k3&4;8S7-#BrvB;8`(E~4wEuix5(_Iq=>bNz7@#c4C1?7k>nWcEYu$2{$2nbCZM z?KV3epEBMSelh7Et4Ga-0j-4)#6;52(@SINkeC4q5Vc`J?ZAs;Bt0bnbwe{@52Z0_ zw&9!sMox$$A~A#V0W3*|(=MJ%^)O*zE+;L~@xg?IkSXHI#k%S(G~7s#Ixyr!X$&1k zcC(=4#CZ@oO@zCHYLnIT4<9yNBg_E*_V)k$Ng@>BZ~unDuK&bfty|aJ+M0$Fj~HDb zB)K0R{v1dk+erPtgSHNB3B)`Txzhru0SCS&;IfJzXP-Kj_?TkI3ig_u-R4{zo;BD(>XZukU^2I;+tvmLYb0F+8W zXX@hOYpDMRT@2yckVvWf`w<0LfKwQk4jO?NTpE&`_rz*SBad77mvD@UHhl|EMBb{` zwos-(u~%oCd|WU7cxX5j5PwlAZ=NeWHbVBA5d&w_@!l1VTM@Ea!(&`QVs(ccz6oG{#jV|NzWdnf;DASytXlpy=GFGHDJxeMuT!*C zDRMjSAEsKnX8j1)GkSWp(+2}7)gIrbx^(Ki#Zd+t%EA=Qr)#M=4(#hX>$++ExwV_V z_F9y+FE@qxvTgDS^O<*-WSwn)r7T?R)^mKJO}L}8*MY7ta8B@|r`6v0cc*nt_f#T5Xz2T>Y8?K|I(yk1 zqB`hP@#jj}=@=w7Zr!lXq_~%RP%ayy<4}-guDW-8d|qP&Q^zD|n1B2X_z)LfpP_b% zTAc2SM)SM&7|vK?cY=--7RW;gD)aGK57E-yRF0$pPWI{_-Iww}6n$Fs^R*BC;1++GpfjV5MteC?b%iiY)ea+(W(}t+Ie883zMd zz0+)znx6h*vN{pIRwPg5T`muXR)j3$Po?>-BjgfFN;>L1;KB}xQ?4cr1lxW%2GTe- zV`7a(Z;>zl;&$s8`^Gec3A;?@mz%-li11tF`OLi%ouxrQGYEms$&VV@LKc^V7CRx0 zZ3V&b3-QR%@jQO-W>OrB4+)Vls-I1H&t6avbKQ1j@$W8YP1?nZ_$c%D3zz?lO;P{a zwHiia5V@Y#LCN2t;gcTo-Gh2(p2pEyuql_RQca2F8cw*E%up!w%u%$;J!M!`Jt#BS zqN3u?wvE+w?7eeOxfAQ$`xgtJeo(vCnigI@dEc{suUWIgRhuFGXYHTwu9}O9{+!Y` zeXYEeaoxx5&R?Rv^kQo5h4%FB5pWh5OKMy^R`bS&VfIp6bYI0F_kiI=yV>w@YbB}?to3NF8?OAgYq`iXY44}c|UE6-n9S4$V^~uDn@(W zdq4T?nG}hbLHHU*UN(KUQPt5=yq4oEg%nH`MadkF&6q?*;+1Wvy(fB$+412Ja~+7F zPKA$mk}g4+{bDYCP{XZWou7UL$=$tm160>!)*%NkZzhPYnJw?y6cJlUMsOYu z4?u;Yx#YqCPGe7w0~W@7Ba<~r%qTR|W+}>k)ftl8T2LMr=*YJXDOX+r$HdgMYevW*`!B_J>h1sv69p+LrYu;9{;Z zb{Q#O5CRFQ+?TKB7$qbR)d^_jh`) zlPEQxecPY)ra8&cjjCv4;c~%8p1>tT=BOK#-pU044LzcI$dF`c`;``Or zgGyfFUQ5qzI8n|#8!al7G>}a3PC4Xvs4GpEdcrANVZp{Z{(Y|6QCC5c&ZN>{3lC*u zL9=b{t_xX4my2yrGUxp^q)!xamdbjw!(i*v2emd|Pj0c#-@USR+=20CQl;Tux04X2k9-p0%4zM*;l=X*Fc)&5uQXZ5$t3>=+3P`rQB z{zuH37Br*7r#4pHsep&W2el*#u2MX_6Nk+UlVX>A=ZSlzIRE0m;T8{vO;ts5g;cUo zP}5-1d}ExkufxI90NB3`r8Dt(JT2~!Xhqe5glBP_3PdW2xqp>yPV^&oeDUpM_B%!c zgeN8{L15i^t<5KLZ75Xg+Of`=@Z4{_rDDBEDr|qeACTP+VwTUCew~c7n*=;q#1%b_Ox8 z=LWdhT}xw*{nG2%r(moz@%)ZmekC)P*x5|BaGvW7S`Aun7Ybwtm+I1oGrvt`b#y&` z@v)D$hoNo$SF6^Co17biw$r+DMC3-~?!LKTGeP7npcoie?G#RPOVO@XWB{Kfm2|+xx22ozK-%Pm)Hi2<%8Q4ytl~)1RY#n;A}aP z^4{O%kj`H3+7|&&A0ABAr3jt#erw6Bcz&fmh->vOjb@l*eQVR|O={D3_iYoOc6$)M%*0D&CPQyl6H2d(gn1vr132= zv}Xfiv)6A-KIdT0A4@7U93^8&TO`Z}xsk9q-$0MS;_+Rr?)L~O0~ru&-hQ6v^%PI+RM??-%gJ5VF+Dvn=h)c zFN#^Wl_M@_)@5{}wEm(%pAxIu0E3gNN$Y}jj2h4_}1o4{Tuq#iyL6wM_-K+ zBNP&3RaMnUR%^7p7qS2yGXpHmswhq#WrUWU{Ua1$SWkeOwKi zOJbxcsF~jkb-~*DYo%cc32kVsdtvs%{^MTk$DQ&NxK%GSy?FD*mZyr7m@_}I{7}^u zxr4%N9z_M_Xt!@UT2qgP6-HK~uZXv*3P4**LqlV;Uoz@VRo?Amh=W3IP!!e%>}t** zm=;jKL3wowPpn4gJC@`&ljnMlq6SbT0$$!=-i!d_Ha(9Y^teZLBflX7E)~yA0{sY` z0g8@LA8TyDpDaAva@6Mi6=#xGj+E&=P$!Y{2*^6d1`tt%Rn43XD7*&r3>Yn1&i-%# zTI4qcdXgm53e{MrnE(O}VxgOgj@EU}t~s!7W?vgRoMy~HgPkO41cIWG()YS)4<39f z`K4dK*Jl$}95&DQ$@kPYV`jD@i8}V}f-&tTB!Wxva5%fBGFCH*uttZ2R5}`)jie;K zoyw7wxLp(&0LgMn6m^CKOkQ7qU-?HTCL)CflBF2S?~5i*$g2MwB#|v!*NG;=D-P_TQ|_=B;79#ge6l){ot1>p{riH~ zyD`jIv_M}QL@0<^;soe->a#>d;NbC~cYr$rLPCx@3$g9bIinhLbPrwRa#CPWh0|3B zrneRSi${gnf@w4Vibs05EZ?wy9$yk@=BwbC@_wR|v#NL|HZ6zic+{&0Pg6Fuy-Ys6 zEc)@xZcF8OQGfNRnS8svrQM;U=c*bVMpU2fPg?n<6)X}GW~I_(A-Gd-cZVBw&7I`V zkP(IGFI-NA`LBdjQ(BXFYr+?{_pOxeytXaFH=i+aEvna1>#@C)E_s z^w?8@A!y)NQ*kgJ^~CEtCg@TazNnn?5B0K%RZ{j3`naXuX&-Hj+H>Pd^>)~}Q7oms zySB!;`sYmbEWgo9nu0soQs?~|3{v2GLM7LCocmigIT8^vtAufg@PXm8;3IL|2;bdl zL!B=at0?y8+jDn7Tq83F5tzgbD62C+8-c4iUHc+rP>BLo;ND z(roShs3^?hVT9R?#HBHxwDH1MA`bBvaEGJAO?KMRj*|V@U=QA11MwJO+H+1?*K@^@ zZ0DZBgVV9)!wn7TZxv9Y?$v_)AbE_Br4R4?vn7$-534IHHaSm6EL|w*c@BJ{vuu@Ab?L+e|NdeCW$zFt- znUqOh2$TeiZxF{#A> z4t67}J1KxMDL@z)#|RvU`vl&Kdt{|@=(tUnas*1D9$Dk8e;5}XYR`y0{7yv(2)`c_ za}MYngq5b)kJu0lM|Bmk%t3yP~)plS|koo&Zd&xtGh|T+(j;12i z+oU6@oy0++9K8X7WGsZO?|(uI!E?E+?fB(47(~#H2UE2wh(SvSu%QvZE6_%YOELn} zuVmlxm_{(g^8I=N`CqF}0)bIw@7d1V2XI}~=hip^eYuUtM| z=j@ltv7|E>RQ||#VypWEATAj#q(Au5%V1-MZm8~BV# z+b}w!;GLha{PtvZB?MKBp#UR}N>hwB)-ApTy+S7 zMDvra^pCakabqJSSM#dX4qKDJRH8Av7~}s#JsRu}ZapLtJ>RoVGSvV#7@1-~5^(G5 z>*IJ|Nz2|0c}22zA$=Sk8M#nZJ@QlXO-KjN`WyvE>f<63#j0O|$AoGZdhN1#i}HKZ z=`%wwM#wzmizrjG>`g9D?tF6ZaN4okYu(mR$mEAUv0B|Ctrqjxb*)oYqS>HX-Qco# zFwe###g~hOjgH^!ym{A5XziVe`q84R_8r=~|K+fd)pJ4A>E2%1&~-oj zTi&co&h(+xUHNmN=y*{;iHvK@4UTU#$5~^WHLuqhd?~ue6Qu1YZJMI2w%0vDf_b}y zqL1*c-GY9~KSXU!TSYlb=!=9rj##FCpc%KfT76}Iq_T#8n*Q#iqSftrjNFE-tP1yq z#}j^f@J?OaD3lvMB*nC6dA`$({`-a_!h58d2ND%Jt}1=o!Z*JX&m>&Ky({|7^8E8A zP1#faJF+y`Qk2%0(%#w>AY}Nx+NRmHVWh*Vd;!fC!bJy*UkN(!J_(O5Uic-}vdqGB zid%n|4g0aIgz1*+UoOz=n7&%p^*gfBN$K{pE1%xeFiLFqD;^(C;j3cr&dr;AR2_`S zOM582=j+!|f5SGjN1W~)@P->Kc^;_vQQ?Kh#JmJ>O%ygTg#1CE&&fpPzbIR5TaT0z zWgGD1W3h^$I{1P7k(e@|f6BN7;WHYZ81fE5hehGB8<67~*>N(EM8jFH?Z=R(Q{T!YZUMxlMM z!!z>%h0(q7nG!pX80Y!>rWqm6Y=6qwxa2 zIEZK=aPZSiwg7zd10IW4KoGq^3l}S7$p|ee$*#f=ovB^aonKfGA7i1VmBRc#;dj1d z3G?OoGF_sq^+~u__PLRWgBV;+s2NP^T-%L{6(qeKn7G%bK_|t zVbRf=CXzz&_s|;hA?!8@*8I(_$5DpkPv*fa;@Vb^D&Bl_OPHJ;HWp`=J70jCRpR&f zK~y(vRQ4=gZ{9%L6}(>m>^<&8Ewjm%GgE&~_`Wc)6{szHMVYAea;7BwYV*D{j^{hg zX+1VtgX1*yi{;);$9>fv zmw4T(qD;S&mBylx{p5+n70GG&&`17Nl#b>m*3dID3Oszcq=ohSn$^L4z?VmZK_w=6 zUl+3+wlc>c$J(uHjQzxZ?fFwtH}GTZM=xdU?TW8I&HvCdGC@gk1JiP5u@Pn_VrM(r0BA@Ka_~h`gB7s+V z`Q+GA_L#92HhAj0>n-c`OI&@c^`;%V#drDALFW&%_thNxjZKLo-SM0N@Hzu^6-3@Po&wg z<8x)e5e|+|U~GqRE{{H{o>|+!Q!L`=fun*5fVHZ4Dt>$bH6<}KVwP+QcBXMlB|7<} z8$+9bA^8Ilx_4rL+`wWh-`|W37>$mMbfOFEIQFaYp2}8P;_kwAvk-AaLL!Oq00`}z zM{oBSbB+l5h`xTPc@<+hhBx1qpeH;vw1S_#H?r(QP59!)*#XzdBP8_=)ae+fGr)O_ zGB`v*fOT9-gp|dk@lYNfK5!k@gENlBDxg5CYlR_VtLmV~-j zSJ&LCTVHwCtV^fQ2-S^I)>dPTjZ!@9n1m{%f9}m*3(>On;+y__K8^eD1yBcO8J|^P zss7EU;?MeUp@W>|uxl>%+z65dkhshqk& zar(&Z-IlwR9_+tQebL?PS5Lsch+9(k#r+qfYMc5~?DzEA4O1??|Ijevi#k6;hikoBO2MCBWezpnnFEX-pXyw{W|3T<>Fdv=GEnOpX}O2 zNGg57U^ObVW00%@g4fQE!c})1&;zs(iU<~FUquFb4!ii%Y$&Xhv9thH{I2%^9w6>1 zD_iZB3#(m8KhrFflgD8_D$zayDX=Ln2SQj~;)&uv9Rt1xoxaS!NYZo&Re@@~w7-i+ zjSR6b_^hU!wKMl!V1vO`Ky-+NzOnXLMsjQKhZ>jTtA7@#@uL=aG%gUX1Hx_gV-E9U zs7P_4B<xLflvj}la3y;ssenD{#x3l!XcTp0q(SQ-sN=Nq_I8 z@`3iyn)v~4W;x-teD{mkeZ_+ozp330dFeRXHDSixw*B}M>D3MON*k>Vt`A-OAYgXp z5nJ^4w6DyZC#BBo(=keJ`*{GNS&5m#N_-66jph5Iq*m!Bb{A>UKJiie(I0S~b$XRU zyFI|(^+%Lu>k3bv6St)CAD_U$?8n>dC=2O+&1Ut?((Ewb5ph!f*hGh=zm9Fnz5cX@ zOIs{1ZJEAw#EDM&DYJgPw$a$9pB8Q25=*-shmGmO9UpsLPPPi_%rpJjbK#`L%odkB z?F}2LS9NFj=qRzgWh57O9$`(}&KkabWo~tmZ<6o#w!3x@#w!OFD-VGG<%9c)nR(qxfew*|XwTb8bLz zv~GrAK2-&is<)e0ooQulRZUD;m`2D=k$c~7l7o&P|qVc~p)TmP2qhY$Uc&hsyH z{t>&LF_xrIeEzd(WsfSo&~P{Ih`+j;0?jY?I`^;n)Lfp*u^HX9zVC7phOxpoa12Pu zaWd|$Lt>c+4H7x_=olFIv1F{!4Dk0|aO%17b@p%bqI^cBlsr8H11Y&|0WAO_9D8^1 zZT{u=3yaycLUht|?%Pd&^K>lg$h+krMc|>L{+GAzqNWQTiTTAJ7kwFn9S%bzW_=gy>5Fe*Dvw%Um2B4*=*ug;0zk z$#4EL;tDq`JLvfnzbpPM4(-G?#px>@EA%d; z=<%4UD4E&fTk*kms_NRVGedqf`kPN_ran&>tmLJe8t) z`7uvH3y_OSlW*RSDK-iHd*|rVnXeujidi$5gtsWbw7=^hv zkivR(QFa|@1d>xy>@jW(jnh7~J`+*dg8&eUPF@pISkO*J#50D^LL4{*ZwNs8UY=H< z>{Qvtp$}9dzoXdM6#{Pwk7wkkN$ds<01n{8EHA*=l!i|Ne|vlT;QV!Ik}@DI`2OPu zzqojd?YD+@Rdw~VP&vhKaVV7^tADqCZ8?ov$>~$&_cRPRS%h})PMej0il7(PTC*mv z&%V3w=r08}C~TL@4(=H{i|LS<$@8_pzZ`{9XoJmzt+$vIgI|)A2aK1+nR*2bR3 zCjWypbEI0Oy_4nM7Z&tl=Az%ETyyV^&nnJ~d@S3Sxu3K18ruewpN`RY=b!oaG`T)r zXt*C{ru2Ki@QU|q9xA8!r7(aPB9S4ZI@+fa8iaCVg>t^XP1X6tWOuPPp~2gMO>o?J zYH~Wmrt_$~7$=SHRce3jpa#uQjZFUhBTO>>*WWj<8r_kLB~9YwmU1Pv_A|p$Z|m356Q+#p zzI#N%U9w3=bnvfp$|D{#$IW+V%P<-*>r6FCh%H4q zF!Pxi%n-Y-{BRt?Z))pzqxMA2J{G0G8+g^p(A6uks%{)SayXS35F=4KU;m zF|pY~N0s;~hbJ&8hMv9+&XQE%+x>X2!3uH>)I7#890RZ-0Hah?@Y9`6^~AI&ld2Zl zpcn0SP;sFE@1?5<|41K>JC1bx1AJQE$uD%4zqG0&B6yJD?iqTZ-Y0x*)Hp(6x1E7R_!+T?A&_sflzacZW;u6Qp zJX_}S1F3cE-G;eObcEe})^@FFl2KJ^T;z)Xl~1E$EN5B69$nliSH%~pJBXucRryJj z`VvU`e=be%nrgTmE+5p>)(@}AJ#4BGe#_fKkTbCN#6iZGo3q||uSU4knm*OzeHg>aTtCSW2l`Ux{{K-}6Y>KK8Q%{ri0%EMI!L4WE$C z@VMWqF<+q+Ug62!l+jgs;j={^He&RZfv$);-{>l7TJS$h zUim{;=6n){M+}SikRfq>eU{jGK?doA_m1;^#gOIIokj8lR{%2ffz;3_Hi0V$Np`y@ zSuf*9z3!Kz8Pi2x3z&8Wd-QSBh&$jHVi7DO6AgLQ1?gSc?PQ|qA|81HeZ%oiZn4+D z6BrQ{4vQ>=1~ZJQGxHc^o%NK*G>YljFMu(LoU9W;fzUY6jUp8>G?)|8gpGl7WEjPA z0vsP9Sc(ABc>`D1l1-pW!tKzMpJM^gsjRP$mTf(bNMUag#3mHCGic5=LD|O5NDCe2 z*EqNCwB|g$DugOHW>{ka^?1jdP*|J5fYJL)yN8Agt|I$V&mjVNwp_3RbqVe-w?J2` z%l?R`oY0X78ujAET~JM!xw!rsIb(VPbj?G|0_xujC!rSYoxFVjoj3IC8QAvp2C|za zAQJ^HN(7}Ku#AbTr(=+cneL;=p@!W>^PzBcXh>8h(WMYwH`v?+5tk#Ru9?$B1_%f> zQ>iUk!mh1%B+K_0vxJY1-NC~+ch2!JIidD`NDC*>Q$AU{Eq;*LwaPs*xWrrvGKw7JjH+lawe7UCm^y{1F zC(^yQi1tMvjIT~Z3Kw_L_erF2_Y)_vu1Ooyk8JwY=PrLQ{Wd*QOm)k>s7|>(`-yW~ z4Gt==2=NJ>n2LGr@VDBoQl?TFjNGn~1(A1G?PQaUxFRc{z>{#-_kGB*)pOB-UtdVK z6^@H)6fmb8by>)F$X}P)PaC56xpb9XC1cqgT5-mS_X2ON)_llsKDY39QpU=81+?91 zB}_$}Wj2&H_@G;LKxcY?w9CWQ8y*;O5AfzVl-CTQytXK9o;y2UN2{+)qje-*SBr(l zTKe5xR+?jPhZzTomP=kX)nh<#$vS4Y!nSL_{(5ON`%%GcS&4c@i`nFvqP7_w=dt z>fzXwy?|RfO`r5v=i!J zKK3+q-u11n-hu)wSNi&bX?|1x6wc}zdFsa+$*!y|r>kceQ8x&e!Ijvc?WbzV>_Uv5 zdxS>mN%p~)PES4c&*TF3Y};>=r0H?I(o#yb&Nh0JzU;%RqRfm1VSQJ) z_;arO*guvPDb+Bl@oLl26lW!CtxrNGu_6m$9?vxV6A!Co$*$`^(^yd4eF9vaHpPF20{U)}!meb*jX)d>rOi_WE6pzsHjRWCaSE;26-v%p` zyE30r;7~K83zXhz>XfuRQeBmthWdDy{Z;*Sw?xnEei4BmZJR$Xb9ryM)b~pgsoXsV zG_2eDS80jSS2MErIZ^^zeyr}4=jr?F7*ng%l>s44oNoS}TXEu(pL_eD1mpVC8% zLD7P{UxvBl-mRFoN=cS$f7CO1B)!#P+=2gb&M(o2&I3Gqt1O%v7pNCRGpC+jDzSIr z-N=;d#T8h&h&1+|TLsxt+H?Mg^`$)%uGZvI<#$Kad_5(-Cnme-Rcb`8@qU&hC;TFDcR+L#Eyi4#)Bu(b?eY&3;Ew&?ZVHxcBzJKh~NKRV_&jS|OT;D1e zZ~+gwkGt7bhEH9z5FO+A;mK}iz;64t!>ej;!M;{0N&kg$?6+ByXqU8;t|t8L=K~XR zS98+Sn-r~$3uRMIDjjs_bI$i?^z>(oGaFNRFy5b1&7FE&`$2!{HrDLojn}>k*cs5` z*bHc%F5g#E9rrNro09kK48Gu^wh_9-t6r}U@MYie;}(vP<8q0)o`B4D{KHPE+AjYp zg^Cj=_MXn{-jm9GddTNl%lUO7V*Sp!HL@n^BRk`ozLuHw%{J3$WwFT!ioED|%onWP zTYCS}_FYBm(kqXY?iZFwXHGJ%f2}{l!kryIKhv^gqfHgm^=qkn;i;E<&&Gqlr6*id zLMVJCdo^EKEaY#`_&q<{GZrqxa-Oi6M=okg0OB!&=Y^2`{HETuXsKPEZ2$OFCy(kD zQ`ra0EP^${6gdEAEKV_NG-{D}K6!dmK;*8f!ZzVJ|3WdQ?Ey^U{DOitGhWWgnAmtc zN+cY(5`d;`L~|kObC<6x4rIAnMF`u7c5G|UIS|$rtqz8fNb8SXe^+19(T`kU{w1yX z%fq00e&d|QwVfq3CpA$-0A_x0hvKM&{Pk3mv-r zsb6^Xw#<5gIO}_fxlWwQm9&ZZdW;h_Rr_jad(7RY*upY-fBl$xWtOewcpyZJ#Z6>~ zaypYq2kCi7{M`JK5&8b$&F<7>!ORL~Q5u`e-nW+BO)8rhcIEpptxdYqnZ%|VGQK7t z1_30ZaI;^MH}wfSK4tR&Ivlgd(!bBFT>P=7glfIO1nZWo((NF}j)w~;$~;XhT5aT8 zsIi+;8x&HuSKmjEUxW`hG*D3p{C8v+2!qE>7(FPSlK$C$v=#tRkhKABKcIV^j$g^; z)7YG$-o}SIgcOOrJAzh8OoUjUx?$lU$}Pam&;tgL(~1Hd!4L->%daTeNQJtmBc0l)V&i?1$Xk z!hYKq9TObhF4pVIVkxf*IdA*Ttd|z3WoEiOe`@$7wdc62?9f%m&3l9=x7;e>Sw|&T zbd73Nx#pt>#%@1{+D07waPLRCHQ7u(pNAM+3%fMSiXi!IHoCt^Al4>l(oD@HZf!9W zAAZem{AG+k~vBp#$(R#4Osdr?btl#zz1AN(i?Bo8s zOAg;q*3{OZqx4nSzhO3LedKq>0P>WETrc-U8m_E?$h%<{{CleSib5m&&Qwmu8@tMe zXgs!W>h14%@S7I6^rvN;cN+W35TZ5LGDwfGUD&zC@=eE$OFSWo+FjJEo~^x|m>_w9 zE^};n@~d+ylfO#PyVId{X*xOsTV+c53c55xy|}bZ&5wJ@oj@0YSF_k>ch^xl3~UxjO0&G#`(-=o3&54n#^aMtcz#ud2;%wU0{bR zxBGnmH?{h?XzIZ!_N`AemYUi@iCBMbgG)?YO|o-S|K25^zs~`N_e{F;JM^_wy6xw> zAAiq-j3TgSPlico9=+T(d3|jow=X%B;iIbWe;67ADgtf#v-5r>R*Iem2L|bhAJ1$( zh!Klumk;L@0#7A;4ixqg?dRYqY^Ktv3J73hcKem8Gi=ruq~_mjsr|5$xcN(!-|nB> z`8V~l?{ja6vi*<{@5~#bx+uOTA`xjBRT*AtnI(F*=5F1#UFstaPhbh} z#8?(<``o{~WV3!{vY?HAczj5iW_YK8V_9%h-hHRCV2V@2(4YRP6R2p=s|jbz?b=l> ztoi<3yyxg{@E9jN;vPNHM0Xa*GR1j&)?WTNp7aAXs(WYmFr<$i@7HJ$n{*#pPuJ11 z)t~jz#C+k~$@%MIg+rZn#pw-!T|G++jzSUU+4^0pGVB!5z=75x>x!sy-Q6 zSSSbF8O9;nc^we$xu0A%RLyKP{JZLja%;X)#Tw6@Iup}96>1{YicKH)2QkE$CEl!((fxIE-?W=VvCFPYPI+)6K^h>eqfSvmn&rHG9UTk(k*uWt z)p;XXL-8?#YG2nD%QgJC#{O!Y-P3n?%u-m>N~wyrub}+Mz)k~uei0e~p{OlA%oavl z7{cCFfWzkj9z*b*LoP$IFmP7QD>q`_XJ(pakufOU(Qc~hq;<~Px_@e)lhBQ6Ru-06 zkDdGrf=saC`(a^e zkhFyJowkn%Wa?{_gz)dFMqDkrv9%k@OoJ^ll0RA-wSj949b>hPn59HaS5pDQwSp8y zZS759Na^gWpwZPX($R++M)AlIY81=xCVct+^{~t>eLqpfb(ilkFftASh_k{9!)A>N z95VQ9ete=5b|K$J5tqiyZpkqg*!CLCFAx(kCFbME!@lWcicn=`B|6M>Wx)(X>T^P9 zE2XBUI`+u6_8g-tTgP9qGXa)M2Tnak!`4XxQ`Z3Jq5Vu5NsKiZx%&iXLY?Cmhs%?Q zpoFXeWldg&qtuOPmweCU*0J;*QF&`DVCs$$pt-kj>l-kQ_<8=Q5EYr2r`0<-CoRXIDSjEKDwg{EEw4i0=)b_v|J(l$8xS5+J18 zy?+B*Yx+&f%!@+cAemDEm27rU z31I5$GtkmHY;JCTl0B(Pr2uah?(}-%dZKDBJoiPsL?Um#dL{Pr7Z~7!8LvN}vj#OA z?Ef}1F=;99WHIDSGfWhU)x^?vVc%*tmLV4Y@{Da*ZwR_9q_VjsN|uiJX6~_Vng+wzIOSAuei(>GFt>hxeDw4k_*Z#ZD|MOMf zfj8Fmyx8dHF`r2$?hZsb;RW5$ohzXiyr1(nk+NbKfoD-M9AGE&hIt|PyV`Ciz|Vgz zCEhcHm~Dl6fBVD_KOpe-Fn^Q5(=2i5%3OXim_wMgcv!G9L{1DMj^~kjHB&iy1_oj~gd_(+$*agP$}X8XEFO6^8t+^h~1u7vrKrs|o$kv3Yrj*)eHPM?Zao?Fy{&#yT{ zKwU~8no{UfaV5h}oPFDLy4|F)Na9e)Q0AKbad{c zn`oG6l|1(J>C^b(_TZZz0CwqPE6hPHU`XvkFsTJ0-AqjEPK;I;5Z}JNJR7iqi%^1y z(TkF7I{<4GjTBv95MG4jV$?!kZNW}2;HNYuf3`oNNe;;O=fzT1(t6%p;sXv~iB6tb z-O6&$$~1^`5EiaF?C$O!b@4Cgir^p_-RfA-_3`LN`3kUr)@|4z|ETZgO)VoMqbMuH zL^L@bVIQlurs-gIZ{V2NAt2znH`A`@3t)tQ zhV1O(a#;RshbkW=+LJV*L!2+%AS^({} zj_Bubf}F(p0`{pDhTjq?F%$p>T`D&&EWlqY+p3N&-uL&?fhqmhW>_##kZbLx0PA5< z73VpgRZXaf=Rt@HT;_vCkg~EeTvC6VZ1DWjY$eZozatO^#YU{XsC= z4)ru`+I3IjXsr5|4VNFAQYNm?lI%*@bfoDj5=3lXxjMB+ya53L?$0tQLJOTv8-0Jx zvBNujhm6dZnQr8qM8IlX;JO!y*wy9TG`R=sV?LaFf6hri{!Z#Kd=uL>2=)qu39Tt| zqxN$e_VPpg;mB?uPutf!$rc+ELuls|3Xx%v1Qv&Z+!$u3H|055V+xTX<`z-@)>CGc zg~=iq(fvTMcEX*`K~vJ??-d9-!=D?H5_PuWxZHHj_AB7bQ+1+X#S~$JSk=2CqQjWF zvCzY*{v~)u#PkKxI{mxdzKf0Lf@k=#2em{7U5|Pqt*Bsq?i_13lH0BWbIuZTQHX`7 za;{+A$n0hUHQ?AP8!H0dhRTdc1@X(*(9+KI`fpsN1=tPO9xo)vtQ`mG1!lqXVe;}% z`mLa$pjyT_{|o(t9YdpK~0j+4S13y$t`mI23cKt-vj_^?M*500Sf5^cGKk z&8e5$E?$>1#4*J%qGK=lWET*2)2|T76E{GTFH3eghgmNXeJmd2&{Fg7d%TDo1oIjaErMUF`EdTKpyyF2qGNiO`HCEAytl}rK zihq*dsBEJ>gkOENVC)^dJdv!aC|=v51z5e104$OE^8`m!n`9OfVt0 z&daojw1g@#V@4#oa7h;us!cE0RRZ+|HPC3?JYBsIwR+fi8s|FjV^?lhQ0Pe)A?X4i z-y_wQG!Vj{cmBd?`<}5Pu-rl+&kKT=Iw5P~Ddd2OC_YQL+~&v|@I|B_D~IgnBv3PI zA^psdkP!0k$ec@U#9l(ginFW?&~(w$eLSM(RJci2u-=B-F&!$F2fKB_%CNxbOnuL{ z><9lnp-pWfoeLh*Bcf&`KY5aRIB#X|>9?KP-FWU&C2T$&7Ux($ipZS&8+1*xOHL~x`WisP_KWmjcU1|zw2c56NwAW zC&Z)FFQ}{E!g;SXhqQtH@8$~SvrE&1LJ80Rdh^kCDQYyl(o!2r?Pf%Qmg?%`7Xe!Z z7Heb6kAg(?7_)&W0%0l>E$NY1kIxjDWHrD44(0OXMtebu(8y!t8vmu5(L07%|N9~! zbuCH1{>yIvUw&0V!w%_+3z!2{R(QiT$Qhn<|L-@?j~2N~uJieN0QXAT528h`qefB- zc*WJ{xP#Q*Op1C8*d5U#e&qYXf4*OSk7Ur7KQi*l)xh|5-TL)AK<&jY!)VTK$5#0+ z_Z4(Nnv#xhL{5m}su!m=tQX94Ou(R6k12yyHe#T5u@i5*4b}nyOn=gu=gxl)pa7=f z9l-4&*>S)uBa(|lwp09To(l<45|Y`*WoaBi#FhsBGeq}BkWW1MNlhbUo(&2?VoHVH z0;x?g;l2$DSU()m7Rc zJ7-yVU@DDbqF;%rvBM;_iYUbRyF-aGs-cVO9rE(sv&|Jd0Aq)#uuf((j+L z=<*%={1JgW+|p%8)C(hur)!19^|oO@;CYn{xZD2E7Pu3Sfl$OCiQvVElZX#i@T7q8 zRx~_@qDA8Ib~kU4yH3(n3{6YIfJB4mrth#Y4>3UmJ^xC}tQQE^AZggNPkXuL{ xBD5uhY_flAwODe*|NnSo{&~;;TMy;R#+%a%lO=SXK2-QmMOj1XxuUt>{{aU*`a%Ez literal 0 HcmV?d00001 diff --git a/tests/triton_tests/plot2.pdf b/tests/triton_tests/plot2.pdf new file mode 100644 index 0000000000000000000000000000000000000000..56b835edb943c428073df1f2a1ea9cc52a593485 GIT binary patch literal 16044 zcmb_@2|QI@)OV(EO__=c*OZKRzFbr0S(%ctOxKW^ONJCeWR^JG!?g_1pG7>+E6ewbt3|zxP>3*g#D~5+#L&370*Gm)62ia3tK* z?kG%N9*!^x^diC$su8NBqAKUPB5?~k%&GXa2#|ALl`^TI=FZ^!SO%t`gqzK6G?DO zFszyeV1*b!f+Muu0ShWWa@8NXHr)C*c2Iv#0OA|s-iHXdTW23(MD+Fa^RWl>f%pgG z>k}PZY*jo1z=%ljheczhWZ-Bl^xqoD089+D_JtGH*;DrL@B~*VsOs;+fNuYko*vP| ziR27N{TNZh#SP3Cj?i!eLQo^xdpZyywtY!HL|b=QKu(L9uIFxP#)j_eJ%vh^jcDyV zEDFu~u5cw5ZN?>or6f0e0@ZK~pG(j8&nL1kV=wRoQ}y4f_#~B~n#{NFOe*t5&o$X$ zwQ+i9jo6F-Ep$}G;LJk)jQ%iY5VYAiZw!DSMvL3 zA5FhqlX~_N)1ez1L{ZI3ne%xl=~GOu4!=dSvF2=E6Q0^wmOxLdncd%|>xUZP-Y9MxadqR|!ya-y<0Kz-df|j_xmeGos=ld~l>H89ok)J;T|vkE->2ptE4uUq zAMK0wK%AWj3q^1VxQq1)2o?q>+!Bi6jlF(o_KAH{tMelDr}#0B{5w4*`hE*ftNdR! zvH8VKxGla4Wa3}m2j|EmsJ2(?`dTXYMiP1v_wb$n6_KK zg*8v*<_?uxCCqt4eNqwB?6ic8h#a-bsHDqRH4ozV+=`!nIycy2 z8ZzY8a9nY?!Ps-n^w5z}tI-SJs9sK;Ytx}{?N_0ueOfv`xYv}KUx{vF;r2r2dH;cH zE%)UqWMFuzwyn7BnTAJZ)y}Co_!|0XA=&Np_V&%fo5b7O(jVR~EfPxqu!8%T=}QG#n_dNUH$mClBAp~^*NOzwz_ zMm^GQMQ3H(N#z~%Ni{HXtNy~-()ZNg*p1F8sh&B&+@5Qv5Z#u(6(NQh-za`$jA#Xq zwxDZpCTrrh=IojYoi1eVndfre)7cC|$GEnJQo{WO?@hd?i~72UbaBjD%5PWs^~t~~ z#c=N3aTJ%cmKArqQa>B-%CVvLOo^%3W2QG0m`Hz9$V^E>QH{@unyTG#x4Z0#xli9e z-<;~pe&FH3^gd5_Nb#Fvyw$?vF_&2DJ(9)U?UAGVTLtwUnk9<6J6`vehn=UdpU~02 zR$O8IMGM7vexxiVw&K8)LnbmcGHc*qlH<$JvSFmN@$(@p%8ZBD8w%^*h`pwa&j6Iiw&Q|VR#*CWtn^Ub0E-8kdQBf-n zbeEIjM|bqKZJW2N?5IAj*)`)Pe>YBGI5@&MOz?$(#Boi!Y6sp_mTOJ!7h}9t&Klfk%{FKvXzR-BzPYqNsO3JzuJ_ZAwqPIIXdG+JNqu~J z{J@ET8A~=Bp_LkOieN?)drGg-9(nr@d&L_O4vDf9wg)d6AE}`Ka&IhF)hRc#`igSU zVO`0|l&ZMhev@ebbV=1bSg4cqxTh4w#R{U9m7$jlVH zBk@z_xjnosIm4?BW?H&ZX^%C}9XoSSHFzK=)4{u-)$3sUhZaS3&q|->Nt+J;>Tiz- zF6T6#X6jywy}cT$#dT8Dj;{H58^L(6S?Rj`hde0)#|_bDedXrp80y9;=0gJ!*yf4w zQu$~bCazSGj3TccujEG!jED|N3CZz$6L09pGTA>rbxWsztc-i4v#hsxAc-OGlQK!s z>Be=1^vD4RRvPn0*Bj;na)Aa@wxnvkArAAb7j;$xdbe&_9J-fy;J9uqR!{lLm)glA z%XFV_LrCF=oBM$o$UYJBlPGB%{wI(iJ5z{g3Ag@9L;vp0e>hYOQW|*C|L~>?DY_n1 zC`Q#~#Ym}bjh_cEd??skn>l)NSVZ{Encaii)sLpv>X{q!CSlJ;oP4w3RdIe%&aKOCEOC<_Nqym`x_gt|2#JaD( zCztmrPKsWdbm$UM6l;Iw%gM*VU*phYNA;}Tj7o#~k(9@d)BL>ci^U7CFAcb=bgz_) z3omI6L}7eH-Z{5UK5c)M^nEm*S~u~UGyjLN>mGelHuc7`>UIwMeac*75AI(!;COPo z=?>$zk8bL!qQ1$`W{xyHVsK4&+a93vnTGe$$R1s0)4}25{qtAl2YZ(JEc!`&e<}$L z0hf(n1}XF3ve!>gW^07gQO!399N5CqyGD{#7+5D* z%(Etp94in%&$4DODzypI5c=G}G(d6x&9r*5z7GvQUBvQ*e8j?)i9-tA1xDt2@9r{n z)eDJRL{3sYp=8%wU{jrqNn__jA4g21k5{AR7GwvrybjcOSTzxwFt(Osa$*nV4yzuk zGTL`O@5=m6!l;xHjzo007?hXI(3uo{jn$kap5dqTB=w!q%bjkP;T#`z%0$lRVy9Uj zEF4R9J?_VT;CjV`Nf*CG*H!kNXUg3Il%IA8vEE{e6-oEyH(b54t5jCX=B)6mrO@ZO zX^zK(jI>#ir;e9(ICoOwd(PH5?P$%`y2L0c{!x?IOG$R@7zHokR+|u{1UZf@W^ry9cqX!&x zO2Sn&!glt8lnB%cA^Q1B3hvUL6HY#dzDgX3)VzOnOq(MtGWI%+ z#oW8Mw}*N5_MFqJgZ$3cBbVLni#2rGp4?^9*H58aRik&R!oJ)7{bcu@hW;-s`}s{+ zHxU@LdHF+NsJ{@{0YVBjKjXu|NJZp9u3@FSmoHc02Ubtkrey6;THx<^!p@TwkR>Qs z?$|bL;Q)6WGuCUl`{^bN`<3M}UV+qD2HU7)ufDQzj|&Mq*8J1D&1mQGD`(6PZyUI; z*LP^UlhiPcy-Q9rp~EB2n!O_4@9Je^GfiY3V+5(HDHZN(Qn1vKfS#O|Nfx;fJO61| z!k=)0`7HxJo}Y7SMJy;Be=2`AaGr9fkMf%J$O^A|-Lx=#7yb4=L%Vl5uOp7!@h9=# z+RNsvlC3Po&foP$tY}-dL0RRcTq8rT62;5YWfxDrK4K-giM*iQ-9~v~3I8oG1p~PR zHGamogh<4K(4wFhuSTh)*K;#bFYQnja7N?@#^iV z`|wg#Blx2rmT1y!5k8q+mL~_Z&sN$#p-~;RHRlbfNP%__;mdtkgzqlmtf`hr%oTxz z8BBuc1+n0Z0ox_!+VPemDGC$tmqtc09fcQDAMA5~kP4G^?4_1qyFKQ>t}pYjPTMGs zI~6^{TjVbgQYPPiNp{*n|52G>*5H!^!&xi4y{eB4IqXksBk^{w`g`-Q{WD(&~v76O|{5ZX>}kRXPD`wIyg z>3Xy%(G5>=Smi!mIM*F$ba15pi>IrqQHfZ645}3w@*wHqdlO9ixYxzMDjN~?z#))1ScW(yPP9VcDrq+dT&r%<2F%C51ws{E_qZOSFhHrctHR2-gG@l z(^d~-?jv7Qn^X-ZDj|aY!hAq1M>V4?)m0s&BFVrz#)2L`} zMk}6=7Y)jnDBqb~uYtK|yN$JPxv8CM=Qp4FZP?4M0yoVB&fT(8i$CM$t8j<9Pj9}3 z(AUI%-);VUU)~K2MvOmvs5a{8TP6R>h7Vp<7x!VN^W8Vm>F-A+}k(dXp?gc^xr(E`i9zPsDLAw;{2?{fB0dOgq^1=$xa+2b=f}3YN!anbN zXteOHV|j(=I!o7AtDHY2h=ERqHaw?IM{~{>A374C}YwTb6FLkh*sG;bF?;3lBX}!!y`f&3Kgsv^m={@Tj zqqlJjcxYVY@Syy*u%J+X>VXge)v}_mSRWyQq7`j=8vnG!Hdo!U_}S!Qj*CO@$19dJ ziUv6?KECTM?8-OndeBa2-tQ|7^iASyEyb7FZ_>1L552>SI&V2x z8>@Z)(qfQ@f4MEqNMxN4KPhSWm0RpZO-PWN5DRwf&j@y-`$$ zme=;ltqFI$l?u~r-Cnt3ujku?=GP3+B7LN(xuO8>}T8vvsSG;l;thjyX#XG*2IS8vy zZ+GafGXgHIdmnbIg%6oL^P=qzB z&ak4?W%T}$J5&}g@1&I~49)N3^_N}TWU`w8cj>=?o)~mBSOm+l6l_u3_rN9{x57lD ztdtr`Vs5KQ?a*Ruen3&l7m$r4wZ6_v?O@`eJQR4n$)1{uQ>q8Bv zQ(L|eE*ScCLjJ4{5p{bF&-BOv@1IfpEkZ}Ih)rUOBM=$VvDoEM_2wACRBjgyhx*>k zj)T(`l%;aU%~`H7N8j2LCu*L^ZEe~0$eK{%;C(61&%{s6ORqcaqv6P^zj zQ0BYaZR~-26PJj_xvTG(mt)`uM2t0C=HP=1v?B2#X{F*`=iNT_>LQt`zp1ciPQ)s` zWejWWGG3Zsn;jqbp9-__3}n?B+}57V&&3n&f8J{CRp08JHSZ|7?pxf~cF(^$mf7du zcEm1))0ZOlxoGi}tCU4Qx9sRyOxr}!m#9a#Z8GXk>CLb1>+Amfor+h~@$)94+yrc* z@RmIgN0xN|DWe1IKw*k79`^=j%EiEenPlw9IlmIee#^#P>I%) zdcK2)(S=r0ejQ@wwu^T&RZXI*qTQO| z+om5l8oDK=ee8Y1{NQ;myB)8mL!RZ0sxQ(kl_6^sN4SF>Hc`bUXbz3~izisY;C4JU zit(X!$Cqh5PB;DyE1 z5ogPVmwRCELt-X5mPekYSE^z&oCzlwFk)AB@jTJRXM2gU*XBNN!hB|OCsCQ6=B==V z_Y~OlSHXSSShyJPWwFN!FzV=UUAyL!c=2%TF2$zzsvOBF%#*k^vu}&OOxKMQMUp){ z8(tmGDPg8!-JOzMFKBI6A;re(f>YM-ZnTgLt*bu#eb-IFT1KT$G0ooDMhSU$UA+ps z5wAsz2R0GWCh$_`ZywCJI;KV>omy~_4fTmVpA-#K?~mSP+9vE#BQm03#Y3MZ>Vaue z-FkshBFfr3Y$+YdV5Mo0aDMiZ?X!j%>iCaEMjz%`QZLMwy^4BEL8-FS@lV#Uo8UGy z<}Y@r`UV~DKzm>6U^rys^zJ%d+F1D;-%hGNpijwg)_s+?m};5+fVeD3@a(94t$KQY zn#-a`PuKN|@0S=U{nrdkwp|sj(CM^#>`{T$MhrH|#AmqPN|-7MjxfIzurNOGL^SM; z0{Wg~oZT}(uAmDYLh;w6vE!}J{8VS6d%qsI=^4?HCM*P>EPoD zTNGMGgfh!oS5)y@vEr8kk=CP)&C#Q((tWTbmT z5Fp8A(V9k+tVERJfuh5A)F;0WCzl5&=TZw%J9hJ{PNwdgF$#VFJA3!DgypNE#rcjq zyI$+OuMqZ)ewNyOmenpM?lRX5_pNYl0lEqmP27|aoKmlAuPu8-cco3}-F`cN_kGVI z{R=X#rN0e8D>;gmmFreKGHAM&C**$j)bR^1FeZd3i_)EGb~QQ9ZT!+{J2;jHBkw7W z&uKff=i2e;#C@WKn@lhdo=UE~aJ5Kok?=LO`*lIL`PIIPk1-NS$`m|F-tvZFQYBjb z^Sg68a+7h$v-kOu!!O3@>$_ph3d1zRbt4~@$)g4;d$-VXdK0{54CsAwcb*#llyI@F z=)y@;QskUBLGi{xt4jyv{4Z*>P-V|GA3fsfKVp>j?nR)CrQ?#qt2+x_-OH$AuTrH; zQbBJUj`5vVJ*jrB<|xDa#yQUJuE0q4tax^w-uIuW&!}wWe0h4jjsQ!IDtW&(K=_mG zMA6OU&(pBi6J6$))IZn6zr<459CC08f8Dm7#yK?{`#s_Q!LMIwS>If5-9){cKv6XI zFCJNa!@8^B9R96Bh{Y)!3u)*WIPGHE@>M|W?R$VxE9gI(5#H<8^Y$QFb73m7kH zr0dfHXB_s0S?6L_XeZ>~kZJ?&tW~9)Ro7s)Swg*ziqSc44zIq{*IXUnsTBG0dTD25 zx7KjGZ1lsV*qRi{$Y3>z2~@LubNGDedtTnvvd)f5+E&EiF&6PpPa}x@<1O4nww!IY zoD*i{e*9O01B?rrN`wZ+W7ia0d?l$Rk8AXV9+|j!Jz;D$WxLCAru&m?LFIaKR|?8^ zyB+X7yyu98(`{MaZrYAd*9KUP>L@uPZ(|<$zTEr7RO9A$f4tUFD>u)CZ)Q3@ksekW zX9SZjmI%n_7Ibv(fxSx1;q?)tvYpAkM)KKp(EIX*PTaZJYK*qy+Xk(Qtowo1GL_#M zjU8_Vigf$rKP{n~uW-`s=^>%#%Ul!lJmpk#DLc;gB?b&31v@yf6Q=>|XSZmLQsKgZM<1?SiERMqdhB;N^^i8>q6qWHx(!c-hYOn~Ecup& zD^x}k=dwGET9&t0a&${x*X(AkmV{9ovMZQnAvgt|b}voK7JKPj3_C63q8d}=6*W;Ep|kiH_k5Z?Q@n-3oy{DwNxixa`Yk6*e>s7InC2aJ7V~|*WaWp_-Q~*qB^o8lxL89F}Q0{DGao+!$?tK3js(-g(dEpoEyl zo(q+QwykZv79T&bn7lcDdJ{!#wrltc^oGITG(d(naj0IyZsD1vGTNl)O)xqN^A~HV zdZI2fH9y$e<>V(cw((7Qqu)gOuFZv>_h7eIh%^bi9iFEkV0)ddV;#S zqDhdEDND!L>9W>`{LgbG7Q(f~qWBGBGj}FS5Bh-> zB;QiNFJ<4UW--L`@mikt!oIrE{Md$*1r#XyO=hzR%*6oo^xr!UBZ3D2N^ifxLg1ku z!}Eh+$MN+pQ^_C)l{d*KRnEf1|KaIKlUuW-)thNUY|<0Qh>^|o9i?n@Mp0VD^``vK z=Cjb(t8q@$?M8+7CkiIKJHq!h6<%g&Tdk_N{<@tfk=Kkn(-l@xJ=C8y>q*@ zz2agFl0*H&hO`(7!4m$4d`plKxN9!8RRaAc$ZU!9p|9-}19m?D>(nzwHY_ndrIq&T5Fsi*%g{8STJ1@*H;)h1q#h+g~s<<()jjdnqmlKT0{7zt5eMmibj>)8X4;oH{#PQZgpvtd+ui_|=l?!viFd#|lfLt8 zUpb>=8Ozr=(dW}GFHzi1)&0b$+h(n=D`Y8mlvy39fmGti%MFJ-#l18`rbJEB0zWQ(e@|f!Cf7T_f2eC&h>jJ z;h(*FQ6D}2CGPsv)aqBNBQ}=Ne;PP$lbrzSZ;s{lleC#t!STUz&ix(55tzm)`MXzY zo#*dfRi*U{`#J&ll9%uC7Lp=4OvfdNpW8j&SP8d()hn~sQKQ(5QXa{dG|TXB8|-BB zy^lswR%vt$%I7z-v{})yIb`rPG{Oa}6IV??n4ewhu9h1~ti9=YL?OR1W|pp{oZ?XE zx7xNa-TSMDpG;(k3EQ)uUJRQNFv82_ZM$VM>vQEy<;Xd4FV3Z}y{hl#Xl@D0zu08j zn*e##-y8`mB;I3gLeiw>y<3_+CkagjYeyRKw6McJGi zZ+R?i;=01oq9iQco@K*aSn!xfGdajkhRcfg=B|WS71N(t>rxy;dRkR#g`Go#%Zy!~ z7kRTc^WbP^L}zj?)o;6fvBYTrn=*o}O!2K@x*ObW+`o3YwEp2L&7EUX-#3}nCiqqQ zFLoyiadOO|NU|<9olgkSszis4o-7o-8T)NXXt<@<#v+-ry0K$IvE+$Qw~*ZNZRhwW z4(+nIb?w39Y{uB^xVua*M)dm`iez>)Zo<3 zzqjlA-LAM^Q{H=4Q8!chvlflc7W2P0Rt(CL%Bg+c-+2@}m3~vwH?*dz9&t6LF~%|a z)IFEX8j5Ur3f2DJVnfv(ygg3~jrW#6r&c^u?ZGl)e=97bihn8wZnu|PLb){KzA0JjXXv;9#<<8Uwp)JBcy zYwzRYMe_84gEJW@co$*pXGbCjIzwXscf8-$9fS+3Z0vfyy~=usup|lzLX#x1;5ZQ_ z4c_aZ(7=&{IALfAeOr=`3ptKh3W7_GmsC61b2obT)+hw$PJEghao^0U=S}j!W)i)EEwnk zbms@e0y+x>?0|?a5bKOM68b%CnyiYSZ_9qE`Zr~fCHp^U28i9w*2x#h z=C5?5>`Rsz0S!)dpb;PeN}@4fHBhzn(pnb<5{5AQSp;F`AOPOQUfILRjR+V=7?X(Z zrcgoOHh?@MI0^$a@GB&qJU?hUmT){0`Trxq{$KmS0MTNB6=Ly7xC}-Hj>q6(1mN1F zWsq>RG!Bj-Ajt^__=N65^)bK%$T=`EQ2+myU_b9;KwT(7ef&%~3Gyw+#3696%;W!-V1Uy#Kz>KA300Y3}$T@T$0}=@Emj>;j zJ_#}ekgwCAoIDwfGz`iiCZI_{36H{%+v32)*ZTl=*SqM-LjZan@BtXCe5GV{969?Lm zh4Pa*C@0$wG!hoNKtdvd*oLMCiU@jm4Wlh-+I7GgGwMg6Jj5-gMYCK5y{2`dV*U0(+k8D5(y*U`_TabPk{eS zkRhzw0AvUMPGngFyMS!sw?wuF$R>U!vaS5E3CJFPPY_?^YtRQ|8$S~?b376UbO?Bb zY+#c_kO4Y@>;kM?zo(xz@iT{n0`v|e=TP~#MAr5CeW=bq39<`Ft5E(cr1f`PX$uD$ zBYUBLmTz*{C2&i2aG*o-Eqm}C%6iEG4)jT;5#eZPqY9kI`htaCUs8Xp+~^;!4O;1e z1lOC87k_eG0_fcdn8Z5S8S*LQo?L+OuG8G$U>4+Y+`)sgUh)9;@S_B|3Gl3wD|o_z z@YhQal{^*-)mN0pfaz?1ey-=K2+IOjz=uz)6u^9rnk- zz|s8}b={u=$wC!~R;Yq3fUCM+4aW zdRM>h|9(=WjQZV6|MO83LRA@bg0Qy*$2^dyQ%3)f)qqRkD!WBuBiB% zclqt4=ogESgnmrn7ZDQH&w1%Xs|>gXD$xg1{2^<7G#uFdIu$^fbvgg)fZXC|59@yU z=Vx8NV3UTYk2}CiKY)?66b>#4{ri~!D*73?4PmhVh}*^(w7)_dEKu4BDi}vvNm*~B z+2d1_l*w&nwcl#X&3W3TO>pAx>iOR@N%(6fUp>Yd^xz(24z|ChOorz)6~@{)dV!;T ziY+(mHf8iT&Hebc?tdXoY1H2{3OMblL~f;^v0fLe$q8mQzDEyUQflXYNUrhPgGWVI z)XJ=C1(A9?pQ%?L+EMxzT~m`&vt@}{Y?JqQWIOr)W){Ht{r61VlN6Xk;f#?Wyn#X9 zNcn^n&z-1P)cfzHa0_)UEf?Jvn#0QAg-M=+%}rz4ich2HHa3)%KsbX^=Ev=00a^lXtoAoK$;le{+k_yz*^ zyVb?@Top$2__xmQgVjYMOR7&IP<2m2Ul3sDH1|6nV? zhv*0cfB**r$l}Kj05cqL1CH>2XpjX!Ke)#a8VZYtmeqgLFgOswvyp~Jf{AaaCk?Ie z8)iy9l3M2F9+yK_u*cKd< z0ATX({!m!VA3TFy(Vy}|5s<*`K)+uyg*x0w!=M1L`gc7f8V%mGZlvMM4m#<1N!~LGZKphZ}K%q4=YP|1C}`)jfd+xc zjWlWeM!6ty1OUJ`)B}5mKiUG!`e!{H0H7P&68@a8H2P0jNn;_*1O0x<0ASTW#+3n2 z+{St`xIgS3DMR?f){ww7{*W^P`-feDT3{2iffs`GA3URwD6qBvcRdsm+T3iU{V8V@ z5(3N{>Y-47%nb!L62I92$;Z~kjp#%Es@&KmhzQmfIKs%&6Iv3<=tJAX(GyNy63OQR Yz9d^85_u^?VP){(<8)#50~)aZ1EOsslmGw# literal 0 HcmV?d00001 diff --git a/tests/triton_tests/plot2.png b/tests/triton_tests/plot2.png new file mode 100644 index 0000000000000000000000000000000000000000..94659c0a41e63112a0a69da5cfc3aa655e9219e9 GIT binary patch literal 51996 zcmc$`2T;{Z_bo_JR6r3$FrX+%5|F5Xl2i~0f`ST2l#HTC&OtFCAX$)Xbno7K?X}kKduP-XH`4B;rJ$hL zsB}tRgMxxeje>&m+bzzBEjijlxAt52qrsQ&?mv zr?=6uZ;^lD$3w+lKZf@@fOOuS^)huP6Q zQHvU3Kfe&grKKg8<>{s!(r#iD_*w5)@vh#VAAQfI>BTS$Tr{$FK43D#3SodBQG#jr6VF+xF0yRYj&K z2JcyW;MspY@}=pO zPwnC1tH$La`{E>QOW2E)9~|(QX|c$@*0FHowkDPfZ^WbAd*%5xe0H`?*Pf$CkG^*v zQ~UAaJf(!&6i)z)AQvz1y|-7IzP&jAl8WWsBcT{oz4s0qSXfv(WK{kwPH7PX<31Ug zJvKHrfnVH@vX)16GqJK(mjf^_cb8~ZVt~90PA6_nhho8vE z$lMmV^Pg|RUNS#9z^j`xEIov$8U6d~Q~qqHyG{g`Po`Pbw<2%vr)u%BiBH*f96Wf? z$0<}#U!N{}TH4C8sLh0sd|3HFYTo;?2HPo8k3rKSC;(tgZS@~rm7@o|^G&tzWa z33pUP4g6&>^a#Z_36G47OyNIt=mGwFvO7SqKXtGyNI*BoNND@iW-*2Pu9N-STsbeC z%^shf9cWOO6^JO@NaH(tO8me9THlYwAAQ%+DMYp>$=K7DhMz8EN~NTx35kq6nPvV} zJI8v5_v*4Z`6IjZm-=h$Ar1s0UT}T1T{2+mon1PM;**p2{eptNeEm97Pn&W^ihX9Z zqiQI<$e`KyXHU-w2M4hqFSP`%&yH582o`*<2uZ`!o!-+sBA>b>&CsybRYaEIuf43n}ys}|K&2fguwwF2&;dnYwI z@?9tYPEGX|-=O~f>@2l~g+*(wjk<$FcB!j|$j6m{lvI@?0b3x|7pW zujK{n<~PR#6NN4v&r!A+X?;8Y)0e@Hv&dD+Z17u(m`x`uK8X3CUUbT7#!Kz*Z08mh zj4dqg=H`l@xVz~`mFU5SXxR?ALxy&Wx*Vhw+OoZ!S^qKH{)14bb~WVM7$RzpW4m;+%V{;``&$v5NPRWt=QytD10RrLwv*UbrBOUL}irc1`y*!bDG z=UWg3eQ%o#8aC!!%(nc9T#^;%zBFfx51BaATggYh+|p!i^+Rz6U0q!R*S>9c7}(PI z80c6Inj+dgR^|uRtX*sSvzT&Wdgw07h_w66nrwYUH+2Pz}_9bu!KHcnPvw&e?oVv{XEl!olGg#GwhK1FS zPu!jGC@2fsU6q>X^`fflwXV3Hp59#7>dIXlGyV4tm&o%^|Ik{HM14rIZppN6%efza z)cISgI@`)z4@2ysD_;_%-L>C4@GdXTre1B!S!n#e)V_66G~cdIxnp$f&ySa{3<_c$ zhZ_4Xdk7ZJv7fj@e+ggwy-ADr%KTlN+)VSYJAbas{-~ZBY?x+~pm3iVzV!F!M|?0f zstg4)zqVYAh}r1i-K$G;;~rD>imDF}-;WWs_*UTVT$3OjaQE)I(9qC)1a1EKmt9k% z9S@_Uw{O|9CHl}6W*mqfM6gclR>q_!Pj(<0M@L7aRk^UQgfCcm0t>>`Wj>mU*p9RwD*5zjx^JqvS^I@nx*AGi zMotcwj7-9Shsm+=_K^ufh8J@9eMbMOpWphxlN_r{ zlMH|bQl5)5rInTb+1X;MT;ZK}kB=vFu)Dar+O90kdn_98BU>357%&Wf7{75MjcGrd z_mc3+$_iOib#?W|bVKzbuVo4EmHZV2hPrpQCQsE9ui2&*6|Hha9=1Sk-yvi~fh{Y= zW1=RBUMg8r;4$2sf#3HXNc7(0W7k({JKQW>xUz6Xf81qtd3JgMZ%@xA6Nli^&QI3V z(1@1x_DYDeevy7&%LeOnZ~8G7t-`9aJo^#`5!-ieCzh@L4Qpg<*yILQvpj5 zvqsHtF1f`IOE{0R;u~aJHPGzZwadfP)2i;#v9vN{^SQsjZr!;<_4mszo4U7Gv<_cu zKZPx({N0cmETdxf<wkjE0G0vhF7EF3%L|U4G`ynY6X^tqN22 zefw_A&AF&@Z3TR4$+;3)T3gG2RWUYZ@9gY+_u+%Z4i!~Z`U@8>kgDkezCsUEZF5TQ~9VHD7;YhHQu{xi|d&+oqGLuJ>o4$PE(q$2%EiIy)wD{4|GGDmq zfG7w1el9M%=|ATwl%ozV%VqZJ+KR-eddzk>ztYYqVZYEZSu3j^yhnOMO3261d-(wt z^>-PN7F-}vdPqYmmE~Wh5SM_! z{kB{i-R)zuLruZRm6PLTY$jqmb+RoP6FnxB5_;D?(@O7@Hgg>>Wtm6r23TaDd#=H(5`lGO5;SP<&R^;b zM=n1LWN^?R{}d9JT}rg+Q}@}?DSR@()fVJ5AdT`W!J=2e`^0CrWp-`__%QICKHpyx zuYzyXT;wgQ7{pqUITaZb*wkK8u|Z2qYx&~syLUD5uK3~(Yfs3(^IASKGc(iD*@Yxz zef6q9TI$=khb(FnN*f!Q@qWv0EiEmVY;4j6!{6uS4bIpH1qE%te=@)c*+ofBthF{dtxnyCX2t<_UFu29m))vv#Jvy3JyQihOnIKq1mZ@?i zzm@RNpZ@+Yz>)$Js22P|uV1~Ia>B-+#y$W;pf5T5N?Ta}xhBubMkHGe4UIRsxwTPV zmvWBg17;vaR8Jo`?bNGc`yPo72htS%5mK$byKa{G#+O>@=E!15!dy*G;^ra;PMhDn zd2_D+q0KwTp$loOn%fRjCATzR|M=s<>}js!$HbG@jjUZqI<2}lmrexOe-*bw+IJZ1 zj7Q9F-MQ2M;lo`jk^GlX2brhiNu@#V@y#c;FClhdIw zfhPz0qptMlSI-=a+4$pY(Qna#m|FJ`_bp6J@(6ZhL>XYup5OIPP!f-n4TbHFKjL^? zUK<_Y^iY$<;>^f<5xTq&>5>&qfH0Sv(!SoNV_8o}_vHM`3v=`H9hwR5`h2415AkSi z{D%$JDw(ep`(QF@ZyE`ptK8h&{QD9!HY8K3Xja1-MQHr*-@S(qMc5wtpXkeq zzfn0RAwts2{hf#W?%(G;crfVi-|N`a=SgflOs>WC^^A3$9Km|Fv0_$RkOIt*Zjm{a zS2N$aPI76cs+qNBTd5{U#UP#c>k5e;bsB!88Y9B2+yGQ$YiGxfp5)HPT_U#K@_+_2 z#h%;NC{ko?Xq9;qT77{(MfpJq01uK)9l)%0+gwku?BB^rld0eJZU$cr%Na@le+PfO zWT};&WqrhZ{OsAY%#N>KzaDJK;=Xk0QgdtT1p3~gR<~=twRC7Ra5nF$#UBy2tlLZ4 zgf8#Z(^yfo6xl*AWF7x7A@?)1TX?zDxqrig$}CT5l0yk3GqF^Xm76;-KagngtK@c^ zjAyKm4+ZM|lc<6I!hU%y4)VdCE|1BLC6y-*BI|mZWZ}Xc~z_Xp-thvtSnIkEr+CJ z3;^ENty@36x%7!gD~%O(!sO%4wclTA1$TKZ-rvS|K|4D%AtA`GUCU`w(AvI$cH4nB z9`kl+>4HBR8KJ8}Dgahd&;|Pap1z295vUyHF3KBSP8$KHRZ>?^Mdp+^a%2Li_1Djz zK4`!2RN0pGH-Yvv(ha0h%*G04_h7Y>Q&V-boNINygL4QD~F-;J*yD(uQDNbesv?FX#yxi8*X zTv2fk5EzHj2*pC6>`1>^=jvk7YERxk0tXk@N%8gKh6B~0dcY$oo;$|`x?&$f9XoI_ z&LrVbf;9|^NGNhlbE0XOfHS4vM=)sOaQd;Q<9j#UYh@>?Ht5w%4zh&s;pf zq|V;rd2rHyeln&bP}86=d@@Bs|K0T)zdqd^(C$P_&EHl&*-(ZSU%o0*Kp^K?Fq>>= zQYYAk)9A>eC2W&Q)7z_6bX2&ip92zXU5{*C2?A}0NC}wLyMW z^eX3G{xJHvMrbAIB~_LOAQ}C3As=jri1Fw3I09%6pCBzDThs9I@g;aISRpl@x_tQn z2!$v112z0QnesM0-=AxKE-me_uco-WX*WCC_VlI<6GfTTC41zVn=fDT5lCBDSO@@a zZpDXI66HJ1u;|)ZC)xRQy*w&Q|1lf^yND5lHR)~0LWWn|N=#%MD_lH?j)C$KP*$9j zYZy4bo?pL?CCYd8i_x{csz%_l?%plwID)iv2Mp&0g};r+>17QGges9<|)5So*ZSE|>k` zWL#X_DHW9-vhsFg55d|7~i|1_8tU-&uPwP{8>qvUL6*K9`jZEVd&L zUh69Snsc=+Q(mT`>hhgqBD+F+sRRWD(H%eFKiXEEzdX~%ulKIbc=4J6DJ@vTQX`=^ z*u_P+zCB`=n`Kw#o~*Ec4H`C&^YHKh-f_F}`Nh^pp%1pLqo+}AQC-(49y51PI^Lh5 zz2V6*`DYj39zcS!K zyaEa7^bMD6HJXDIEvNVH^eH&WF@68!%DSn-aIJHBP9r-3t$u;6m>OxDH5?fK@CY^c z67r69$NLBAg$oC&ruA~K?Ivj(dCsFZGv(=1yPhu_!C9zZzka>2{g`PP*gd7?h3Smk zTyCHhe2p`irWGC|cU+^RQi%io__FB%qlNa9mWDkFHaFaIEoa}cX! zQd3q=&Qqi*tn&^WL_)nhI_B$#?@*tlP^|y3nh+*{7`Efx3N-ZFqX>FIL`3;y@-Z<8}NF`4AIP>b6d z>dBtT&8c8+-O2B+p<$ZnH@qdZpS>jhyN-O?HtqDa40}11BSHWcnR9s5*LZn(b^rNu z0JH?3_1Vlfu9N1J$!OT^P!XxFIOBk5ppVe`;PjgHh{GuWA|9@A-ma0?ETP3mz3IY` zGO&fA_?L7#=xPyQs~uvNA?O8^$Gc1Wz^~Bxe?nB99XWz18 z2d`y#o!g(5tklk~)#a{LrQ5XIJ~uXAz!^6EmaN1)02WFQEZRQ6H;pXw)AN9Ub`jn{ z6W_o??2xoehUjn+Z7es)>@b2rbqP;K9@*Bxfk|9k-0J(Y&{F4VR1E%O$BrRPx$)WI zAQqLdqy}CSy-HTC+0UC>T9k~89=+1dRin_!FuvU>vlooYf+PsC&2sD<9Mk#e?^MFM zPJiy~G^mUe-~%-iimFr4UR$fqrJlf_UN93Yar9_FaBwi`-kD7#L6_u1cR@>Z5Zx5{ z{tzyl)5^kPz<^bv+keTHnwnY_!K?YXtIH6zF~$6AY^HUKNWhg1X;@yu2x>7M+9#DER|JIy|Q`KLoS49(@m(Edu#Q z>g{aN>asc>`6ZkB!^5X5qXccv4&c+nK}u-%rlh1OVdYE#HvffQbHutKQsmYLvOP`CnqNpt*6xVv=9@wC|)=L#H^yM?4JhP4hEtHfBhZTNoMlLB$CR3E2#i#@El%aHc9skVOu39OscE7m;;& z+1WWx25tt(iV?R-0Iog8`}sID3G7JfVQ6_sU-!ATb}T_~3B}5)@l}kr zj!p%77Q0tbv9SlPUArdYF?U_kb0M4csFNyCMXI-#Dj54L)Hd+!jx z7U%UEMBqWNG?BjR7%BnoIC*$@YW*7q8&XRvDwJgAd$$5UBoDXeX`<>RYo!}3NRnfU z!*A#@emm}nqcVi611F!H)l%;6=}7?<_AT2=v!&4U1NpnigZeM{m}hsXIGJ9E<2oI? z|NM({p`cEc<>haILc13q9}jA_)c6&;(+J2-<@HY#QvmR93j5|WaA(JwB%j_t|YQK4kpnq?78 zA>ur$_gw!&4roRt^s0ShMXNO#uK|ITaX?c*akzCN#85#_UPw0(;0jj)!_1FeBZP=f z+)?xD?St)kS$^f=k0oLtY9J+T194SJy`02UFT)@6*2dmHFfamOQpwWt89wCR{rmR; zl2w6jqAgOCdh7Yr%x3~-s#>HL#kA$O(CJa#PWoH1#Q9*abK;_6{Ai7tc35~f)AsF_ zxsWyA3u}%Tx z^S9Bq3lN7gJ%!P+oS;}PO^m^Il0(fuQMkI|&a3}E+s|)lWpT_HtP6{v!S^a}P=Zo? zYUBt*jp;~AL}m%7xBT8WGdUTqW#FnHC?2D#k!h+lk6J4)-wQ>K8M-LJ5`|yFcJJ7h z3JikeRxR=gj0D?Y>)V&5+TzyDkD|aPUZn73CA)gemqsVM9w}1awz54vW=V>Uot@>N zUI><}FWT?F5~>O-B7fd*R5MenT*%#Bvi0@ z%^kyGD{XDPm>xbJ55C!nSMzD0-iN2JYneDp2XC;bZ=|6{#ZCqeVHoBBYC2FhCWB}v zlr^$5A=NlJZ-Z&aUh5k zVn)!AUV(6$-Q7fm=;5%fZZtvB_#M_D45NL!!QsXO4e@{1*1=kNe-8(nl0~`b){i?10 zL?yJ2XygP1S@6y}#$#sro1nA+^gNWkH#}Ss6`lwQxsm#>lcz+nS&~(y^;c1G1>}QE%P| z(bG@^24y{4tl&<3_g{W(J1V6Ni*5_$FlBq3K{AnROysuH|`Y68%cyf=B8DsFJ(Mc-Xpoor>3QMl|H+| zq-$*{GSK>s8#%?qA^-+pHmE?R#t*)BKbPh(=gyqvII2D;IsY0MT#LfNA{C4U;;OIN z2@E`z&ha@L1lS%fG`$w@7<=D|^ON~V^@=UzmET=OMYkXB z5Zyto(N%SHc{#1mH>i|#-zD^YI^ERfou&~{ix0FWTAk2?H5Yoi2lXN0vBI3uH4*(F z7cPN6t{~bX1Px`Xf@!b~`T6y*J4lZ>(>I%!_GZBv#+biMs_3)Xk8ZKRmI zv2IcS(HSObdO3reYc>e&5Hilw{o%sBx?@l0=(#LGVbKrHV^0BOPy%dGnXk3K`v{fy znfvGpP7(i~hXsibEuP8IO%f^|tnwhZ1U9Z& zw+>Oh8@cwmwy=9D?~NNb;$*!Oe*d;yUi;pC_FBM^eu?|__4NyBg{T%6(Hp03A`)Ci zV*42&qW<~iG8(aZxEYu$R-*JW*H-un zS=6wi;fgnz_ZhbdID#}NZOk}cCXg**W9N`HMe0s^&P4KW_gPT;lOZSVq9BW*=6@BfmGT7vgk2>C(EE8XJOm(U zb{ZNQ8o2yg3ndJ3)s1&jjTO5G8JpCP6NF_qd_J$BoHqUPdhaxX=0Njzd!8uvF9iZLtHgB**j0`n zKR$Xia*+tMV2)kEDxrLdS$(?&X4Pr5{T}!Wa9Y(5;(#F^12hp%3(X z7j;u~(Sr|iAiM`q-V6SA#P+sT{t-@n6ZmPY)QWLKwB204baJ##NGb{yv##QL%$gsgsro46qK~u5^WDedH2|uAfgPJfE_IxliVg&iAPN_ zLYHom%Kg#k&2>d0fPx;{3jJ<;diuxx=&v*%ceJ6ms?U8U%L7d`5^S#u(N@e8OlQHi_Fh%(UyMVhAj(63f)=MfKN^^t$p)W zEv=@K)oqK>#zNa+a#{;D|Dg5RJ#?)(R|BG=SU{qCUJ97dKD#;bhr1_tp8~hb&y@F2 ze!#~Cp}#3iOD~sgEZ)AC$z`p?#_W@#I}^)4>Bj!^)O`$#jxGlkh>}CZX_#8FE`6eb z+-%sK#Jk7GWA$?@#4YqV-@ea1}EIKZ);w z?-PS;^~6W3j`xRPQn0wO{+W+OeLJ&*939uvm-tn|ZBPGWom---i0aM;cp|=ztStMS z&Js^g&pZdEf=0S6f<}k_t~FwKMSb~Eo(NeUd>c!c-XbC*9MlcBg5H~FJ{2lTW3=}T z+cWAV>ZNtRH-2VG()(x1ThhfV&;5sEiZYo;v&FBM$wwP0XT%blYD*nES>0O9cAF>s zpDS*aW+Fuba{YZE^3IaqU-M1H;Uj2F)317&UU*Sx%Yn>N=4e`LN{<_&a++dVG|wdy zpZzP=kaZcKq)1}tBv$K+`w!cQ=xfG856Ubl(gQfFIB9lb*E5A~Yw+wbA$Az(0ef7Fz&UNFzr&x=!4!PQpulb04(^0MIy|k=_%h$VH3RXY znkcgy2o!Nnh*aVl=Z$0r;{TDP&WvEk&b zr>E~F_)uXN#Q72$Y*8#il3@1v(yD{$Ul;e<@0@cn>XV4 zV;4(sn;{Uam$l{RK)504@yb7E3g*Kuaq{yA0#i*cjummRXJ2Uw0@Xr+a8(3FsPo2% z67gpSK@6%}CA-Ar>dGvEv(pYMORfN@Q+takz+<3pZbw$~S{xBa%LQdq^t!Wi(@_3& zqwXvH50U8juEU>n2kbXC5q__}Cm;4oh9=6n;_Hp*FI%04A4}p^$JCB_t=qVBzc01r1vV_gha9@n7y-7ILvh}Dw4E0PEs`)H=&n=rKRAh8{?3yJf-h$pfxr(`$^xOp z*4{oGiC^j5xd?+5!Zn;Z6G8|dd{I-hZDFyoFQO8qTs2}3Up+u-)b{N>c)P6i6A#;p za{=C>@C>9SMz4I~4*d>xP;YK|iHV6|0d)v{3to#EF*Lt0lrHea4lD0E-xpQd(V@?} zdp9BVG6BK#f;mNFe3zd7zrpzIUwrf)E7jkps*>W6I2 zpO8gX2Bg(!i!-zwQ2x0}kP0jgyGc9oP_VSZNXYN)+qcBT+Y%|L z{YEbEq={Fom#I{Re|U^9yTF3dlq#!iy{{B2hDmRdNpN{Pja5|tj z(3W|ybE_vrp#CW%e&Ek44=Qm8V@N**5hXjKeIK#}+%e~0$O zv9#OEq#~%=A3!@i84W7wQ}4lyj{XnoOgJ`AJMitLPj}G5G6S(aL!CMYbtO4B!{{Rw z2=6Fl8D$rjcY$oO24I#POL=8v4B#kY2Za!Yh5>DMWN&Y;hb$d_z{1U)%BqPRW-lc! zijWntZWiLzN{b_`jdnUhaw2hciMwn9F!>^i`IS%>#Qmo`8^c))T#iXeOQ$%Ev@$0O zGKPYWQ9?@W1(6Qv#}qFiIOd_aP0jRn5zXKJ;ov->6TvR`rZg(ZPb1FXzk5G>7!SKm zxqct7ocziG)f-H!J4L*eVpQ4hS>>tcYA7s>Y_rO1(XhIlr&Cq*Aa*fFM3q@5)3cRBIQGIce_EY)DO|*N-gZT0uEciJWgirr?0gk zJ!Fy1a^B+>PqC%oI++6UQSS`kI}u=Khs5^*X}SuKV|oI5v_P%DCN#j_=_bS6-^5t~ zI^HpFd1HShGuY9XuObg?YsHDoUg46hz3Equrn?s*Bl$EiOGXk_B11n6 z4Nz>^y!m=<%TvUeKaM2n0$L&s^IgtDnKK3YWu8ZhA~ZH3ywOlqq09)tH6UWsDOI4R zp%I2+aDoC3LuNQq2!)jSdzb6uM@*?6yEtHXctE%hqzeT9+^@TA@H^LlNV|KB`Y&%C zlApH9&Q7pV3FZy1a^$Kz;G=6l68i z=kBme!PrdvlgPwCzvG~u01HXFRnY%$nD=S-v1}tDGObZY484|b7yX%h=xY;4eg2X2bpWSYN1n9DQ1ut~0_xzuk9FS`H zse%6;o6`%bZ8H;-X=qf#^2GEsC(e4+rb#0Gf-QBQ@h-~Aa%ZLh_z8dMG^PM1+zJ2) ztvbLbA@tG38 zuMztC;m(4GuswGV43vRXbRCFy=7B5~v@CsK;fSz0rp+UD-5mu%%uoBIjW)JklUgm=s`A)n7>+1rVxzosCll=f&bTBRbNnDVIl0#* z!PoahT5E1bM*XTdpN)AbzbFwHAC0^aZDJF%JQH;F?(N&cu$B^_h%I!5jgz>dFn|Ha zh>+KkO9H?z|A3j7RsStl@wcXahOx1i2Xzj@SxCfjc-{F278e#^>(AfRpAnP(=^qEg zxo;n(4+Kt&vnPp`9yDoJG@)$JHVRW^TkgQp5J>l6obt{rb!a*5?d`;0jgN#QF$?xk zoV6ECP~^9R$?C`99f00co3oqo(r9~L{=%P^umD|*cyLhvNEJh1aj_g)flb9ojjj+F zNiPHOjERDL_*iEFy2+mL^KsQsCRspS8hDH=K#VVk2Tx3>cl-+&Gxs zhdVw*!4&k{71iPGGm|SG!Itl(#84>pz(p_nw zfqBlgj!MWE_(wf56QtMHJ$tT&IzwGUr%rLYMeX$IuicI1mSW7&8Vh@w{JMNV!u%~` z^Zxr?QyugGo&Je)AlxifoWA;K)t6*XQCy%jq-;_5z#<2I49V>i5xpO9!iVlw*%j|y zEKRL$(Q!wegK|6F0uWnS;7c*kQGskkarFY^!$@K{0U?5}x#aV`tpYW(m`Q+L;PQ_H z#QOjkkKv0za660~CylOLQ9=19+B6YXXZkeiz;lUUm?FWoe+LFCHZU*@AWJW(cp^lI z1`9%Z0tP@Mv7H3h!+UuN3?rI!V#8_3vQS6&by81{6;Ye>`w7lDNM#|UMyQ0zD%6Gx$Jhk^jwMF^f2;AB&a>tn}$qr^-Fc z>||m4*)HQDVN*EI#T9-J&=iD26{rzgN5>yg;iF|C`%e6BN@pX68W=PPLHp@0gJz+J zYe5iP3#fcY!IHtSvy@Nr8UW?*?|zDva{3cV1(@@_yep`{0?}y%iM_9x#`X?ZyghgOo>3H-BI$i@?SG zpwURFgoK)j`2tEZI^-6=VJnwkR8$mNS32i4K*w9~Oc@y&jNl48{93Fx$7S4yP~jP{ z;A=NOVMatKY{Es-jsWhVc`^~6Az~Wp@8^#nj{%3JfE8!{5Dx}7MJb%i_{tSNWbH`2 z=Hm9d0e(J4>g*IbUGO2VRl}FT$!4Ewe149IU@O8ZhXZH+Oe^UU=9opWP}u9 zOh*K#cUj`V0p*p+#8p4QR@mK?0mV#_UGny*n2YSPdZ@}p9LSt|_sUsYTYqk7&^&(P zgrnBcJjVhdNF2QI-$0gDL4q^;UH5nkE9`+x@V_b(!$<^j!Bnb*J>V9&zZ^=eyb?vt>naa&wD9$4E0MP}bJgmN}QFcmjt7 z0xFb12#psXT}^%Zlni+g9dF=yeaTPt?3DMK&Ou!?T+~jI36IaTCz9DB4i11>y|k;Q zdk*0kzA#~hJ|2rwisP!HwRksByO%#rKTo{{s=F zhy|1KBTy6lC#Fa+LP>?@obFRjoX~c9X5k2?;C1nJ5l>}`FaP{{)O&!EVRWB7-q5P| zQ)`bQagOdfY|UdW%(PZDjfs156Jr+J281r%z;%D_0FB~w%PJ=u%3XDl!{*e2O#j6X zF&)m+h=VD~uGXY9{4V8oitIJ-sTX}%ZX^_2$bF1$;QL26xr7MrN(E-=XsBpJl>|7@ zE4;12w418VL;C(nckM#z%a0{%&k=5_yXz$84I#&3RL!5Y3Ohl%=?Q^|B+E;5s#G~i zw=M2b#Zr$^KbK5MA>>hMxFUXGaR+s*Dwi4c^qOyLB>eiV+6g_xc=4yvZNf~gZ0v2@ z{jq@Q3HT08c&e8Uzvd}<_gJ7wU~T)2BA>nBqAuIqC$CI!up6P6{4iG7gFg+y+^1td zi+6d1r*uLI`9$y*M#E@N9#;rWrfap<6g1-aYnbig+?6y^smZfP3pCU*ATe@o*sKT@ zg_(vRB-t}C3n;ecQb8~*Ma<8mSkmb-P4hF^+DZ2~O!l8f&|&!Ivyl+--7YuB!I}f} zE}0<$OqjS-`I?6+LSa&9ML2f%_se3rYm1n{{3$RRLD@qzz;7^rY9YYiKM8_LuKV2}>!P)SdqbT`*?=gvji`F<=O?f5W#M0R;< zE6MIaFlbEA;-Rfi)RAcpiv(91pd=IsOSFBV%`TBcJ5te+`{EBZ#gqDSw0g>z)7TSh z8TpesR|2z>Bh+oP65$dFMELya4qooA0n zs*S(PfDsNB-CD?3(USIDjg5^UAb3=)-yx~iK}ub2E-mF$(M5zlLfKe=52|!}Ra#1F zJCr=Q5I&+d!y-4{dL{iV6mFy{N$E3a`Lww&t*v#x%jJtjT5eJ5cIM>Kr2=hZ@BD^C@-3CBw1IQDag@7Jrif1rHU?vhC{Mpp^z3Nus39HjKs5o3>xUV}Z65J(9PkT+ z3}j9T{dxi%WCcsW5n@*w=`nEuovspD1a$Sn#04kOtLqr{Ag*8rOBonx&hQ0>0N3RM zH3$7#AJD&B9*Q@{gdRJOb;6nXb7-gv9|XGTSIJtSV1JCYsi~=vW&+(quN(g?rgDTK zhXRw3AzRQt-9TH}?;;`M;wgzg8Z9f#QH;NSMYfp%M#zIu1k3!`rmHZ2arDE&L>p$& z=0-n^&7mcmiirvI+A=u-ryKzu5Aq~*2)=N>_1L6<%8OW&uR zeN_W)wzVM`QlPRq%E_tV1zsH%J>I`LZ0ipDo!Y^8O#T%?2Fc7X((i#hVt^LXQ$CR9Q1)4M|>~s`L3Gs@+Skg@56W?>AQ*V$GWRfzi_@E zv&2XdOmhEs?~uBNrv1vAl;q@7h@;N{QK$~#CIy%_C;TLO)hdAJKIn`8-qC&Lm7gEn zwf_$sM5q2a*YI1N`8pIZ<<&1rU;pVtK}>e+?C5HlkzeKkE_$JDX4(&^ZJ?zM1NS~S zfMJn)ckkZy00m)(emv%g;|uUPlD6Hq{QO|hRe>T&>_M4OF5t6>0Hi#JVa{N^`~8%{|{$*$Vt_yLd-Z`nD-GHY#tU7tV2 zM2khD(y3E^gq%T#!v(G__u0U}05>K_wtze!_M?Au5F#;Vv*Y88uxz4H&LFoTNX}gO zFD~o%Z^HI&WgOA52Z4z^b_yAl04ZpRd%Mnv9C5*dG=bm+LV7qP);E_M55N;*8b2 zt5un4rNNgtrc^Yg!e$ayXb8YMNB3xhL*Vnlng^Vkp_Tvd?MI-y`hUBkL)j*WnbCw1 zN%WSH$P$&%v2I%tu>=)_+}exoX4}%f{2!H`Jwg`G9Kt#LtnX_ab8hfo=@eSM3dDB_x_l`jk(d zS`UDWYd>~l!U2h?3~5vXjUsdy&`@##FU}xhgdt=yXZk^&J5`fum-suMQRycr=md6oWk5uka%`&PI%*IAdV* zTh3M80AhcHCy{@`?!6%p$fJ7GNGQ0{uEWyy|t;VwRr*oY?A_;eLyE-m5G5 zGp$!Bz@}%w!iK5EDpWzrf5s+zv_1FW0U^E+^t#Yw&L+uk2si@WW&j-`ncN1ZIa%v$ zYx@gT9z4?X@~1&`I!|(IEFIF`nP< zrDdMIR+!JGpGC=>$bb|bUgbtYLY8%pSlc$y(V3vBA`%)Z-~^c3EX*CCe??9*f^Sj^ z)7LsRRj-ewwfJd^H2bDJz4(O%lU$FmCiJ_Hh*iyF1s7I#r2D#B_YYk!x-$RB1iwmQ z1azkp5n4SX5%a<~RFb*T9qk zP$w ze-Yo3%eSo*Gymi;qW{a$O8;Hz(N#5W-UJD2>+G1s_AR^U3~O&rIS6f+>m4=Zk1rHJu(guJ5H$c`UW7sg% zOmfwE;)VzdIArr)?84Y1kO10h_Wk>B;b9qJCN4UAc3dqN~yj;YuXULkL1C-O)?9AM5Ol#`vkq^A&5Vh}DMJA-Q%p#^17^uR@RW7~dj9hAQmbQhM<^iq)vaZLR z6v7>}J|UR`w+oZGRAPw(8$oW~L7H}T@y$D7YPug|CjBuM3H62{scP}x4F8yX3H*gS zB)&8?ML_|gM7sc+iiBP7dQc=u$de>w;9Msu5`Iho`fLR@0-npt?!rwk4ATrQH+gO- zF3O=8v>#(mxCww zZWtiVAAqd{5KiI4>*&4!rSGCS!1pRa5p#iMi+CO#6z0)h;);|DhToR3G|{lWfDEl) zj6LG`Z(!~CZCYp>fO9)Z%R%I!0(?sTCo)WjpOQYJXnAHAENfqV2(y9stCSXIUolDn5(dP%I$?xh z@D(>cylahxH5p8#Xu-rEb?`JE;ai@)1+gcCBlm}x4f6Dh@h~2hz_WXx8Hl+b@;$%P z);oZ`2m&&f0l9%6uX!u9iG{Hx!{jb5;yFRQNCX~Ybb+qtI%6n$a7_q`eBf9~876u? zk!pyca{@A|cDdop3j$C>UU{)%_GNXx)|)uDh*plGMkWLPY328`p}Ki245mU4!&vfR z*i(5y!em0+XRL_dm+v>jifncq{8M&9x;3(Ll09w8wj6 zto5_8RtIM8k|D5Sq?rqR98vnwQ-y+0DyYH2>=zO`gV#D)iO$n`& z|6VhH@kQ0|R@|+o?)&jf@-}XTUk!{r>sa&cttn*WbW1cpj56)8<(s)4)z9RA_ZU<( zge|+67!#liv!FpWDWKbx`Su7Rp07c7e6vd1*)#&OEnuqn!m)rr*)uXHJrUkj$2e={@J^b;Og zV(8MPSD<{tJiYDqaoX!auavhIR~+s8XKi9SG;yfMv_HLQ#h7R+3HD!sRfxp9Zj0J@ zukqa+erJIboyUKgWXC+!ezg9U-9LV+fT-t# zyGdQ5(ekl2eX)m_P0xRxk`pVmjHW0b{vFk5p&J0e)UjpXU;+eRD5tEfoYt{wVcL>u zCbc-6g_tDG7#Z0ICs~FH zmJXr@F4P>!>QS-3t^$FEWK%fUCfY}VY{_L;z@9%r)209~lRG5P&-`{IV}C_n%PxZ{ zQFAaB!g%}(gUVG~M3u(pMAm*S99$z*5Mp+g5PR$#*fB6j7#(d`72kOux3Vb-I@{a# zAY-^*o5eeufz!#K`ldq0;^0LlSNq7U%$an|kt<)w#FJ{AL@1J}$dyL^KTYX2kM(%> z*!WW>w$FRBUw`^n{_GN+RejQ0vGDedvcxpGD|qOIn#`e#`jm~+7rS4{+WKiAGvt&DATqf_mk(rWo?o& zMYU3s^Z;dK1pmE*6}e_3;9fXRBbgzTTtpQS{oPz?gUBq-fL^a}K23rHs8WSpCpXT1=H1yZRr2^WRFvO5I1*!0;&Khv< z0Y7~vhk?I!7CtScCVJdUNBChM{EhZ^Kp?A&%AK}y~1ImaqGEmnnn zVJ?(4)pty`W*6OKq18=F`UIv5-jcdtgCGToRl!S<*3^Rg2O)R6 ziA2RL!iKU1gR{mkX2Z*U3$uV3Fpy9+NRKLcvC~7DufiP=`!UO1`%u{A4!VJVo)S1G zv;Fob7?QCoj57s?6~O62x&O$K)}&74(6u3u!TOsoz}5hH;13CYs&e34PEO7gxdw>7 zy;$c+xYV|}w5{e$d0mX%%ihTqt|v$}@hoDIiG_t59rFYHy^@=oTj96T^MjcwDZDE0 zvyXU>n5h9YgWc!AzuXKkUSdXSE}9w0NDsP81FDb%1pW}C(4WTFq?l>oQU>DeBDBP} zh6Z50<8@q0)8FAD#cc#pAadsaequ;L+!6^0_vnyoEdnNS6u59%&}P8i=eQIGhXmF& zq^5)Di6bFEo+8#YG%piCB+780!@72@od;j?zq77YfR&t@Nc37zLhl8-QVp~oH~6v# zFxnO@LE;xe}5wQX=U!H6zTDmDli5eSg? z0sLu_oSW{EeY@)Y< z*r-ra$&jYbdxz(xZ`TlW)xWXN>(_B_#b(I#(Fz~7a;a_x4MRpjvA_zKdmg$c^4VJlIiw0jA>CKkwf ztAn}+_1+c#*A)mmjyU|L&J_nefvMv}PH%h*5C<~X7h4`vNyu`T+ns9SjzFBvKqdlU6VS( ztv7z{5%$20+4O3|w{PT{7h<#kWsdlx&Y$mph$i$KbmW2+a68Z8cK|>ocb*VK1HWz# zJt+SE+C*|c61o2q83NgAD+@~*s6%o$AOOuiz;7}q05T6NfR2;G2)|Ft5^$O@m^xB5 zAe?99)GRpSb|}XSue6V0WCR1A3?OfxA{udL03!>F#aVIOg@hXcLEX9iMRy$@7FL2^ zM3=f3cfr7h>5G2Jdm z2Z5PbUWn#s?t1S?CpX3|A`*stCjv^Cc-LFTrq6*8)~M^(Fx=wA$Z0xMX!cJMef#{!@@R(>}ss7A5<^>A&tpa5y?rZBnz5 zp;vP0k}Z@d#Hbve7+E{ds|S~mQDX2Ow{EJW+ayHJ{QX;=@fGi@!{Y=(A1!zt$U!Je zbn>GwD>grT;dyMNrQq(Jj~fxSPop=wv~D}avZ@@nk>&{1mCC3AYD2@Dw7Pp0(oQKW zcO&CMKQ^(kQD3+^*OPbCZ^D$~hUO=`)x4|Ot+yVWX!Q7ZODhHx>7B*+07#I*$WlSr zAd_vN(#frFaLrN`Bf14ziBbaq4t^OA1QL#(0=ILfLdx=)*|Gp<-8{_MzuR&JDL|63 zi=Csh8W|p@FD7#|V5cdgam^4M*gt`;bX1urW`)O9uAeMDA~4mGOwiH^)_0 zHfLvX!Nopo7fM6|gbGS>|7O7w2FUw~bO8>A;yW04+mFIuJyO0g|GK)~(bl%UNDDVYO^xThe?I|Wi2@`}x7Hj^Em9nzi_ieSVblm$RDCbX z!rVwHMiIyv1wU(s+ysZ4FUCLqbSITmG4m^bP({EQ1;PlBZS1rk!;S!< zi7K=YxuuJfCsS`*;XzF98)NlAMFmkXf#bsF$lp4@;+Fzc-(b=IE;56Cwlf7K&`&;) z$H4(?11A~azW&%Oh5v#}i$>54fto;15Q59!Kgz(j=ILoza|OELkh$(~+u3SA;QvYf zPY%|PAXOq7hLLVbXA8GWKZ?Et9c5vsd^~&h5F76HHhT2Bry!Z3X(zg!!9M21_32-_@J-_VNp1UeOG=#EKo~p=ehVj)$4G}5^P}KvJ_wa^;E!# z|N0M{$`JYC2-)0PeD@iOw5zW0S0E0koYqWu|IzwWw}ptj2#LfLi5(oZt}uA__9(id@`-2M%ry=h+jV0c^;FpnTcG@+~s? z+ZCSjNF5_1dMHoNgANpFkp*gg17&nN#4oEc0=);uhisqpfuiSM#rZUdRp(zl}MI|f$J%!i6_Temag;4m9aQ3p_P0=7f$t?jT!$p zX=8&-{p+=*^>G1JC@HaZTv+c z)-W1|8xSeuj;C0F91Jv#&tHQe^Gtkd(@1*x;Wt5P=s17hx!!$>Ey;JlRsOSP9hL=x ztGR{*TTqEx#_rLBS~mjM0UiasU!;x(8ot%4rw8p2Nq|Z2;ZM_gcU-6!a38=ySA{T% zXX`&C!-0iBvY@#QK0S`OmvJYM7CJ!uYPbWmKWL}VK?(#cgxg(z`u%RhjRNsi2UK?u zpX0+~5<Kg!;7+i?5(4vEY?H0mOg(ZJPr-HKNm)&&d->I-^#z<3{ootqO z3vBAYW<`pngxNRyRYy{qme{({B#|=rLw&Z`ub=}#A16LO_=a5z-q63&g%U5655NV` z3ak1MLgGF1!@#q>B*lQ;*t)+G60w*_ofy%vh&~=pS|KD147@N=cDh7J$il%90jKw6 zP!_}TzeGw3lml^SG3Y>=<=jW~Ulrgg-Q;zL|5E@~;}SSwDf*vXv869N{~xZ{G^yX= zKyL%qwbSjF6<(WWuf+e$6`Oj(J3=jFtH{f@5d=1chrk0?oPp@2O0NlaJ%?fC4TQHs zBas7oBj^!8#BBv{uU>zPF#+lgcckt?N|FZ^EM%ua)fV!r9?6loeof%B2Fhx;a&ldfMeLhLGzu`tTXJN1GtsV71 zi|Z35L&oSVzfpb34Y9M}*Y{ayLX%3 zAdIyp^7yZkaVtRi;E;rOliM;3Hn?_xg++rCm9J?FODX#`R36-r$}~s{Ak8f>sTTlA zf;F6n!1THeG2Sa_=uAL}Mh3D=E65xm%hzssYzrR8`*0IMj+qA%1W>jh$XGA~_96Fo z@kcNUy@hz8AF7tHDpcSgy5l;BF*7p*_vMrq=p8yb!k~$cIIqB^U=JYyTwQBZ_26GL z2#80YA=x-cEVH4;h};7R9}JFBqACgy+C0BF^MV{0aVdrQsF;0~lt?r{S|SjdKxM{aHS$7g8BJYh8+MY#HmNjJy1@+q@sAAb=hg12aZ(pB zHzP=AsKp0E##7}m#D=8Yz(H_fjs@^4*t_5ZZEr^)^d-n?pwU(dIWWZS&m?X_rt~Uz zAB+7%I@ZQlkV@QuZxl$Yc7+yxxB401cUShO$Q?c&J*KZP$k|J5?BWr9)%4N)XVNPc zQ60kS)j-@6g?~6$BEB@pCn~xb^p7|Fz!g1q3HMQ))_TU{82yH-lmji3x zH}Jn}DnZN|84&@UMq7vfe-&tJIQ@Km!1jWquCOk1L3Y+45e${tWe?A_AIuOCBII`j z;{)l@0}Rk#46mNTsf{!&k$e_PZAj*g*tPgT2|P62{|?a^L@`L{%fv)RzCkby$gc{` zPDB?2_pQc%!}R=*8cQ1p6Ioa)5nK|yG9?)Vb8grv{=YSs8!;CDliV_uR>ISn=085) zml0oJz2-t{f^c2H)PZEs;GK>G4*0|BMc~r9NYFtMvTC@@)EN1-K}h7wNOuGhcigzG ztu5{kdb@jlyf#heboFCi+GfwNLhiBM&BdAnxDS$QL8*xP*$^%wMgEVvlIBO5IL2k# zioSoo;;_C=Q~Orrbc2_=)+{92VBQrkKTjXeelD|DUYM7ZRM9=wYG9FMPB4RIOEiOR zE0PTgyYLFN!K$Y$5|(ZN;weZiphYnAbh3uU4p=}N6mo8X|5sBm$CiNyK;)phoCPcf z-08nos8)>6XAk_JiJ@km$IuchO*zje3m7$|;Jhz${G*Gt){uGAk5z~>ljrxy0mjum zhICd+lnvTp?yiMNrb;$Q3diIBC-?BL#m(F7P+$XTdS&kg939Z)!GSG_xO>TN{8)UD z_DDW9w0HQ{=xtTQ`+nZz2C8;X9I70-O>cHL#YFA&SGo!h$lMPfX@L5N$%?6fjEhr7d_MJpnU>4}uMV zb|#?$nw?ID-GzO|>74v}`7N#Y)oa%?8shVe&{E88}bjxsb93w+|@#ZJyzg^Sc?iQEV^zZgF%F@%WTd~Sd(S#dsK(> z!HNFYGCXdnz25N=(|UCMDlz)EtcL|_1+B}{eb7uof&Zm|GVh~@bQv*?i8K7IA1=^& zTw7fvVu4qe6dejHtMU|TxQ{kTerscTU#R<1_O{pgWvT#~1yq3yxnk6o&e8ojp@|ej z7Qa1Q?mXF)aTSnLtzya0zwJU(RL@=~?Vn$$7&(u|0{kQWVWn$_mUW0xH&m4sQNhQ& z-AN19F$Zt1fJ7tF{-4OHMSBvL&(+Taf0g-#v{=p0sqqAOuE5VqH_`Lv`hKNRTH_SB2dGQMO6=RrtKf@UZdS!w_H1zN7`D`i5L|aBW=z;A)7#bqUWX5RPaA zcNruFyRyNz-1FY?NE$1t8f?^w{_w_SdF-I-8#v|mr2jC+FT_`)&M$;Umdrw$AEURb zE=7gM9|*GG1o{b-6_B4%-Iuu!1qe_*K*S!I;}6yTMLFqW7J2BEfQ%EF`$4v*CuvT~ zkK_kPhZ9l>A84#{cO#lLq=F2^6|J04Ceqya?=o;}On_Mo38McEO#EIbViYTwcF*|@ z@4;8m&&|l2jQ=w;$Ek@Mfhmpog1ShE5aDCN#-n-IKNPH_PzL~H$aHFA5=Fk&Sc!s^=+HF>W|$-Du&&K3a9OU@EZSH3)L$YKsX|BM@Z|ayY2KcHGjt7 zpOsAf_%Zb~?@y*5sf_>OmQ=_!_}yNE4g6QWgCeE{zidWv;<>R_Rp5j-ZmnsE)7Of> zO^M$&NQ{y~D>Z`rLZoGBD~Mb=|8&W}RkG{^Gst)x7mz0rJ|(cH0b-kVNhd+N_{eA& zWc*0m=1yX{Y2!YMMUsnWZ*sdW_4BuRq(^fXIZYeAwV9fRq)HFaq+O3_M~;{slK32xHVi6mMALNDWapAM|v{ zB1$i(c(>I#Zt1Yn>M2f!*(vOroZH|$8^?XZ5IT~?B)#{XhepaV>*yI4E%6PSEN-Op zWez<)$_sO7 zSns}FpXFBHknhmei_#Kt<7NfjM<>9OLB6;K1)L@zphH|raM3IR^l%x*I|_nH?pUhn}NSsNtjxapYs zSQogpS-}{HM0}xD>p2$n9qsP~#|hyxzvd1Z4)ES@)`>p!xvy$_fLE$2-@&ppGbd<) zUBPfjx~x|)_^(#{zU(da$dgC>vU&zQBqbU3qCZ@zYwg^h%&ivO(u9AM%rZBFCZolV zl6~bcM!;eA4hv~_Equ(sWL(4%$)c}Ru!_lAY%j6P2D{e6D(LA;avI$7Q-X_We3|s{ zt6^(17POiSb(m4yCR*?h<+pJ%Y*VR@I!6gpd>#Mj)34Q)@%&zYH6WSvp>A;@!}E6U zfNd4U+?&=X<0CiflB0XDy{>L?GX=gBv2RXG!I;hB856$sE{_`dC_V-F?pUo~>E%CH zQ^mJ_ZXJl)AW@gUOQ!a+XtnIlbktWRx19=92Xo47VQ|Ln{wwc7XmAcOQsnk;a=Dok<-ViJda=&HkkMT?5 zI-K_DzO1e_IF*Z^Dm`?i2`JVk;Bo}>kWoK(TTJY=uv+pfs< zj~l<@vg^dask0{)VZ@UPT0*;6Trzws0w1Btv4P0;!Q}xE`g@?x0a{NEtUbi|^Q8Ic z1;l~DiCWvCkTL3sdS?qfD0(hx-fu!D$!)HQH^(J(lBb8}Bg?;ve3WGCf9MX#xGYm_V>vuDd2b_7wZ*M4``r5s7i$yxhnF&j{gIFQ z8HPUtt)JgrP*i|u&l>a`NCgghP6IGz3z1fX(!Izwlj;sEx)%4jMGTuYW)TvlGZF#^ zvhuzG**kMsDWCsC!>G?iU$a1rpy1g-6s`)PK%#~s5>gCC%!YwLN0al=`}s0D>>mZu zM^;3}Z@NMnuJgAr`%~cDbe?x+iIO4wlNUbpfpzOB?Lr2ZR4VoHe^%j+Y~&E)BLJd9 z0i>r;7O{c=0dZdf<^-7p1eX@P6l{n)iRpz3GOT$A|CtFjG&g)I(1csnT*fLO(ZMcPhAj6Gp7FU&F+Vk03e+Lj6DOCKB3%-+)s$S z9guUtWV;E?T9xw2B0Vmo)(ORYBMbQ}dXJv8jX#Y`FEbR)&%Y|>N17}73)Rib-%w=x zDy9g2XZ)geR!(%@vF$a7S=!{bN zFd47F=l=e1*NEP}+`PADG4ccPL*w?j);awW{+s^(_S#r@xW8xHS1BlaRVeLM zhNfE_94g!>Q)S3@;?_LVh0M5#G~SNt_P^|!d*IEw;KP?OE<^3^{0Lb&R9aA< z0&h0pi(`Wme|_v6xLmUNZI+IB3HL)IkxU@ti~89%_7CWk5pp_$yuZqPzZZjBG1@uh z(jKKGZ6Cy-_pnB7M==9!x>XgQ>S)sO)S8N9{{8oL^DUtOz6SX`I9GyUpI~T0WQcS? zpfjq?)F!1~M?I9d_&q&@T4p-UZRKxN9#fhW5xa)GzvS9<@l9ojUxLuG90uoQ3%UIj zB5j9`{L2`rs-Xjc?fG(Z79JN*H^VU)GtS(E+}zrEyLZWfZ`H0S8%BdGg)55oHMqLBRb>ERRA;Z zdbryN`Y>pwL#IC>`x{q2mN&+e?2cLq?7s>+8J(DX+CkEi|2~-?ugksAP(>K8b7KYO z)jTw7kPI351M-5G%y(ZP)LBs5nD$pn)6`LiU|y(YyQ69FPDOj$n5HI-bYjv2XkL(o zUzZXJ>mH|j6F|r|#TQWT*-eOK+@xbz?DJ z&7KAO_lf~lU;s)+*LsY>atR`gm|ujd#7bM+VV;@%e@VPQ%Lus;-c$+tij%L=Uf$9c z^=E5Z^6)9y=~ccbliL058YdU=@f9*l1AkWuqcF4j+iJwHPf+4>t7)J%L2965zWeMR4YcVs#Bq!Q=O;U`!I30YgyUMoGPjNyPL^}2kkvNoWL}#h>JH(MBWj&>W@8EWl&m+- zUm#^Rf6r@B2aO|S%<=CPcjsrf*p&UZldAw!iAy(oCM8J&1TOrW+ccpcLKnM)yzuGx z8sERKO*IfE?>@HgT+$%a`ed>7yzXt;xJ-1BcAd(j8xzh${?&7>7=4(Fm^H4ZLRy{k zFnUlF*Kek?3 zIgp*!m}Ew(WW8?JBh=uOP$S(PZDK){Qdxbe*Q2!$t!n=~1B>wfzqdRUaks%9<%=|_ z;eyQEVciI^fUbF4v8e{!lk>qhpU+Ioe^GM=?O7dLH8$fnX%R@@!JSodl>T6p$Ze9$ ziAfZ6(4UZNgQl3npVO;m@cl)kD)#TwAXjno4F;Zo?Yklzd0%=3B?x`lH@tadp50Te zt<@}ilc%6plH+qbFx5McIJJdD#;jlRRZ?Y0(c>^4YF$O0MAED<7efMnd~O^w>=}Ij z|NXf?i9l+lT2VPFU$0woNkfVR9E%=L$*rGsrml3yDNwD*QZ^OBL5LN zb;F*K$y%@8xdd;4&@R#o&`(fwQ)3eL~ z4wD9)t{A5(t=qRJD$TqPr+ra+nJO4!jZ~F$2k35$P5ie05La!Fd{H`9=+AkX2)Fz;Ws2JHc-6$tOs;B$L-takWJQs!iCK~E zyu=i-D{25`bT9D(A1anr(p+H^g=tR~iq(=eF3TYQ)NYH8{~B#14WkBf+lR$!rxM2X zw8$L_VX%rU^mgdr@t$K~ZUz~IwD23)+$}>tSWB@%A07ZL2L!0a@B| zlEBP_wuNQ?Fy-;v@Nv7jS3{Z;nDUWVVz}#DF(k1iNlruaWEqr+I8O<-c(=^wE{FYs z!=F%*^rwoO0JN=k#G+wV5vbjIk~CvkUtZbZP-DY@Dx7N@9Zr&lUC`N#Vd z^IOWwF1E!P{InJiZ6efzu|?-Uw+M^kb}(X>wi7PLEV-TfRo8iGfw!13bG_S-#QVWn zA`F`Cuu_G?76L%SYvtxG>7cE64NZ5D1^IMltwHf^34zl>$NhU1Kf)=42pBR4^7yX_ zRs~5Q(gWAfvsaSP{*gQ|h`V9h>~*xhOMcWlZ8b?Tr+pN|qDN=Ki7K_5pTw8CRO_i( znvyuEUBlqh6+7!u=VQ_GpYX)xPG1#{|Cw-3E~=)+HqI=rJ3jhyr$LC^i7D2 zG33mO;&CUvn%oL$Xl6%SBgS84!5fU&O~FP_p^Or#tXLL&YRmSY%89j>**9d^1auDK z;ko}VG=(DXmVi^f<1QU2kM+NxtKM$d3?2 zKQ@8+`df-HSKm54W6k9Um-d9tW^j*6+l%8k+=22KQF-pyOz0^boE+~W8?2?Q;*4mjoqwjn6oiP4}}PZn{Zha&Pq>8BY^O}Sn~WCR+h|I56jT@VZ) zQ0mylg#`?kV?JTw8d+ zC}pfypL`hjVyDE4PXdnbmG=*#1e5Jq+t@--qNOE=H>A==eazDWAKebzCt8t^g*STb z`o5ho-}+Kd=PelL1Qn=c^XEsnA*~sFAHkTdfXunr233fKUDCxT4RsOqWB~8cfbl#V zB+MYJ*&VG#d?C+I54n#Z{A(N7zP%H}Y7ynuxsq~K-(aeu!HBrH_%Us~OO?LxCAI_a zLDkMTW-DwCH?8g;MBz4O&9kxG>_CNvE~HAlcf-%az%!Z70^L>kz`OLJKSxSIdZc?1 z6+H1S_1U*x&TOxK{61PEe^g#Lsh z8699kr?`RVUNuij4gUq`|MB2bw2=qY2$PIFlb=KQr)Tp0aUbcDaSMQ1C_G;#KT3`59AJDu?zLg&t ztoAXS*R~cywUYAcWOgpDOqa%IqvcgBQqwTI>gAQjoW-BITW4&!a!H>@ZzOYgcy+fd1~_686QiBvee{iDrV_glvRKk{#|mM0QeO-g3^m*;rs@V4~U)vbmJoIir;) zu(sS7aQ#iW!=KLYjUFzMhMd~p-?lbCYvs~;+`&!K8W51__2gk4`S}rka~k$^gIqG! ze$!0R4Q?@rq3EnXqTFK#y4TOdf^!G0*=P>u)v*d*ft(71K4Xd=2?x|+1SM$1sKFPG z2YEoy5>*wrsSqLOYG>bDR9SNh%U&Ie{q@f1hYeQ* zu`tnrL9;;tbb|_L?7!m$n&{numn9HSDW-!O0aWUU%0C;*b${2!k%=dNY!F00=#yc} z)(sdp0f7dhs^i)D)&&mRJWvD!2n^9#!rvtU-_pX#Hva5ixsx=Y2(+iRSQ5{BDe;)70D?n ztunFSWeh79Pe=y!mj8U?B#s z##X%4Zo}{DciwijNwxi&{r0?k!~OZoC9w^ zli)2IklLodV5ZKDkyty&UrI#@ezJH>RZ@3p`Ya<}R8l=p!xL>bXCz{Cv;Hc_529rs z_FPQA{)7TywA@^ugoC6eRxeu-_Jp(Ub~xK-xJIee;2K?rk`4GaelGuMgm+@n?EhiT zC9;b&dyj!{iF#lrH7Li4*?K_(zH9EyspMR7f+U0YS22{$)m~Tbh3>xkBJ%eE61uvL zITXC5217$#ed{Psg_XoVU2^&AiLJrAs-|S`k}TTRuG2K5?PKE7e165OZJUb)*p z#5nO5tGG)6wNoGO{z|z2P(6j%1Zy;l=T9ki88*?z4b~K?gmIA~Qb!SZ_>xCIS8J>U zk~jWj|A9mB$As-3SC01%|3g&vheVZ{fgOeFT+wOlJ>q_OQopQV7 zE({){p!Zy-$6VLMakjVnbBTUP;jn*D!mj4U?Ul0ZwK?Up&?V6#`V`D#bs_A(;<9wrhQ*8n zCd^1ax~Kw0eLw}bAU+<%B?P)@u;C!1h(RiVs8tXlBfw^^RF_1uUn-9|;Gl}QgK>Z0 zj=}c+3+aGNZ?}k7ArDYvHP&7^AN#Qk!SUY!SI)KO}3y)x{h{xu7{R zXi~O)WvY{l>)OcmAv)MWV|Nt!rGjStUue3U!;T7!C z@xmc?eZ=|Cv{rDWXmEFXt4MeSS7Bkr0=w9Ac7K&gPjZSK`}LUb({82}d*5dQxv(Wo zjf_XR(k~v4Y20{i-}$Qr^C~JzJWsekv{7QOXxr5#dl` zcs;#XTG&s0t85?YC;gfdW-X+nWw<)FeX(Sh z&Dm)UNNh*Kd^h5Fw|Z_aH{;tX%lfZm$LAL4kkVc{ja+ix^_nBeph#gXXGHM+l9it1^I$Ld_4R7!dN@ANUMVr9i3q6fveFaYrVJuw>mSAFf>s#}$O zr}^As1+uQ<`}So-6PdU-rF*Bx3wqH$7W-+|Gb460g$B4aI#3&s^mVV%wF5~nEiufb zx(eFovqQ6ddSW~<;1KEl2v*=3k&=F=J%2ITxO-U8U3IbeLC|>L`b`|}8w>)8FT9od zG;VUOB(I#NwQ@LXB1ZQ=Dh6`$T6{42#4D@%Hc~OSlqEQVFhY`$<0OTg0u+C5zBczB zow@Z{C#Mp`-q zm7rI=!4;a`n^IrC1l@c!zb~PC`{*dgfj^HjD(VHFdfTjt=@)Ye;!V3E-k(Wbg}xkG z=H$Yy%ffTyj%<#?Thy*p-aW$3XdFk|Y+vRoE%}H<=8k!R#o;pDx3ViAn4I1oOL7n# zOXJR;t+t39+-N!}Gn(YRp&;D7Hu$S4{V@Zz`ojRM^Vg*LJtMV`TV2fsq8c$8;zaXV zhD#;A;>1@z%7mq0?GgliB2cc^53^^S7>tUa;wE2@ymQI-3=6Lrt5#j}4W`8P7u79A z#=bQj|Fi`h$eCQDqHdi{Y|Fe`>OO(-P|ej8t4WaJ`aKjsQO=e2CkN=Y^@?@p)ih&2 zN)Y4B$WFIS5^-a(WjkhSqsCB9ngQ8)7=a2Snj^JlL*CvC3Ztry)bfX4Hs^GWjJgp+ zAT%FfGb7#&sJZL{{zXUI9VPnlG7qmKj!6^MBX(?cv3=4t?@N)AgqCe>^XYy}Hx+Am z4u8A5^_jsMKN4br`RPV~p8Sd6$EdK%i5BAxedV6g^7XU$IIH)MWFNASV zO|4_z73qQ^X3+-^9#nmBefrwigF$&C$#qMw;90>w?$-mk1&xt-pT~M{gTG@5ar#(A-LZjPI*6Eg}8@^*5$eFGv%piRAm)SdS)}mji*Z3$0Su))o9EX;h z3X=mHPe!0A84Kot9WZ&h1ELemci%y!kAU_eQF{ahDNr%=mB(Mn`uK#rPwl-U1M}9v^V0ODp1u%D!xby1h|1RN&vT4P1s$Pk=sX?Q zS+5EUM(ZA&vRNTr%3hy>n!oDV;t8|i4Y>xBA2dH*9A^>^HD)EgPJU^0eaRf{X)KH^CyOYqd|x>8 zc}jzz>SS}chO}1s<^**NqXB03BUCu!i%WjRpK&o_XjPt3DqX_bqM`Rv__{&dQG zq2DD*zsqZ0X!9Z6fkWl4`5{VIE0(DgoXj^}XL;Bf*|z(RqDi8eSOb=G39E~jW8G+} zOVsN-^+OiJGtyU{;ulV2b3H%rr;5n9Evz}`0-}BstQv8+s64~{mvC@$IX#})e&RhI=PILaIi114iqmJp9Cu-Wv zoSLwGjqbQcHDnTu{ic(}C10htB5s~R`8t++zH`60gFV7&ndIbMnL+PG->aLO@%S((lJXvJKa$fW*~)4URYU}JL=$)FH^C;%Uxj8al(-4>#}(W|L3yKasV zhHts)zei=(cSqBt<-5}Nzj*Ri1X4fB@KOo z6||`-sQi|i+IgZW+0Cl!(MO|~zor!D=h5V2Te>e+ZrKRwyj_|Ie0S_}d8)xeQfGux ztnbV?jNI%{ZxIBeFe_OAwwJJfQ5jSqL8&=?F)Q&9)Hmk#vakv{Zk>0pLjiNO}nP_Dbv(618+R< z&uCBSay@p}Ym3Q6tf`I5k`(a6by>TdTOXExOaw&hK-7?uHvn>P3~2mg2Tecjdi+U8 zOH0nkCiwX0K67CxFko7DZp3vbZ zu}bnQl_Fab)#Sfp?~waSWBsyYJ|N=b8)VuLlRV~dp2d7GQk0UfArq$d-eT<`t;n^B zHq#$Wf>SdjW$RgOL*dC4W0)ICS-VL?wO{T8SA%pu?(Sks$MvYYo-JiMD zPD-7!vpZiUT;W3*MMrdn48M3?vV{O!3zI)z^xHjc+@GuND=kXOBl7ndbi491(}!2F zPXDpEFpRP*ipd}xXF!vVZCT$P(WC3t^-xYMxy+Dc-dYH74*Tme9$>(@ud2$*#q|or z1$|J9!W4RHjU%%TY`W$f!UbT+m zVd?&6I`ojKyGgLL79<}#N)#xC<+p43Oz1Bmc6?|7k@gMOk51n1vkQVLg(`RV+9x*Y z(}i>|H+>o}S_+6UjM(idpZg0!AK&SNzQFT%m|L&|Lpe|U)a9$|<`k9%$w-^zKPTsZ z*I~ZE=a9xzz93BAygS$%=ZgIy{wb~-}j}Q@>0pZ#6K*y z_HKKdSF+I*7E`zwS6xIPre8#h06GbzIq`U-ooK*$Az$_%S+S2sBSN^H#W#(EDdft7 zs!>j?z0K?k!W| z;S&yb{pJha{EimQpFgQ`&AX;=S8O=~p+R<*B7=jI&1oV@{^Y_!lXFI1eNaL*^>+_RV1EuPtv zCDnxNPfx}T%E5EuWG<#Z`pfPN??TYd`T^;9Brnn=EEmUI_qqMg@A&VML|rFgw|WaR@MS#VmTZjks>81Gb0pqXTXg<0A!zx; zjP~Mux@$=to5K2cgU^|Rk;C*p^E9+{rNP4y_YCbH9}B!URpWsQ;c`tiq0MSWq;IMZ9Z4QE#PYo;xV=uwx=Ce!7O4 zi$Mb|np#3l*3l302+fbEQ^6ylxmhTVaZu?6>g1)^z*DQC?++qG?o|~`6#VF6w&S^G zoVSu4(}YQP5t=iZv0Dx2t~Wn%(y293Oa*DfBw-IN z{G;gA8TPLXe^X-M4V5AlK2O1O!xaAX_r*l3?c?>TD_0sOR7XwzeC)eFebY&@eZ+9s z1?_AR-};di?yR>#+=fGDM{k`GbROs$y8*04au5r4BA~BS+7D#0rKOUuM}N&ruYC0( z>-cUx_HgEG`$sP)Tn;XKHd@!qe>che4hmUz;@L+YBuL0cU!Q)CMO8rh^Q?WOPLxhP zaWu4G-6u{5@8Y{QYkv;?YrAj11#hOA5l*{7RZs|hDY-nYFrd3|}Je_nvM3_jp`!jLW z)F9>h19E<}uw;Ew$?cRvabl|Mm9jw;A91rZ9uf?e1?)&D_rNwAcNPIlJdbb&Acuo zi{d2Tq|58@i&>bt$ec0>_1a0VTJ)&EI7q2Vk8aqoyy-d@X9=AY<~!o*U$gfo&U8HB;^3D# z`=e36l;ZRd#+ZfRk;o4TJqNMv^Rrd+D)40jad%_vE0G&BKfe4R@!k!y(H6af?Lls_ zt*nNyv32#DZ?9~%svu32-;{P;X7WGECDb%-Ms2o)yyAmsfw?*aha=m@qS>uR{6`vD zp`zokF&-PC{4Y7PfmX%2fUhaqj)^7PsSv*K7(VS6mUKW)>(56s@{Pa8Vv~Wzy9=W! zkikn}OnjN=V4w10=#RpfN4o#c{l@_MU6M9n7QgmXv)9+F%Rp|)&++#B8DBwO$_Qq6 z36}1rbbqnvHSrwBjLOPk0p8hqtl=LY$tO^p%1(UPPn}p3QFN3fFV`#HnMLGX!Z~hu z!47+DW~GQQ=)u@w_uOjsZRBQKvGHIY(GxDjTtYBZ5u^0~tqF3?z3W1x68sp1L<-ayQ_mPk86m(sbgSuogqm-+Tob!gOr5c_ zn^E^rD*1|~jK-6^un{vn+a3pDK|N#bQX30AUvBKDmAb+!0JgQ*&FhU~5T$w5p^`)W zXkzD#kl83GrYLPk-{9@jm~yhb`I}BDRxEWq(?#*$=P6PZuId^Rcgci)qt-OcsKxct z5!Ad*sfStrz6_Y`oU%5yMJ~2a)q)1owvJad4GV)u+R409w{$1BAN_DwnjSp4O~&bT zW?lT5_#3~3WCh80{b3oU1nToUM$^1h@*H8qA9CHpi>(TjGG`tZG+L{zCl^uaF7GHd zlBoj$idj%F9vW428?*IK+5h8K6WQ;K3ko*-5x(#pNJw5l6%>u;YYH}fcJ#N$+)NO(m>){ohs|- z-w*USf16e233htuZQOL48r9VosH%I{f`k85kEldtaG~`NU_g-xQqHrFhsmg4kTD`G z6K~*qUG-buRSJ%`*We^#9-_`DiX5(uv!>1vH#L%x&CP8)-YdfEzNeh-v-1;qA-6nK zH}AbZ7&Buif4q4UM=ptJ#82hYJmY_s$KTEy(_aGSrsB`mz1|^bb`Xt;V_9bgkBE%9YG7{)CqGm@^T6$6Z{a;FDd7g@%!)@K z$kGjGo{vCzu9pIDdh)1B=6b@8T?kt8@(ur4Le#XBrJi z)I$DIxOZO*#lXQVIKy+%XxdLHS$cGf4bPL#FGHKJM4y|ebMJTUcM zv0H)4rC|PKdzyvRER*(Sx5`J3$9htP)+^!-d6t77>7znr8GGqJce%vELPO)y(>FW< zL7;+(S@|fRoTq~x@9Tm5g4)Xd^fp<>7XO+Ly}X~?r$%Pqs0N0lW2rSGJXEeG{iUPFKW}km-=b5>y7g}R=en6FV2N(NFrV)p=?4IJ{uuyY|ewxQl zap#{Dr!q@$*CR8bDkl0Qv2k&6iLZ&pfof$7DA>Z2FGWNP-6$CDD@t28kr^_|>%E+y znYhY=SCqf|O!FeUaCY!m5_4N;cy(t-H%LWIFFPk^<70cfUeL;GtgpA${R*#W7(eCm z@v*zDK}=|C9ktn_Ort*<%F(Ck(aN4I7van$j>j_igU(Xr2^G<>8V;C=&l3LCIG3&)tJ9BXy} zEc5A4{|yCYo0_}bA#vV$WFVY5Hvi~vYOJ!$htNdTPUqns66-c ze$AY#r-*-*;fSWkx)Ej0!orI5ycaWJrEILGu;g34!kU99c|G-83F1zP3ES2qj`C3H zh$O<;#XX-jeI|*C_A0=Q=HlY~QSBqR*csFDsD?syH!QDn=cGOkpV2=KqBjH?Joe|> z^C30S*Ku)kW4Fl|rN4oH@Cb6{@y~vuz>BbfcfM!#3FeNN`Tn^H%5x=~mkeyn zmM4l4q+Pb*=!y7GGdwM9R8JsQo6^G|@eudOCA&jM}M7TCB{Tm7Hv%9~KSj`uS;v z9+H@bPfevv`YOlDYh7NODaDh`R~N3*nr;Y@6kJFf!MWe+=w0M4ub%8y(9Nlp$vLzq zl)+9({`;n`2LX;T{_|k%h{XP!n+DHJ?ZE4v(=b5+L)`cQ6DA~6d^Js-oh|-&WEtQ$ zOE6HiTy$|5@P)`B4FsCvPtkOOc#YHFfi4tqZl=9>L%d<(!QqB_)Y#qrVS4YIU5+Bj zA(=xOO#AK;*NnVOCCzp9I9n zv%xj6B_)~p7JlD1eq0|0GD|IeeKr^JP*(!=Q0d({G}rKZpA|t3DLsP(ygMJ-zxHKv zA1VA z#YINFVedY*9376*Hen*(efUddY((ez%mgQuL1FM-yiEmb`T56em(EO1Ir^YTqIc$j z$aK|jHDhL~06T|WW(ll-poWqCPc3KBPUW8yLFUH~v|?MSae!c8_XF%seDY*(RoTbq z4#4*I_6gjWAyqM=NPS3>JHE1S`n%XEJ&)-Ai-hDje?~FSV)C~qHxoW+^$N$9>k8_8 zjMs}GbW}Kcq)GmFvm6gSF6eqg54W&dh=V8W>|_L^V$#W1Irqlhxbosa!g3)sspzp= z7tJO((8J@)ln$$DW?32zp0gIM~v_z7U+ zyMQL;1GD z$$!~g8O!Tqxk6>Q-Qu2UOuHxr3WuD?xyvkz5bZRNeK~9UXZ(qfih$5ylKipurS;!8 zY~QH=^=7YCDw}7Z#|yzPo_{94PY_)#BrG5(yt5-z`vkuRCr|zNck!=Z#i!#vC@4LQ zEKK`|&W%V$GBhga)2~sxX6fpbQmY78stkXqQW4(a6x!j`-4bxbptl|PLs2kvlYH-X z!BU^60|w`5Ka$(VK%YMge%T#xRgDp;VSqCNL054bRe$+#_tQI|O>KP>>>Hj$beq8; z#!Cj1z?`c2;ua>x_X6vT=OLL_XRXX-RVCuElLu5AA;Y4=oQ=PG27LNvTF<64lC?+6+TK7hP)9iM}+} zMgF;ZByoUz{jREMaYqMUY_adWfh#ANpvP`E)j~(DyQgnJdW>u%IrsDw!R!@(YP5q% zJy}i9*L5_cI`{&QdRF=vNbw;^G;!;o=Fz^^45C$DnBe1S?8+IMza#iI^>|QW!#6rB zgsnt*vEq7kJ5>RtuoAoNoA`)%@jTmibetRobM`4pMkz0=g`%@qh;VV;;FrcEr^vF+ zeiNT|gMBsB_rh(Msb2rR`F*`G$Ktu$s&Dm*UK`{6N_lHim1a`)d1egnoupmi@NWV+ z@igZ_jKA`ugJVS}ya{}#$CtrOE(1a=spl~H%WC_bOgRky1V#*KHTpi>75!q`QVlcj z0JW*}IRFMG{BRb5He%Cr7g^c=)!CVcL*0e{pJXjdwj@g`DN9<0WH;d{q6vc`TZ?_k zmhAhIl9b54WEnGKDf=?kN@R`6W6hGKu|^2tcYmJm@2}r~zwdKhuIIWmW7}y466Df zLTcEryq}%j&nJ6_GMU;@P!OwiD>D0J%xJ9)M4s5=KYZwS^0h3LStYlX4>^(ynb^a| z7*cbyUL5E;>N$XQfj&a|M6*Z7Dn9L%aijl; z$}!h`f6L2t{@;9bf17mwOKBXjpJ_o)-#}=DvcH^b$8ogy)0>MJ9P_bO`;h)8T!ah* z=kt!chUH4vwFu05X|;GK&DFK6D!GuqeiV1DGRc% zjs)dE>TT9)Vj|jIuM7R(eY3LC?4V8zZz79Ym!4KYaKAkVzz2weLh-1HY0r(fJ4X=7JbZGJR`8!I2@;PV~o{AwhF_ z{~ZE*1Vun^6U*`{F^9G;Ne7}p{nt9anZ z#+g0FBLSt!CoUSZ-oolqE;U^?*czrO+rF1cojI?54+U362oG&-Htg$w?pCICSP^-{ z#Aj1BIY+4yfP{b|ajbFyMMKbnq{V0#A!@}G_#ddg#+bWl@LRIkcTJZJXt zcFlGNb}%xBC_nW|_&O9(;YBS;mu1Lrzkbd9)>1~Ju1U~LZE7mwfH7Fm?7IgjSRebcpdT3U`HS;_92>iEI4 zzJ;+OSzlk_rTS1&Y?ym5$V;3hNFT>2gO7_ zcysmB*NielkSRR!j)=Wdd`1M6t>rABS5)17Lf+O%j_GAVe_by%kn#92#m`KCsNx>i+EGQjcZO z->t7uep1pxiWJ9HI$ut-P)=a-oyrU7NJ!R9)ASJI=KjWNm`+Q~J`f$<|GE8=@^kof z>0s;YY}gik@7MOf?YIdBbeQWBdM&T8pNxR4yZ<*Ne@Xjb1Yp4fddChD+dAUv>Bu*4)Mc$Fl5+`!Er^#ga zUZQ;$3=xp8H2o^z>j?yL=h{VC!=gJ+iso&TKS9fc78H@ZGnTvqJu`~Hiu$`f(BQnZ z#)G~yj@Cc%TkU!}^H%L0jp^f=f4T417i&JeVsv}!J`;6h$Ywz$WpTvQkS%i4a{3`2 z`I*pp_>iuty{M+Ra{a_mT#`%_yWYL3R^@!Qd@_vrl08kb8R1h!>3XbLEv@a*SRB!^ zTI|8+;(bYJ@^4OIzZKfz`fYI|Gf!+9y+0itUG+r7abvPiIjmFlr@c3+8>#*me*3BO zyqfi6%Wnvt^pLADj#+sM%l3frqKt(_Dl<+vg=y?!>Tj+yH8;b@KYa7KvUkz6pYu{2 zYd?;M*N5E@aoofzV^$YWLs`?7CUI2GmXFJmPZ#YtxbUA|e6HE?5bToep((?&xP9sW zBLc_iQ%XKqicr?kmA*_EK@fp*g)|?6=+(_oHggFCwRXZlDi@SzS-)4y&%oct;lH(i z!7+L(G`=&9ITaZ8S=->Y_k$w>+ZHrjhU7!7>~;cy!rL98f6f*DwN8$+4k>vk^YSoD zp@*QBHb0x;$B*K7w)AWsaFZ4ut?zcaY0uSxKV!k%%S8N{5t&dI4&IDv#!JC;*p($;8pNSY&hg-=fApZ+i^=xj$$TEjgk^~JrORw-`hIQFgDEPq@o@5+=?eQ3egh zekU4}r7+yE0ftT@OviF7*4_dMO4fL$^Zk~P1Bt`uUSr;oaK_0Vs|Nwi_xIa4i)yHZ zGa0GyP>C;`;Yt%PVD$0F5-txAGQ)WhDWZ4i@k6Ihm@EA}%E!(d`p*~X)7F<^GqqCp zLR;N#xy7iJhTY+G3{JPQe=KiA*P|yUa1Pu|=a1W&LimK@HX`1)8YxF`kf67Nql-(x z+Hq*6YXuRSDggJ<{`u7^jF5>z_Wh)WhQ?p2QGD_)KowML0oUp_<5!f2nV76>0%F@Ad5lo= z5TS}}QBzkxxV5zv7J0lwLsm^ggC2x3XCP(vN&L`l=H5JRn`SSs1G=GUXR7YwbYk=$ z#Iay<2QdkS%HM=uvyS+w0C4S5BXe)jrlzBhZwGN? zce3aj-0Bv5Os`tJq`}DVKSPl^Tp3oit(nUwOnm?EDU5d*DcWy6-eST^=GmKT)W`yA z=5+GAnjQyVlI$G!+lV-V_*vV`jMp}4Sj~uEt|6h_Washly?qUBJ3Q@NW&T8g8MRy9 zgxDpk#$}y4q>$vwIoWLZlDSo-=%iCC^UKIbXRHX-gBf{5@gin@&C5nS-m}UPSO1m6 zP}1MM4D2|br$F%uqBMzjS9MYEI95c2Vr`Lceb6itH=HUUiRxA8ZyRpW>D=?$tLuQB zNZpq=o@sQ+{k_e)Y>k<}jG@EY(Z)BSBBybzr?8?$s;}?)Nt8^ISS&OO?D+y>PfV%u z%+<))K6zo8n{)zoGe$l1;m8|#;>aENzMX^wf^Nxs?KG^u9skI1XWJKrL)De07fdAb z^&3L_Ab#Qv>2_d{y|CiCsndriksQCa!6W=*%@*=v$)2x3d|ik@2*omU=N4Ff;Qn`| zP?dO4g3>L1J?63WE(um5fe=z^Ls{AI=O>g#M`21$vg@rLr70Y`a-k!@(;*`|u<0lC zvrKt4>uDf(+v517Nq+G?6-)8!0UDrT3-0t@58WzVEpTS^rj$0Vp5@l1u zvU1k4a!AxTxF7d~AG|H#<6c33MG0_w3MP)gf#oh+(0GwOpKF5J&l|C-SmGGrZ#;XG zc%v_#KFfMde4A+gYO0z1R9)GlYeBADv$pWezt) zcFGx4JCGdXE?t}Cn!jn$$x%&XF>FPJuGVL@<^&hh-6x2Ae9+_EBuLV_)PErxH!sUjxaMIg@Nc%LaW$^6r{ER8 zW-8WxVl?6#D~V(fzCceuHctu(fgMHd=;O547EN|=8Im-wLqxYP+V`KW@mhX9ICV^0 zr_Z?Vu`7g7IM7iQt@p9o{Rc()mc0yBaYdVAbYI*cm8(>)btxt`3ny}+A|TUke~_0o zS#gp*sBOVM(30`(EVGB%=#m3|5sf<*ACGq}l9zC45+{WY;GmcSfz6wFje>Kd2MyW# zAbhiKeyZsPlgN3;GZF_qZ?{C1a<;q2WU!F9NY9B@>ZPixq)pRcT22zdE?HcNoi+R* z?9{V!Hq+W^uiFf`NDCV?1*PrXk&}~5t83)$sx)aZ6CimxPc9nOa9GSFlQ1`m-s4KM_>#> z0e`F{dAE+A3ggIL?Cl_wC|=;@GYD)j&-KB2J5-L>^B5<*)!d6^`2In=#I4(s{sY1E zCAt^I#Ed=F^yM1;l02~=WxRtUpQ5qy ze34@(UsHq)EHl1S`O732%oX%zDS<%i+UZr~m!GS7o@Zx(#mX!$z_N%k%&xY6b>Qxa zyOz7VHC*l8s5*|)C8{U(5$CEGzdsibsjl3VDGIw=c3%C+BTt3%yL>P%?Z;}0Tw7v# z3LO-Pt|kDDanFcEp0midXfP*8w=rmY$5YH8)hTln{hX-aTg3m|O!K??I$`yjvq%X< z{+MdTUtdb`BoI(=8FDr&-`F%51>2^K%!d7MqY5iVqk3Oi0OG3r|)?mo)zkx=mQ z&6iuj_bPvFzVjp`<*qtba_|uozS(u3T(~fwrS+|FxMMmoRYrzz77+(CHiEB zinB7!I|JW~=P_}!U{rRpa@r2`M5RyPTX`RF;O%+dr3|Z;L`MaUkNcfGX|i+E@31cP zCCOm)7Z&&a-pP@>1MjmEn^&Om!YPhptavZLEHc_}d6Ab8N;xLOvcVrH?2wa)nI=?T zVysx$adh!x(FNFtY99*VM11bN+Z4oTpGaE%xQ#wvZ>_aC(&&OexZQQ)xME3u7v4 zz_Y_aS{NIaI`}EL`iT==^ktFxdt6Nl5w8N8)e{Ri79JFm*P8OjJwO2T%AabO8j+!+ zI&=y!ib63U&AM=La ziu}w%AJ_A+HoP5;jl{;W(=$GcwPL*iiiJckH`fz)O4^g-B#n0hbb&>|>qv~U0S#BH0;kJ%sc2Rd2$JxxzGQfAEm z4n1T@t27N<`lJJX9SZpM)Pk?xaKO_XC>_?wMS%l-=ka9mBYH7uc?ofb z!2sS!X{d6m?!Jb+Dmy(>eeO}lD&ZM)@M>)@McL_`*jAJnh?uGF2Ou|Sk z1%9&<_y7Llb}2(X05xT1Cy>{P48@<3K5JxX_zCXdh?074ZY~65ktobM06Rulz(AOX^_HkR8Cuy9fThd zsi^^T-x$!D1r^UWkWq^4`-zCT=_vU0Q+dZA z$nP@@eBdsi7k+_>C00t^sX`J-Pew*YG}BkNraa`^+S@^Tx`Wb8-jWkwg9!(!)Con! zI$W$&n60u~Axp1v+3#|~ygcNY0jx#;s$fF=OomM=q92Z^oW6}!HW9N83lEP=O2W;S zBf^@ePv^wt`v-tF04JAeg^e#LHvhfAn=Sal0UNcEC0a~z??hE^38W%jeemb)$kPF5 z=@|46&>#&2n)j|wv5zC?OIljmLVs~|6vO9oarK>^gd%Ll(-x)IiHXP)t$o#& zc7=kDQrq5>{~Zd_5i3w1p75f(nbP^`a<^g&bkhtr1gaqUTZr0sdOQEmSBUOzD#*z} z5Th^RKwc+o33I8`W{4y@C_vNf4^7Osper&o+OXJ|cguyk; z!XF@HT!^~Rz!upFS=bU|$-{gXg1}e`KkC+q3IYTLGFTUZmLth~=%|@T+{8SL$W72Y z1<=6#mGLIOzrR{NHv-`uM|_vbNPGDBxFv#pfCP!g?b~q!%V|kTH*9UsCg-)S%AiwQ z8lb7AsCG#nWb1*rWA*uk2=tKWn9qK9HzTGO0S5%7a}Pwd);FZ!EC|sjknl(A>2ZP# z9zq-M+^!Fsn}b{Xjhi>2Fdo@{<#+t5%{PIL3<@zikVOP#{J#q}Dr?&=GvHOcBG?4A zEGZqhKW_Jq`;T|B&vOH2szxTH?_Isg~=g@8wqj0|z@6&p?ol*xn(iDt@ zLA-De?BT}u606WXC>WRr9e`~J?2`kQp`U*at{n22vokXlpLe*j-}1@W#|M0os+@zP zPe<@t69#=LUdjur8a^KAqj&15*HqdMaISY*B9ix*oxKL3UwVH1STh z5f)3?QUoh6zu(lSzE7Pg6E;o?jo7; zN*9pq^f54quLsN$M0x)-&haN*bH_X_ZlE}0Vt&}+=$L!C z$t@k7Y2fdHp3B?+Nty#k4Lk$`(o{z-7KHm14T$e}Ru~x=tib22LB^9qP;fN9@91Pm zFH~gs41?uGw@hT>@tj%tY8Cz z(%1a1M;65cAnSYqkZK^^Pq1(YVFfNh!A@l2sZQF7=Ye9M1pETjj_=A@sVEEUs3AWp>D-1YayeHBkZr0YK}=AvwQR^JvxEoW|8|qTtApmP3T5 z02DB2N*YM;)kk}9KuYUAs3ZJ^?0sQn_rQF4Sy_KcqQ0ss4a^pC@9(fn1471A1osRK zzPqJ^pkXfww&=9FvrYnCNI08idmXE^8pwco3#jx*@>daFIeuhwU+`fdFA1XAB4^JsVH*Ab zJuw7!U~2J~t}aBl6xmsj`#<QbR>OW{6hQ>RBMMOo1?{-B(xZuGA+jfHr3oJ`JrpfM$0GO9MRx zY6q`Enl3m@98BDmQ4cO73x64el-uK(25JH18?!xlQ6qJoZEzG~%F6XVwF*-IXlPKp z$wl*sNB{eMe^(f;#yQa9LfNTs9Bc+!!(fdk)OUFYX-(6xMTWuv0IA$`b&;ueX?lA4 zBGAsV-@e@%4`nCJbNiDGK|J=h5d6T4h{e}1l?PW>SD_SQuBHplC2y{}pSh*7>#^NUEYjZO?qj&(S)k>h03Ca#y zRQ7(cs}|<^0x(w}gTd^2!VUm%@5?||r#gY_1^Q$>xD(XRUFi2bRlVF$v)oz(UO_me zQxk3!X^TK|k^>5*GNN(|(2%0iWgsvZ7#LgtLM2!qNcsGjRoTBvqvZ2@9IVT!$?K&0 zj**^TGvMdjy1QoqT2iQAw>V$!Gu}uGwh&Hr=-H)1IL~K*M+awUH+T0oXw`+i&hzKh{o6bw6wKTeh1wDR1gOP2Zg*|o@+{#K(cr^_#n}$bdXR|XNg8k&$y0jwP8?3=cC~o|FPsc}v zGgT4VCQ5zQ?NxQ89%0DuFP`Ju3G-a}Ow62bvXHuFyfqzO`K z4c^e}MQUmXqW3^9hCnT(*g1rSEo`(n$Dw_j2524MR$LuK!Y=u4I4T2hmRHIotYrf4 z9>WTj0z{vTA_@fc_=FnZZ(-f8G}6_*LUDPVgTv(iBLr8kNgQtieTYY=r5ru99jm_ z^}i%299;ig=iP#g9|k1IgYxVO{rS;RkF?+eCPM(}-7}PbBFwh+hhZ8=_UM<-pDn>b zs;jGW2nltS+4TW-K5D|@IKg)tVnML7E&=%)HlIsgOLxR2B+f}mO~dcSL_~sNvLH63 z2j^`sp|K>6?83NXkkzYANK9t*Q1LD2m-yva%?2z!1fOVpUE6_7R_53qPTz3OD{6f-?&08YpPjf;9b#Vw06@bGM z@W7doYYPn|erAp~{Gkv0(<1Dr+yyd?kN0;sw|?FVWN8F{$*1fmhd@4&^FG|3ii={R zpFaoZf0blf`COv7?)}#Hct}9Jw0MoYIq20OTLhrwNu|3$On2;e<>cffTzfTKF48r99=)QL@PFWSeQYe;W1NKHX@UC>LF-1{V$dcSA#0y$Ed?@_zu8nMZ~I literal 0 HcmV?d00001 diff --git a/tests/triton_tests/plot2.py b/tests/triton_tests/plot2.py new file mode 100644 index 0000000..d433548 --- /dev/null +++ b/tests/triton_tests/plot2.py @@ -0,0 +1,69 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os + +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(6,3.5)) + gs = gridspec.GridSpec(1, 1) + + + rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True) + + ax = fig.add_subplot(gs[0, 0]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]): + all_xs, all_ys = [], [] + for k, marker, ls, color, name in [ + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + all_xs.append(xs) + all_ys.append(ys) + + color = cmap(j * 0.25) + real_ys = [100 * all_ys[1][i] / all_ys[0][i] for i in range(len(all_ys[0]))] + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + + ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% time occupied by quantize ops', fontsize=12) + + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048, 4096]) + ax.set_xticklabels([1024, 2048, 4096]) + ax.set_xticks([], minor=True) + + #ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + + + plt.savefig('tests/triton_tests/plot2.pdf', bbox_inches='tight') + diff --git a/tests/triton_tests/plot3.pdf b/tests/triton_tests/plot3.pdf new file mode 100644 index 0000000000000000000000000000000000000000..19e93a24eb4a38dcc82cce0729c3e8995a096054 GIT binary patch literal 20122 zcmeHvbzBtP`=~UmgtVj!hzJ7P%Pys~N=Pc*DGf^q2GU5kASkT}5`suchct+a3JOXp zselL~+!@gCYw)|T_kQkw_sf0`bLN>d=R7s%i8*KZ)#VjM5Ta-Zf8hwUpc;aJ!lBOQ z=OL1kP?+`w7b_@C&WvE@;A{hhshinYxkHga15K#3G{nlu5^RY3=>|n-Cju0;(*UM> z#{9gM1pz9yd#d0~P|_lp5v-u-U4Xh7fneq41jT@-5SW&&nWdeR4HWzRs++TgmK6bN z01PXy2%uu+O@P8qIRYriem}~6Kc0db{e%wUPXa)^LG0bE0PJ?rhiO{5JA1fU0Q7+P z1LLb&S=yP&I(q{n!hs(SgBFG35MppV8ji(7kw}~<4v)rSP{3d?Ie<){FLx+@7g-r6 zCuiUc0p6$y`G+;&75(jRRvywuj!wsn>f3L5J-9nM-Z!K_ z)c0}PDA&b9M?op_n)sc-uJl^;5iBCsi^Vh1WXo8mUT6Bufhm@UQpG~|RnbLJd5BK3 zFNQj8>x#KCSu9iOXpU@-+FS@l`q#vwZ z%?KnM{?w7fJn}62Z4R@oM_6IkRff;;F`e&L)hA1nN}gPM7|7&ZHOyQ%GHw<`W>rXs z?JY^2yKTllPNQp8`neSOIvsZver__s?Qz3{o~4id#3wZ3v)FXA*U^RYxi5~}t;L&?V;|IU=`%l)kU)h&*WYCmuzc>3~gMcp5Ok&>%nm`gW z8r~D}X*yBSYMIsq-j;{2R)_C2ebJxPxMIFh7sV>Yd=Vc}6dy83>3SWWu(}Srd)X6J z?MvMoGms#ek{ZGu>l~{)Fz3rLic{8LTlA+eE9J4hJE5z=Jc3}n9fL|rBd)&|pcFE& zuDj%!8734s*1AKQDcK*wxjQiTWs>>&3@r_pv#C8P zP}8Xlm(81e+AtL+aQ5fnvQGcy!3lJ2kML42x__@SG1gj{ zP5HT9Ghc6hn?vElis9Z?&CmH7jVYdCUAYmw0^%sAVYZ&)GDOk9YZ?tfX;Q6_RHMw{ zkH!)646{{ZX=OB?X%N3*t)sZUN^)QQ+cPwjBHq)6x5?|X*7o&6rdc&FExk%hP^!{r zUcHq6l|(mODyN{8n_&T0M~}gWA1o-XDXY@bc|EYH2Fc-JZy;5H>UOdxcBPtZtru^+ zU$9<2E@;0Ioa*pAzfSL*9tXRE4;B+bs5ohH}l3{^m zHH?>O9W^5GwWG2thImc*o%BLm)EXmYRz_i1&YAaDq|u&4KJfd&x6H&|o@rAHW_n>o zzg9YUxz>beBJ}mKpjr`^lGe3`F}BARcT_lzQk;2lh77hJe>GXjkf_h}y+vsma{2P5 zWuuvH-PZJsnlI3U#;_9t4_-7Srm7OtCkz8z2uja@=JFXNSGQk1icG z8+7<8^5Vdy;S;%cGUGC>cvcuIT9}}VddI_>nKIl$nl{G2qQ;ke5|7hrP>k-k)<3)M zC;DabB|r*(9TpZ71Ctt_2U1+v4PTj z9Y!8`nBzvVLT?S6@5ZYNxuMEfOUAdi9lyf;?dLN_9*CVU9F3ms>;58&wjDvf#)Zf$7X^ZXO%< z9#_d&_owy4#8fPkaiz-R3}kx|&NuG7O@B~ECUQec32SFf8ZFmyHP5GH6WMI@LPj=} zdnv`=T@8Ng%C%97RI`GmV^$I@(QDT&OMqu}Is8&zUIxO0FTH-;U)q0uf}JGsTF{Mn zno=u^Pxl@tS{f|87`u!z=s$R)*^_z1K_P&O5`ih4u^2|9nnE;uZ0R1cqE} zilufV#<)tGnf2yj+KxVSg+y1bt;KQEGsL-k4vL508VSa6x;>(~DpKWMYrs0q(;xn# z-Fz*@D2Y6N-ExIC#83$)`^gJ}4yj40@UHds>%PG8q^qs_GkvyLu4Edp z7^XwnBf@pq7lJRk3+gWfrLX8TH8)MULKr(qQ2~zkZ$WNE;2M)!4X)DIWUMus&yP)* zH1+_}^O@1Li#Mur@?()ZBr!xk#Qe}7x9O3=s-cxlD=ILwCsovyW(mM#>->Q8g=? zl(?^|%}rVu3M!P`#&HqGw)Bry_Lb+`e?gH4HEIcdnH#5ba+h}`>Vt}n-^WSG;ctiC zU_JO~scs0pFYYLZ`lqWR4d;h39sN_R1syz)HercM)C9skULZt6vM6KNxmVt8NxXd1 zu@$)RG3U_}y@2J%{r_ngms<)nLKzf#Xy*|^2%5kPgu-9c9*8E%R zjKk?P#~-dL-lD$$s()anCD+1N*5I=5B7+Bc5uzj|zvUYXc4_jzBnNG{j}}$H zWKo3Utdm_(5;d;CQKZQ|@=f(ML_|d-=LcJb{O0XB(pk=$X#zWL+lhZ-^G6OD6EcD(&MEwowe8)_g}C6TT^>VP$oHgtl0 zc(`3WN}c;qn@$Av`ds>GWzUf zBFpSUj~`;m#F@)R5CJ#${W zA&s@)P`LqDQ=QX<1R3@eKO^?IB}wLD`GtAY@^$0s1F>h>r1V}s_r0$x&_aLY9(lbg zfA>Dv+pzPLuWk_CY;>i1z;lz-V(}O`hE2?kJpRJDkG`ydr+KK#rs!IgMz_qGa5NF) zQB!v85_a52gt;;AG@CG^_im^AY&$Zdl}vJ8N@-;9m;}EyKU#5{n%8l(dz7A~XI>3{ zb>wd|U zi=QpNEQOX*9ZB}OaK_Hj$9Z(b$t!tiIOu)iC~GRI0vWZE@J%q1(oaKI=id14Oy8)kEY0fj(Eis}P2YF7lFt}+(^A<)I z2Mn!pj+ltO-(+VyL~u~|QI0Ff&~C1CD`${M`d}0=7m#=3$p_S6(VSoI$Kb-JIcit9 zR_pnMn{wWqRG{EtsUPg(x-Xo=`EpjEU&tJV!JLDumz3{g$jeiNOerQo1cMDvNQWkNS=rjrT2K1v zv9z5aKgpJ6wlAkK)!o#`trA)JWM3Eaf$l$%K!L&Z?}`Wi3!z4Jfr=-$sM1LJR*U$n zwhmp#)==<8pGG=^$CXqp)Q7TP1tpR=`L2&3T!(&0yB&+Ya=Jvqw=S;(o@Lue1P>fM#XogbDt)rZFTDT8;%Zb}`{a_@H! zjosukakDq;Pnlw66)9l5XQeVQXLi>?g!)6Ij92ICQ=g(`IQ`->$*7F9)1J_f~;$?rTwR z+My#i4L;K!ukFv^eu@#L5+j_;wh}~J-2HmxiSK3va_qc{k%MM|?;%NJd7YUV3Z*b! zi)7a6aT%)jr<-i`LleCeNMsg1m3R75EeyXo|M)&d{`Fo7LF% zCbM7I%l3UKImW+o@7C+Hsg?1M>j?|%Cv2Nv(s8)V$sua79xo_#rdM{V(*&Nn@%-;r{6;botuOj$o}z~b zTrBLc?Iy$Y1y+4Cv!G`(A|?P;r9in?nfl{RlfsB%%qR~xvb$B9>bp0H zL+`8K!yB?0+U;}AVzAi*7ceLM8!kxHUl=6J8GIZm4|UgtAZfVH!Qm72w{KTqU#wrL zj!RLCS>ov!WaLWmPT}J#v2GvMw}e`cX{of zZO_C$U5Z8Qmuq^)2VOL(JU_=`BRWiGVVBl|?{JDRVl0jHxOZDiR|%d;9Z0BXPJp^= z=d5%@A*W_=u>v>47e5aRd*UzC&hE!X@-R<-5%TfJ24u}$SR`h5li4bRMhMO1?Vf;{a#Ay10p2>xpm zB+#A1YVPbwfv~&0Q6ErIf;WYHZ+Wu_&p*Z*2*gQEK;LR=hIQoLO6WZ9*qHzkw;muB zW~dsoWK_j<)tu6dU`s%L;Lh{py;LaqI9q(iQnj~`FXiRnncOHC+1XmpvE~Umu)4$dGJvg>BSwd!1q{M4xvL z4({PnqVSl%;Gm|mQ|k%J;b|trj3-Ms`YvedysrJ~Y%iyoFH{?bXoFwsjH#MbHezrM z>Ud~eWFVP$QPe;`v*z-_``{P@^cB=_hO=K4l~v>zqT>wG@Tx}PgKOFEuZ8&*B+YRdnck+ zzD4?s>c-I|6_K-TPFidxn+Xw5FA~Uv^3k6XKK1LA_3zgWyTI@nzj^Y@^m7NEU~#AC zpTZ^&p<0m)y{5`|)6C>?rLkc>5Begms>)u%)kBHTT~x|vFF#&Avqc!z@A0U0qs%%n zsVcKnLtCU++R}oQPLCAy$r3JMPpwr%HJTluf4kjvU z{aNh?iWnOGA6vWR67AEdV`KRfBlI3F?q5Bpsv4)RkpM<%nt9^{k8ynMZMhV=>FuFT z+8{hBRdHN=(skYabcn!1!ct#idECXmaa_)NS)=C@*YR@u`ZTYW#uv>yqUN$jp&Wgg z^h;8i_6&NZk3Qmsvs2vbC=LZxlAx%gGS$UD4#15J z3){3dM`G?=7_dlR+i&rePjU||_wbq#e-of8#I(DT0bcXsU1EN?A7r(#eVJ}{t2lUN zP{!w=EVcQWrV_(*u4aVh=!YUGnk!0d3vC-ni`r+pP=Of|>e#S?ft(G`R<^q<6VJ?x zva<|f;=ByqapMIi_%4dyZ;!92Kz$<^ADZJ?WZV`KnszcH61}@{VqNpl$EVRRH9n+f zMhhLY=RW`HT%ZyC>0xQn+eo$7k4W_2K1wW*dbN0*+f)409>nh9ri%TAqlH2~1Rk>z zG!a8E%Mr+wHTh2PSeGAGB`eB$lDegnsEpdLoIlM!)W+{R)bT(oltsQYC5%m;D~%|& zYwR@1KK@~rjy>?%!}G(T{$eKY#xwZADarb8YHjibKIUC>0?c&tKnjMy$5Q0 z_*m$_2&VRARnmd8qO2;y6y3xFNiGoHn~jO&`>Jec1ojwqFLw&_7rvcBtg;&^lu~}} zrazzG=7ePYX)8q4HkPK>66=|K=Ud=uzB~ruX{HPCOC6tjEGyo7;m=BNkTEG7z22cS zQ%YPQq1BRNA9jA$!fK*&P=dUbqt^(ZZ|RyH;i2szIyNfwgC+VNuoZDGzVo{N+_T-p zJdW9&D11cl&a}O#{t%n^n|M_FMBdln-YV1NngEr>_2bX`HnvE(1+6#sVCB!A*6wPF zAIlzKS{~+vo%Iun&TfuSnC$M_0uclpiGYfru_!1)44A4ANYu`fg75hQ6dv-gskLz= zXp68IoyU_%F)VSJTZ1s4J0Eoa;T;16BvhD}As6?+bq}8g0if}(<0-^QFp%(2%lWEh z@_peWk?ylb3S!rsEnYZiHVo=Q18NxkpT6=BoAD2+;D1fFc#qV^V)H%L3Bxc!>fv`b zXWTYH^HNo-T_WM5E!O?m{(FpE*0eqe*=!jH#F^jBXHB@}E%dI-xvZM3)|_u-botnvj+vdFau?OZ^cw3ER3lpWR|QhX4?)B~??LMx zwjPQ23l*1AuZkoEQi4W7?79s;^Z7O`;mUd2D<{`K^A4~mXbNFc;sjwH3H*9mZ^SPq zwa{J&o^-Or@(U@z;!F6~`XL`Lg-tQ7y?&NdCWlV8#b4f!61vO5HK>eDbrE8$&KPM% zZO}LpNX~|Eml{C(b4-USplG7jgEZvH686dOl7h1u5q z_Q{>5SSwl}*2%f^YJ(HqQZX?yQZX(!Mc_s7B3$zRw<~g;RB_3+%I`8iB^V@iTCMTnoja;0 z}4}9GTHt?)O5aYpkB83()f!( z!K?43kd4-5kMQsUsb1O`waT{Q4}%hkhe^V!W0qxeA+y|1`Lm*r4PVzxR}+~$n0#x- z_hQsOsn$1)%*nl{SxaX{G(&9kN_-rB+COSv)9BHN)eZ8yo5o^$@NEx216aNE?@u^J zQ~5q0f^v9Gu(ClrB^n`pCeQeh!qnDqY>97d2I*l^>pmX2sf6PnG<`cE@%0}T%Ha=LtfIZQ?GR?M~M=X(*Y3*fD zqxAUvDa*$h=3J*EJ`+Q=Cum;=#FpK>mnZQFznRcCnbW6t?|JEJm~e~?5m$_>q{daz z{F6hAoM{~yu^4!K)1g@ZTVblI4k+E+t4jXLLA`~Nh!!Z-`6qB){*eMrPxZ6I^S~M#M$$;X5zx=1)PEPiqyMDOFeySh&wI? zC$dF-X6w8Tg~(lzFRDDh|6{{Eb6?MeAjXtPMy`R68>H7{$(i4VjMv~H3Bmav$-VhM zn@!}E$8O9(CMSCIvK2NeBj2KlOwU=``A@d9kl7~qqqm}(bT&6B=-(H&?IGShJS!yn zFD7tRjhcHv@^w}!hy@4aUh6K1dV9B8ZDVuNNb}|8m`-&xYgJoznq4}rcrr?q&XHZ& zp;}T*Rx9M1f4+U#hlGNZylyRF)&Au-dPP$Q$9#3fo{p}B8WYEA`0JQmrHzrOH2w0@ z{)mIdeqKbdI?B(k3F2PKfFNaycwlodL5N!IS)b6kudv%zFN8|G;px1`@2A?{HmTmZp7>hnmJXmAug9dE|9yDF)V}%iY#Qc^#yp-)Tm9xSTvcRy?{Filoror87kRe?AGL4GA z*z6vJ?%`9SP=65;iD)XjwL+5{7ReU>#e$FA-0Owi9c2`4u$NwR z$38y|wBi|WWqW1D+-}A^p8s7Tz+DdBWD<8FM!MBk=LoV$;r<8#pqdd(VQrl2ZRulw7hgJ%`XSv;{P z&l@^8M}5;h-5=y+sCbPp=2kwhWJXR$_Yugu=rnFOAri9>sYL`g4jtFqH@h)6!XKhe zSX+LaM|?jm9hI_H(JK%@DPAww}r`I<=mW za{aV?T=MLtQrVT@hB!%SB}9kc95&1`i|oiWcP@W-3zK+BJ8`Xb>NA}g(&h`@irVFlRp82+f?CF%bmH)!zIHtHMbY3IgsM%#9smF_Qz;=s+ zrE|=JInGb%b}0jm9j+m?;ifUJ)p(f^Pumb`k)A3{ZDD)lwog=8L;uaPT(h=zZvE9| zI_>utL-r8F&mQ}amv(>Q81Ce?{hFENiqQMGrbsmR=y@;y8gTCa-M@}jrX}S8!m6~a zsD}1K)2_((LGIi0e(6q(7E(dlSF8Lpr9|yX(wPtOuvL6~n4c7-peL>Eqj{FDV=Sbw zt&3+QLwL#mlu$5_diZtrSP?6yhknCYn)~+?wQ{3M!n5JgpUWMXxlOVRcs=5d2eb&m zM#ihT9_K!gP1P_!4?DcC(%) z|0NTND?u|sLSNgnDu@-bI5w=OeZrhEgkwPl7qsdu5o+Te1MT|=~YVA`?g zBJH|lnjEij^VilgFNoYH4}Fx9bPd9FHjL%6%*PC+$V;COYU@G=<~dKfM1-lw`gvY` zb&?wIE9|LpXa#i?!8r4ZLt#YbDu3K{d}}pfW%NsbX~x!z2(fx`q4x&*dKOQlIp!1E zgpu!kbXN$=&)Wm`qZhL_r=L+Xpsga~+9>Y5!Ut!iXq@$aC>KxDgJ|bjroGwBMs<2a zlBqtLpOGu!@hB}#<`om}?1(h%8{(;~TO-O>nYsFYyZGaS zx}l*%T=RZ=cmjao`tyYd<}Y4|h-jR0Baxx3={vR(XnlLsBKcUz#P$$nCj~i8_z8mk zyNrliP~mg6ld%Q1&7&xU|0r3=gj&`F-4yEk%rNo5!!iWDQUZcJZuVbO$Ooya?gId>69lRj0PegB+p;iriNZX;Jw z0ajQc=gKOUgjK_?@Y9Kt4d0yTR^9K%k9sRl_p%I*`}TBvYSK1sy)oo;0{iUUXl>~D z*NEcj>Ge$#6H|lGKNTEU(C`Nj5r2~pR*gADD+h$bYiUiacLGrj(~|Xft8EwS@5xbk zT-}_2x=2d)yB-!LSk6R59lOCf-cSa$csGFC?x>V*LCCz$646ceY=7C!;NFBp5LYO6 zyp+sppldgzWUx#=RM!CIHHuz8yQ~+#()Uo}b#!&PwTV<#L)aW;YYEXgzi-v;SCyOA zjRz-^h4?KPLq1*oz^jRs$UN{sd(Q3dwX)YYj=3^))xvA=k+ zfqH<}1hU{*U~PxBA|&s~)K6C6Uw zO`P?m{x;caep^4^LM^+IJXgjRE)3ZR!4GNKwFjzh<=ebK$Gt|E#kp6~)cf{n4Q<~p zsO?%Od*mg$wFgyuc$`SYUvL_ufu{x1EpuzkTPF@jz`W079qrkw?}->V%iU-gT%N#_ z@<}uP4$q{Pv`>m?TJ^|K_j&YmQn`q`Uu8ut>|R_$n007CquupNqEtyDxuJnO8gi`M z{ZDhXj+Ts&N?&{EMEBa_!PVpnp6M{C`B64unVx(FMi$3DfgDUdka&2LJpcN!R&HC; z!K>s%XK|cEGeb*>QfH0~5SWoq3ou;=sTdV_CD)_!o-xkqqn{udp0?2!o+qi5_y$+o zw+~n{N^E)SbVAugi*xBp8Tq-5EuvdW+DU)%2l=z3x|=%N**yf4HFLN6UQ^Z3kUx7& z-s-&BSr08UCwI{LxVaPLY|Y$&hZm+|_Pvh8V1Tq4uqj90%H6`v&V}IY21SDW*M$C) zh6Bd{tg)JzBd{4r_V=!L@Bf!19XN!l8NtnNXP1yD91hMQ!0$gbBM1ycP95011ZfQ?DuP9y^;xFW&b#mvGAiURV$E>^$>C1A@E zxX8o8+0oGqiUvFtA1gO!D9qUjpn~89)Cjh2RzTev*v6y^Yh0a;8ZU}OL>D9jZK1BL@6LV&_NpfFD$kqHcQ0a!Q)>|%nQv$G`F0vrIJ z?4S3F?Fi+^oBV%-{*%~t#P~nR2oS%6nTlz0K>^uKG@y@Yz&F65a8N7?3&8_9J~2STNMM3T;o&<41K2iv0n&Dc2t z$_BLm-xbJ@>nNZtSinAh6b!Hy0|5safI|ZX=m?L4iUBL(uz)^efWgJUf=Tc}1IJw$c-$cdB?bYHAQB)^ zV8J3#JNIG$!gu=s+5+S5)PW!vDTdoA7%^Z+A5g~w_8tqsu?QqU;O@9+U|bBK%P2f} z9~z1QjyQl>;Blt_6Boz~kO|-s9Dhe>AOis$0ER-qF)_eBJ52e39C+N(A8;fzcmfAm z1fmU+1|EU3BM)%6-2%Y?B<|b^^7o$#fZ%QyfC_@QVL+?^%)#41HTWT}A4dRNASe1$ z4`?x<5wHpvL?6_HUr+^sJIVy~1m5+ZUO=SaaLCTJ?;YTQ-H|^Es0h0{0P4Z7Wrx>* zUVu9BQ`yl6P$zzr9bNgZ6QDl)TtK{b&VfEa-S|;J%CT?^AVUCGPzUxXcpM-TP%i)* z{B!xC6F-h1QvkU`c8*~Er?MmK-RoeR{}fO!K&gVqe^yTa9P-Nm`nEF{0+a8~dS$>e zsQ@(Z!c?IFmcb}X1Vple_a0?f|9pgXW#a{vsvb7l#MaJRN&ICfyW z^Bcgh-LswfAIPBtXOJB$0SI{a(vIN(+^wBfz;wo%^oyBJ2eL=`um8E zz(fPWAb>>xgtc?k849d1+N}Y`Y3Ka{K#>ce%Dd1X<_vf^c3QbYf!(w_eYim}-$!r< zWVCzryA9iExoZ-E31jDs00mxF?$&lpALct%pf%iyjDS%+0J-h9@q}W(k9Pr((JpMq zgaR+%!A}0!tqWfKVVe{Y&>aWm=j&kmAJ=6NKU?$vY;A_g$pDRD7G_Wc*fC57`9C%T ztmQw=Q2%2yxF}%qfrj8wTKXq!e{!GxA>k*L`v>@Q#zPgHgLdeKsR9rEcdn}7w zUxnM=R`5Dz`!M7dli7|`L9j|O8xOWr25jkCm2oma})xFK*Kd3Pd>qDz||iUQGIAqaIaLqa0F6K zZYp!Qlv-3&Ky+US8J*!CnK96N`adQFh{Ay1KaLPq?BD&MA1DJjKShy~kQp%s8H!pO z(=yQYcD_Z*jPtmTsW_NJNjzMaVxG>XW_*ls9BloNwTrxbr9jLc$p3-+?pK3g%667u zh6{+)K=*IwBRU>HLT1NL{=O*{NKgGkxE$~y959b}6iB$aI5-m=?98E_C{Zy{gfP^W zKyYyvhrt~Gb4k?M%|;Lc0}^kR9u~lNb$-6f#nKvTZf0Q*jPkt)K_CSLc5mnGBoBU( z=%~Cn5{^W{QAj)-iNRvT^aX$VijTLOl{Ewi9I+4}2LJvCq%bf*^kohG4-B-`;2+fK zI}CxtfW-a^!(hRz;O{Unp7{-i1A6)c2CDe)?Z9~L7Z`8?GV(VV3XIZzg`qG&*!>#} zj)d>g4oH*y+74Jc0p>n`=??*r_6Kdi2b2D6Cx-u{ofsU9*M98}i0=PvhXFPp|K1J< zq?3Mw;lOm{?=Tb?mwm@hW!D2EMGFoORLh6j?-zrk>DFzWxc9q_fT zKlqHp!2fKA#r>`ua2y^Cb${&-zsIwHM}awn-`Zh;^yaTHz@UH`jbGXUDRm&<_d5(= z*6%O`;t!bv+OkJG%pdxP0M-=$PA>w91X5PNjf=#98KYnE12SNL;DW@1nWk9r&Mi1UEA~2P?OoZ^&uc`B(vSClsdX>+(S=0@0= SW)V0sBoe~UuW&{Y^1lFs0*>PV literal 0 HcmV?d00001 diff --git a/tests/triton_tests/plot3.png b/tests/triton_tests/plot3.png new file mode 100644 index 0000000000000000000000000000000000000000..e83178d7a65f7f2c78c9b9ad369b13a6c1a3a917 GIT binary patch literal 58335 zcmb4rby$?`w=SK6G}57fBHb;GAT=;_N;gP%iHdZ0=P)24(v5_4i!_qb2qF#VneW?s z|MoxUT-Wis6c~q@XP$RGYu)Q!_x*lUS5?5np~OK#Lc&vgDXW2mgjRrrge;DQ0Ui-@ z7;FJwgx%$I-8G%9+`V4AS|X{uc6V`bc6YEfqxZCQb+dJL;^h$J;AW$@ad&rd6XE1^ z{GTUqIJ;VNF5qB&055{=@>0(Y3F+Z$#9w6Ci&9%8q;+LQSt+P@*8Y;0zvkIh=kGDa zjC%FvykY$c2aX7OR@SG`_iUVNrm)mV62tR|2$WEIlmfZ;;ka_0zcr_jx zq-f9R)#LLk@bmaoGEcA9{U(_mC*zqN=f#{&wv)4N9m|AS#(dhxbf!Fv6tEb=2&Di0 z=LN5p2kQU)#31VuCzlphE?vRmS7(I$AM=IwXn2>#gmK@Ef2Ra0u-ZJ;eabXl?=Wv zi-c{onP@|uFoKYmmFzQ_h6dofKJwY*JUm$%Vb?J52k!KqK$_%qJ2%L6*=JkJgI88~ zz(Wp;@Y}|`hWJ1o_zuSEYJ@#62?g(v7w7|lKrH9`t!itmMksyH52Jg{ep{c05zJ%> z=925_={akaQG3M6jiuuQN6sJT=Jew zo9Ds2teo7yhh~BTS<=haxasTNf&ckc(%XtS`&VbsW&2l$ABp4QrM}EWFo|a$r#+Hh zU%OdSW{~CA(pu3A6(G=Bih9Hy8Wr_G;^z2kuviX(bog$m-}#~Yuknz}(`_9O%?%Nj zFapvfzcn<`{#b>EIUBT_b|;#QUg?(7+!WI!0>(VFfZlGu!zN6XsmANWmdH=nV?lRs zOKial7nsh~j!6uu2rg~>_Wt$U*?0^5WmXltJRI5`)j(4lKA-e zpPz5uK9n%+i=vQtk;~gOip?Oq`AyUfW68=*uVletBU@Ngjva*2d8h;ddR$yw=V!|y zA+JNjuFqLn)XRa_ju`@u`iJW(xf0oh3b7}rV=vt1EK|8Hz7e$F{#_tCzdRY@6!!h| z6l{v+{1@BdEa60^S1C!?i!-Cfri@&}-b)5`BZSgeCVH*j+~-SPbSwU+cnUGpWxHFA zS`AK1O663nYc~dMo1%+ave6-CBoRqWm7^Ti`W-U=o;@O*A-?QBb807Kji+g{kmIHxi z_nmIF7q=U4{K{tr)40uT#`BPW{`~2FFrQ2pc+q8&SP*4ya6}7QjUS+k? z9^e#v{`KJt^^U**tC0+?!`xZ>S0MtO4y#7%b$G}6oBUE(3^zZA(|OrU+Oxpb{Q$1Q z>AZavX@%cG-E5kGBjrN9`VCGr><0B{rsH{1HW~6}W@cRybOC<|Qw?3Q!6NwV zzem|lmBgo}>Qoy~@fio4X<^_|${x(u+0K1tq6@m=&f`e{e~`kXqdf-x94xnqJ`_8a z?0PdfHANxeFH}=gGhJzv*&cBH*~|H-;PR}oUvyg9qxOq6-1&OPSJ@&SzotqbD#X&V zpLr1_A-1`z(S&D&cf!J)Sou`hwc`oSMvjn_r#|JJt?dT2#J*={*4GcX^#kH>3K9&S zh7t5C(hqXEZoEw4Fw*DhuoxgZY+R4Qp%Z-a1TC^x+R4c&{rvKGKiyoj$8)fae;Y?G zey{u7-=6#2p3PqFl;+w@6rj%YbC` z-k;@F?FBkj#)#;_CT32$|L~{J*p$Q+^CYsB8MQ-4l{bu_n{GMa+z5o2%-;7;O(6F4 zK3gbC34URad4D`E^KhclNHm64e6Z^+=FWVbz0}tJWbrEym^M=5mS5lFJ)o1wx<2a2 zT?xF7PBjTkM{HC&uXS^l8+gSBv|?!o3k_3MCNB^Hqna%uCoPR~)NwC?)Z%rt^=qO~ z?0Ub}X+hkq2P^RAw}Qmog?an+p{To2sYyrRkE~{23u|j$>rv`u|5KIr%PlH=TCr~+ zdJHR9)YJ{P!;ar&lPq%DUQ#kIg{4#t+$JF5*h?3yrc3W8Ju6~2#FC4|w5PPiB`0&J zb-VG#md9p*XOTP?TnxNEP|gyHe{pw#8my8jP~$Mib9dfx?-MY*xM;XFn*GYkiWz(m z$Bzrf9y3~L%FM2AZZ&fbi$2@g;PWmX)LW0`Y!;^)uhq>O7do%BBeCg!39qxC{p|hQ ztS_3XAD6RL(%qeJD24qwEN}&`Rc_X=-{i{l^~1w$aD6Pl#W5J)S8Ud!9YZZlP{~l$ zw2a-tj?3Ci^(_jfwp7yu(oOlsKf|e*Sf|!{BN3Mb84+#i>FNI9>boB;XY1G5K}uC= zq@|^8reTBvnbFnNT)VSX4<(hAA4F5~vkHCxVmsx%5wCCoJ{F#2tiyl({SR;= zcXoEZILsw*oA)&g=;0m@a_Jce|l_}upIuJZKHparj@?n49!EUw^PxAWj0q%DNSy=*7(q8ZZ z2mN!5T09Afh}IH{p;gA>AoCfyI3mNqgHi-Id2iPHEJf7DjkaZ7#vxq=lPF3t=PXeH z{vL;XbTqHfk_(U8BMW8fnpA1JDf#FtUVrT#NPGA(0zQ7cFOBu{BU(XL6t!@|J&1p# z*5htmV>;`_#y5!KEaI^{Uj6zjPe3rslPBW^a=1t!^9_LwPUEq}iG6V=1k&uLlguam zI;Eu&pVO`Ba-CREp~gFc?uneY0{SqDknELy$YTzgJT?x8Z%=PRdGDTMH8P_bg^+1| zZ_eX$UKS%FCQcXi;yBxzwFap# z!U0qX8ww9+@Yp5TwPbupkgS3NQAM08Cnt`Kk~d2pvpxY?jg5_=nU2osBUUFTC*MX! z*gfAaxB9qA%gD@qu_XqeP9m>a2|%KU=yKKz4bEeGRUI)SR;S+UH00#{h$?-4xI`i5 z&2`j%T^Zy&*-*$M9mruRN^?P4jq>`=K88S|;n9vOrK!Z_265{Ccf{T$vZ(QZtu9ec z4;`>D1Lyw%jYtv&uN6{UQc};~gN1a70FiQ?DvopERCa^D#U{6%rIv=F6Dp_15U*vQ z4=qRSnd9T?C#R7?wO@4EdQ(hX2blJk+ghY&lfmOi4cqQE3-#bqOP@89H3xqf&$DeDw?@r;h}J{a(|b)o#Ph+KqAkZHQ9o~!nsx-UQi2# zKY?cpX}`t5y8zhdrIM1_kLuTK2K7;K1a$X$Kh#aybGeph<(8m&kh#Z{yH-jvq$_;L zxyY!&$&xfH0t#5U)rAo~PUwMrCC>32rO16oucE6!tEelQAbgUJ4}4n&Hw+r{<5nk>qXw9ww%xeG5Uwyx2@Nu>s4F;y&)_ zk^TP3Kp~aG2n!eYv4B8IKdoOBGAhRS7h6rI_Dd^eW@3<{%==?##D4b>TiQ)sZs*39 z*jDJ*MFienkxrMne*bLw3Im_|8>r(xmz${{0mzG2pkWJ z=7X|~vAjkkiLw~yr!ktH;KNgyopv4g zjAxhs=FJzstDVq_GytBox?HWBRo9lhh-v8_HiC-w^=nzc-Q~0Z%H8OGQxkw&5CRYf z+InhPk_~g;g#CZqz6`M9ef{PQ2Qpl}|N82To8qBD*!kiN=9Htr$!J!*WD<;QL$w0S zORz^S08v@@e@s~Fxw`UHwO_$!w-4rOS@KX!qp1W-02=9)mvv1^NC;hW+BUhrA}>|T z-I{#Gxf#jXPA?!prXErQeklSZa0bXvu+I`;abu{YOu++EU z++k!!UujsM|2Z#^dkYf<7g>R3xf~*>VGE<0vM_z&^7xp^g}s_GZcG!B>~2a|DOH*q zT;WPB)2xmb1(rfYN2dlFJ1PMOq&yIb91C9@7g7Lf!KW5dcQCRA?{ipxy#zX+o+-5# zBj9*9?UZC*tiF8+;*at1H&jbc;u{Epk7^cSU7NMk0b%)vZudPD9 z_5$uL(Gj9BqF@~PWHN88*w|7}Gqa``dmKEbRz>3O?B;aQUze8~Ibu5!FFp1jHIHa> zja79tA6D4v?tVYhOq({W^wJuBR?e73KSYH4XnNJFDyWo4Bfs+QM`USZsp0peBB zF36TzPn|bd3`#bA+Nmz)W#n3q{_8{YoWooAv(13+bOOs#R7~|QW`)%qmc>{@5`1?g z4s?wVGgd*7p*+_q0TQ6CC4QVBOAKI?zd{q#NbvZ8{Gh=@oCq=%-vTL%Fu``OB30Lxt+3His2v6`Nu zjU4*^o~0pe*+ZcqF)*QADk}c~eUm3Rp5OkP5|a<9t-aECefs%C(!t&1V4rKf_~c|c zL&GdX19;iVB9%pHalDA7Qq88L)$mg<`ay%{r5N_dgl(|Yci;wJB9#C zTL?4{m-UfWdk-3eQVD9%Dy928PEpI8z=OuL?!2jDNrP!kPwUBdjHV#u!$H%4C@TPM zvii4{=J-Z{CXCde9Yj=U{QjUhDDtFEi3A~xe7s3l3wu*1w<)r7kt;=ZF|asIPfe$6 z!-@%Q*TU&?{=QuiI!7i!W^((aNjxYd+ZU-)qffiF`6A18e~n8+Vsv1+ESJxJgfElOCnk?%-HRy0Ef_6PMGo735HWgSRBF8MfZ&g{A9H?-<^z;9kj#-mp@sg)E zdWxbhFOLy${kPb-&A0Xb?m8_hK9G8;m4}zth^Jt3N{jEa(^4~StT2cTa3Br1j=qKO zf&y$I`*Af9Z(bhn;5puu2=RRVBM`0^FP%gXZjEyyI3b@n>l=4L2c zkxeDGx&?(bvI?4|mcJR__*jo6`0YOuR+XEg>0&kT6}bctv;}os6W(=qFJi#DDFu2o z8l2wAUzG3rD*FaX)JHvV*xA`ADK!`Mh^WQHC$UGB6BYOe=;OEl#;D8SV|XG$`3g$H z0D?WYaZrh{W;Ddwhz66be;jW-Ct!Yt7Ni#{Ks)`B|A72s(O&>8QUEQ<==`FIz-typt zhHGpH(+MQ^lL_Ti0}NARbiVMjV#8Yf(B~3cn^c>$4DA(8X|4R$T+)-KwIE)rVR}x^ zsvjqoLn+*#Ev60l1F3BLtDnMYGMXp#7Svi)9m6;8Oy-Fw$v=?ML_2P)+GPg?)?J-H z{1{HZ5sC{7=tA%PzPY^-=XcY%pbTG3WT`Ruly!ZoHhOz$T@z3jN1(;lX5-ZXqPAS; z;U^-@>ThEMT~=tR)Kbjb!HMkRGzL~yzv&ZGKL_5WSLBL6(5ti!n=aFQR1^`z{Dc+0 z+BRelg?5srr-mNP7)YWQ4Dt*MfKD-LMWR(scoRhpg%_fyeg<}0=BXwbG#z8&BkFgE zZfEO@ZESc`7P z#vjSxF9da@Lbs;d(xMRn_o+gKhtjC~C9lF+3kx~sV>+jFB(35diPRrP_!iJo^Yi`v zom5>}i1CcDcpjIHC;i7_8F8jEv;6vP-rMd54NXm-pYO=iK$qaP63{TzT&7vd$bxiG6}IF@5$#p|P|w_#~zHZ%-98p}d9VmbO_d!m|KvN32BQKrqF zthh;1aYE9_GXvjjs7xrZ!h%UcZNEAw$jfU>m4_*GtY)BlJU(W$91Uf4B%mRW?_^!| zTos3UbtHC*8`m2_O@bg|zuc!sJguN!BHMYIesQ#PU7O2(N0u{it>J9ZR8Ud`oRKu@H5l7y{tP0mk{YrUL4C6EjNOhQnnC9RXntZ=CrXBZ1M}C?*_?w_`F*L$-{)5{!mS){ucxrq%|5@FV7;ZUtU`{kV}^u4ni|NZUJ*WzCIdnQ52SG;y)k@h;0>urzYHYc^Jsg2Mg z$sD};eVV4Da14=2!(ID*(i}g444x4lj&*W+y0h5C3l8gQhJ9sOw>Cz)Arxae5fr<& z;6Zg(^Ho32-IBqjRkHz@t9tLD%GY4a&Y+NYoo$`irmk~i&*ACjBuB>EhZ*K5cOmxU zJ7#&>n|zvRmHhVKkq|T)se#vOCotVhmV!Fr za1XMJ_uZXSe;M~=Aa6eip{S***NYAi!f>z6a5VE06CFZ)ssRbY-qh4YE#xA(yP%CSV!>Du7|^{t`vI5J zU9WV{w+kyMW*+*z<9=(g>#c5Vv)`azbKEZ*wwfCBkYcl=z3;t-l)exL;%y z%*V#8j}q0Cy8co!=AlBU1f48E^9~}k!Bh{>(2oF_3V7O$Z}E>onfj0;?ssxFYm$0V z+_l8DkcabPK7uUpw6n%Yuc0bF@O>PF!KMTLE%p{e^m1&ooDH^Sv@Wi1!Bh6FAu~xu zW!9(H$+~38C=JJ(OxZ&vEtl`BdEC91#NQ@Q(kOzN**D{Xlha?cWUlpOu4^w=K2+|E zWBgGSxngSeb*z#J%QV;ic@JOqH%2?Nj?678E>5;>+1Rj1NG#8b9XIA;=b5l|V^Pjj zei=jsDLYY@x~qeAPI-~sFYZgeUfr?FyFry^WmiYd2uIO_J*j=Gns|D{pwu2^3;U6C zqhBuX=_vrt;^qKLuDht@JM-~KKp3ZP#z^r1-c!q+T(Z3uD)0%=FniMs;E+<^M<80* zm8GZ%gAA)!2_@OnEhZ;(1rMe_@3%;VTy+v3&=N$OM7)_PYLdS)ipr^;UI0k@qVcK2r7?En^YZE@fym7}SKm$D?uu zy~xF%#*^&gUA@UETj-TluYZdVZ~anBz%s3uSRLrYqJMZERwM}|MGl~GbKALEHJE?a z=QjA1=IY8t?qhOnS7#>SHCnF9B?>zIR@l=M`x)2t6T6viTG~SaZpKGif4x$IAISPp zS}kOQ{Bd=DGzJ($oi*+N?qSby{F+X?o~|BSc6P zdiXQQx20sF@{Re`*`Cfx$d#1wtk`hj@Du+;N&ipYC#jn)V^M@!EK4dSQ^@D{e|=vq z@lTb#RrR(IijKO66FdLAO>5IH^my+Q?*zIz?+Un%|1kBo2==dm-4UK#Y9kSWD;Kw{ zCYaxa*BIHFK~T|H9#yGC4tE4csBO2*>ByY9k|ZAiqR#Yiu?Y~63B$^y8e-<|=^}1h z53EPCo_+Hl93C!$6n6VPQYe7}?stjS0K{k;03dUxBZ8Zi{nMY+=%n{%%0CNeds49) zfN92ci5lm5;A;wsFi_l0YwPJjYFR*Xf5$nnwIz_Z-oKx$(Bp0BPdQ@K8k>a0 ztL<@X$y1dV;3a?3c3=YxR6A1sobt_h{*Tj>YKlVgEy~A_=6QHJoj4y&R2PkTxE;*X z;tq?VXd|GnQ&Hz4G0V51Z=DG_s>|_@!+Ls>cAUSF=ROBji`6mB(F;NEfT=;_{^udq z{C83tEYR6FljzG#NbDx1H{hFBkyky;+woj*;affvQtqRe_<0!Lt_ieKTYDMw0Rn=}urY3kqa_O9?FqOR zVAFNGIzv!#>k(wWxbGi&KE5xXGua0j_&6%3IX8a({ynX#T#<#*1W6J2{-mAbN9Dr7 z4<24E`a!}X9%Eb-?NZg(4f+FPqs%WvGqYqqh}Em%Vn4%DB4t;1elu9j^A1fOi}}0< z+cSnP)~R*^>mp#OS&99S_E87A48UPQfRrE35zmx*ixCf)vG~Ud2!t7Qdqxgc2J23l zrPl^LfhGbTghx>N?BGu0O~0PLJ(lTopANSA`b;8RwcyRwPN!d|ym=;(=&ljEgQO+l zzU-1s6|%n1%3&UqOP=rDL{}+9>KmeEKAA?BhzBj;LgjXYg2S7Wpnl&j#YlYK#ngW-ljP`- zEYkRpLwGs&1Lbf7Pn0B6O`BOm?d2~>rOjfR=!g|I=ubfPL9judMqotPxRLM*$@}Ja zf73zzg3WTPFqoDg2CSemACCyGK+Wvm!G!%O=%~D^%W~Yemkm;^+vi*Jb+90cS_+=S+iB2f3vKp~(wpDRlZ)HmENpsk^?#a08OD zOmDc0{+831)MudjcBr2(+=hrkJcz_fYD};kP<#t{+?r8x^HCTJ+*SGQBbxfZN~k~( ztW6nlaZwdE15&c>g*)~h`KrW>lmdEQ1O_Y%RV(IP4|XS0Y^f8nn=wOh4pxq&{TLeb z7}6AjwD**Hsqzodse2^B2cIF>yzc!(ASiB2mxpnswW)<_zPBdrn0WrAIA3czI*RG z3}a1;jgc?!n+*-dkxjNtI~aR$zj%*_Eud|R5f?fb3XjE+)S}1TeLk&HvK~~$4b#xv zEPK3+O-kiQ;*H^r?s0#MQQ5^t{t-q3Z-+UVD|RaMl|LPgx$>+=Pm8^#Pm+1E8yVv^ zt|+fb)Gexd>*)Tex+S{AKIzAbyo{D$jGX1OlqXbUV7qpNc5Cm$xsyw1Er&iQT)Q1v z=F6jcScuKR$5Xo~dr&{gtw( zg_vWPM;ad<>5W(ADO3gLduoz{DA8Hm4jSvPB+d8h?;x(fkI*cR4H?3Y&<(^(b^(Wp z#jo=C2~PqzbVB$3u?usbv|V2fRStd&zPSmt+&|B{a;h{kKrdos6l(Eu{@A-T!I#=> za92UW6aq8d{&K>FxCIo8EiPhO+y_-|Bj@X|H z?e5ehLe{fCdhm{0X?w@H^t|PFmP2`A-=jC<>i2(gnuq!2l3Q?!INX5n3a+a1fFPs0 z8_W4e-0^!nai_wCgAGf|e(s^j62RE-AY30k4kw?S1`3bZ*?L=KDSabdoI8E+<4e^2y*7SDTOqggK2yS4B0;Kj2FOFOEvr1@zolCx7!XB2U2v#2T5N*5i-KY)Ig2kJ=SW!M*90(1zf z1N`=31O>2{=B$nz8qZnbKcA^@Vp-fqH#^brmN*)J8Hm|k*2Wx@D)Nu@yZ-Qgw?*RN znR=(%Y_b2e+aoWES8}C)TxnbIS@k}}f7lUUy?41U=g$wkaoQ9sMab-wtShx5^+AP? zh-E2KrNNf{$g0kVYxne++W6s4)?4@lAC-sm$|KUF9g}?d6eAfZ-bpd+(3dJC5+XN3 z^Ck&}($1#+3N~Hl0~0;bfuu`@aF8c2ZDbv6RF=PbG)_3%Ifb6Yd2NrzYHrH?&OFY@g4Xi_ij3iOGeR8w|#=hf>I@+ zD%R*F`F4pHcU0drrSx&mnTWk(04BWy)V;S;_lcBq7_IMoGH# znvL~XfQbLk@-ygcn||`dJkKP=L_h1T>Wjiqh+JI{U^cnicc#WsNE?Zb7ydUtrg`R( z@DBEU7zR_YmC983C+;rEcX#CUP+pD>sd%px`!W8oemkI&xlfr!WRz;D>H9swyseIU z#O05JJGnAzGaK2;nQ5wp0Z=R!?2@t{=D1x{zk2P8OhqK^q4-4QmLLl(V5G zEX1*r0#qTAn#Q9kUy8FA0^+~6v8emFC7E(&?Uyr@73 zl$Y+FdWZZ7ljdxkTUZahiUzr4YEvwbFQ=tf7&Bv6(?c-eaooR4h9?&{`g!|{)e#fl zE|X!)&d6#*JDN%%O&1d1Lg(MZ(KnwCaX<1#t}@q`u$Dh-Zoxn6ye5b4ln3R1t~1LL zY^T*VZ*`I-Exn+g=@`#-fZ^F&o4w|Y%CJL8Zu%3tJJyHD&XZSAlCM<~pR_O3H7)$dabbub^;9;WhuHHjR+Y5l1YzJ7{qx>cm#lwVDj z^cdHU!7f~w^FY>8%8`G|nG1hz{8~mkLfw>^3u%V<$ZPuL4|Dadvd{!|a@5RVV_37Y zQ7~uJ=4P_uZw}9Xka1ptuOSKA%+<)4zo~7m%2E0~dzjx5+|`jhnnC$e8o?qPEosED zgq{+Ckv*EUp7lmDkx{et_Z70;bYlR!sYd+EFBCH}V(N5-iX2MvNP$2wL6KpU81iRfTYAC@)o)+pmNrfurOkFP`p5I0^qh&`8w_ZbmPEJvE zTE;|Yu>$Um8AFS%^9tMeo&I;_D2>u@klytt;}qC9+TRpC%vS zi@?m743621Ek9tRVd&!@<3-8J#f zTBa{5?JhV!+t^meG?=~UMn)@7-wPg{ywjYW@Sstd8s|r>2@>j- zL~Ah!Ulmq{A+s+_2W;qU)j6W^*kvJ#O4wVp0(d61DJvIduBFt0Oz#(gN2hJUIuMiZ z_6g0MJkwRXw5tlx$w{`=j^G_fXcxZMEu2dr5Q#L{PVV*hq~kI+G0Y5nb>Ag2Vi(#R zR23EV`{6qtl(ToIq%GMa3bC$o>wSy10z>4ggkKzB95r&NuQ0QhQ=z|B^d7{m^6|6t zbP#xBj4hdS$|vuZ~G^By-JZo4N`!_zw?tA@^a+h*9trkG@Y^ z=8`?~@RtR>Hjn2(HTTu*ul282;qb)RU3>!#X3$kS^i9gHQbJt?lfP>ns+#6}@R$_; zCkHWdy`+bWd_$&{k^8t0%aF$n30bD&!Sj&0OXgmo(G9&~q^hg@W^FAGwHd#J-aDzCgnupRl)SO@8TDW4>C{mfZ(5&2u43yK zvOj$`nodJr(%k^P`j|WY+l05uAJjuOC({rUjZo#YvR35O%UDTdvUstO5LLG2R%=q@Geq-P@LZuN4^w?npj4q1wtONlc21Dl^hp$0HhX-XZ zQihU$Ooufug#6vihll08KfUW9tgb1?dC62FrC)%fe1LmmlO4kf*&>J-G-Rc;ac$MM z`umn(MoS8#;6H<$jF4Y3q@+R|&oax{>b^_jebLv5dq#V9-}29yA!rBTDd`IK!;V&B zDD(|BmEgZ2;xWkvn z%SA)Sqt4E2Ymx3ik^|-|!m_GU|`XeZifmM@|iTo^9($aJpM;R#xQvxMWI@%%? zZ(l0<;*ye!<`Lc>Oxx!ne(EMr z>ELHDW%9`L4-Y82B=)_k?{B(;{C%=v&xn3!5n5s#7Ft=G;{0QB z8+12k=>tSwfDy_^$eJa#SaSWP4Gj&uZV+z{)O_iJ&a|2(%3UCuJ$9i8u^AX8AR+W@ zK&h$>{0lSXI)2yd&`P&X3+SvUO)~l!4p-;3Opa{88=wBP4yRdW;s3i<_TE#`9$;1f8h0%aegD z+BvXjuon9auO}Y;%o!>PN6}Ny&BIB!f9IQUhXvniigMq`Hl_RL(*C5nYgU~m;h+d_ zj;Y)Orjbh9gb<+B2VrZ+SA_xrO3JNnTuuuhdzY1$Z?4GXvn2wfZ_AU70g21dnm6n6 zYwPPk1&4S7(ihuxku3wDpc@z&*#t~y?{{I{PudmQ_Fewk+GI!qnRuBa+9mZu_0RYI za(wRk(}+D)%gkpa<6Mw@FDHaO=+W8tM2u@lqp93V=h*%${vd@uq~Y@&K`(%Gy0_gB33zwn zv$BSJJ`z7)Yy={Z&mh)-;2l@e*H;7y$i6uidi&t+1)8O*K4RZB)x>*&uq1-rq}B?r z&OX%{X%C3pQ|zysTXrXZH^oc~YhVVyou#$t|3>qpIwZ9XJIHQ4m|VntwWv3=WbeVv zt(TLE%4XWq;Q_-t0tuowPr#!YlRLUB_-#JnZ^iLl>j&mE&WA=vE5p8P#SAv9LcQ37 z(>+DY8=Q)Pd>WCSI0hL+EDon} zlZudaf=@;v>gn*$CSzk`|2bUBguy;Xb6V8S>9%;X1HEUr*)y5TBL^-O{@41eN$6)9 z64BQJCEU?%SHhozv8{L`quF3&&7*5K{*}tgVLh@{G0(z|X%0N_PJ>4BLY!tNlquPD z>Yv;v`~UgL6rHG@R;EV`?eq9Iw_!nUy-CA;uXai)z>pjv0!78dch_rh-BdCTytD+O zkVB6`9v+@wAehf~XV!1eSNbi`jHuJqfV!X0ZW?Cl_wOu5oR*i3ZP@YlX=JZPeV&HT z`CK?09T;Hfkmq?R8Mv1pLK{bgHl)n&0#_AP3O!%G!+b#^ebH^a%F$VA#n1Om=Sb%8 zL5Ne5?oDZHu3X6Q%ItsH)MU>#w9wX@wsQ)BDsItxJ(|yUGK`FUTCjX2=sx?gLd-Z| z=gHV~J44ZNasE#5MaN2jR@*#R)GG#|^-E7506Gi>AhhlhQMfIgo`xdy=+$EU{0KF0 z2EYB*NbI`7TxVftEi6BF4f1&H?DQ_qIvBe%vhN1HQ)AjopFx&%2(unVG~^xZ{%C1J zeWvVf5mD=lD&l{)Fno8YTyh#H7r}c288=E?N=gYn8Id&bCjr**=d4MP$DG2{7me7= z)YN_;ZDCDPo`~oEOI_VZ2qk^Pk|!1LQLsFF_7MnXYJo!mVRq^GQl|VlM6nYEr#6oaIpDeH z%k3U`3CtoQht$4|prB$IxjvsK+Y<#d>$^uD7Xc~2_`5Nw*f+{jvtr~xyPjKu!5C!3 z(puxTttxT*M^nV>Flx!E--+|?A#gAJze-G~Hts*w3&M*IeUtzHC2^Pw9`v8L9*#xZ z(#U6n9RX31U(}@q)J$tm369G!_yqH1G1qv~2PHxGG>mB+6{6Yk1 zM@tLutjUrOf!KlbTLBHX85;2DnAbR#sAdm>tey@`EJ&bGIjpp2?@X1}`qo9svHp{v z!>Yis6Y)Ma-Jz0_N{#x=x<~xypV9vFORA;+HQINff~zM!#@z)fP$=DzT0E*U`c@th zYX(x4Mt%<~Ls;!)=PVdINEq^_GKT2EhM9sHI^RH72e@ZAH!s*BWsw7U_9D%dg|?I# zm1UbuB&=1oHy+u?@69`L)&3r>i~)T^8ozzY&M2mAvhk zry_N~!W<9pe(==94*74oXMN?eQb*t|m%(tLkfLe1>g>lqf83FPYtJyHzYnf}3}#fW zEx9vxy>@kE!`7CT$fmV(z{q1M(P6%Dvi~>?O9Ej_wT=XRgd}3wJ>i73vnl|;jAn}_ zrl;>XW}*xMU;smcnjxOqE*27&R#uZRNMYz^ zjB;7C$6oW2^IFfxC>qgvQwiW%2n$047G4!7bnZhc*qsWamK2~WeO;rgv?7H?nACb( z;bLKrz<43uoP}F+2cCCVdz?Yk%_mH2myw^yzt$Z6%h+Hd<38ILej*#4{E#}&`@JTs zsT`@2Ta-L`B#!iqyhr}?jLi+hmF70?k9JIPmBcYuYG4QsWTefeJ|LceXBOBihYtc@ zzkcl$nP$L7R)wOpe8>`p#jyztz+|QTYCz z0sfa_u&a!f$fZI>y88U&FK0i3@t}5;@NGna-@m(V!;!a|#Nkf_!o^p+HBv_!Zbw&OrVqN*s^{J~E%p^VdCfO{j zf6OKDYfXVk9subG(G!4{K|w86X*iIqW32bgkpy_l{6N$28mJXp|4hDm!o<|% z-huz{;dY9F(@)UZBp@93P~fycM|^wou^BlKG^I zXdXGX<|@9X-(@ss&G2NGe>|cRmUt6WOYdjR0x1fZ08~T zD(1l1x4B#;@=E5+^lC6oO`0kFi8jP9>}ekuLvTD);*HAmF_TFRQGy?q%DghPMzc?LiKMNyc zck00d8eV`ESkR|Ra|0uK&0bFyDFW0T z1-c*ClMO}SK-}xej zc7c6{X)kq~uZq!*#uM!Il$v2OY%th0kbNC1p3!>%0h?v>5CPn@3omcn%CHAmUV4g; z*8f-vaTJZbmgI6J0=z+#mp7+_1-iFJBrgI4`sZLF6(I&|Z0zR#zP`4WChW!>aP{XW zun(W=Jx1*^UWFunY|0k-4IAj8UFBbuUd5e;{W(<^&)&FhhVhMtlm;QIK!Sn`gDHY% zY!^syP}QiS&cxN9`QtwhUg{Jwt^VD-pvy5>@oT$x8SbP`iq*FVu2y&_5hdjy@ac^L z7qsp73JL%PyX@*F+XMW7FNPmsyBpp*dAQx;du9jnw>+TcuyJtB%(}mRwFOZ1RVv5XOe{0FBSVHx@1UUb@-Pr&?7s;B857uzy;_~1d3q5LheeC- z4$YC@}jnFLZ(CENM5<* zZjhR8Laq7FgQO0KAUH1nUnb+}m-X!B&5#*z&jf@c8u3>`bPF&LzkpKKye}#r)Rz4p zHE;Xrg0euXT@7v^fh6Fx_zkGqAMx9PTTqN`g8*d|7Eb^6?Hf^5EVw!kK>nHeWRN;0 zaUBKB`~3%3fEyqtCf1DiM2~8NaVC2rsi3MnUa`HEfyok61gc-WYgt=cLEfg=cG zcjM`(sIcn>7Px`NYj8UT_rt~Xcevyz2p+**N%a^3uLHK>1(B^pGcaKJ;7q~-+y{c0kjP-A}j7G8%Ml@ z1mX1QXgP30i5&a3w^!79%EwPK(WekWeABms{lV&n{EM1U>leTTq_F9AgIgCUs;Mc= zdZeCZl+*yRE>3PK*i)p9tu3=UeZ7uAakUo*zSdko&JJ$nK{Giy`BGWA;#l|{K|%8g z>w6wxe6G-b4tNbKkVPvoA7K9k_PW;yw!zj^>y->sabqFX-n-)rd}>Mb1WJmb_AZYw z5ibfgd`O^kzRNoGF__{XNB(YJ(mhnH4lqXhccCRkECpSE!@lBzF2o#oeZxL{m;$+b z_xq>4srBvcZS&f>3!r{W1Gdono9!2h%E}uJD}i6FY9~sf?SID9J1(HaWn{2|8=c6i zpNgbggDqbKaw-HEd5ifF#pUjDdkXjxQ@g*+7%?)L+!H_l#8e>fGCOlfU>DZ%u}@$r zm_w3ia)tvb3(Yg{PLJA2n&}fMDawYAkSbT8G1|~S+sR0!`;dd5Aq#zh7@rfN)Z#E| zJ_gcS0veiw({;dFT~bT&Tj5%&Y!B}6v zQ@`5kVh|ma+S=NB!hN`cNOL@Ux!JrO zRgei;Bf`1bV~RqXudD)v{^wZDL*(EBWp6g1DHd~S#%dcWqh*WN2A}8MTv}6m^86~e zzBCu8=CDDs!lIrt6oK2@3q}ZYK&Vp!ED@CUv-t^<0Bi~`ji8zHZT}?duvM=gLjwb_-~<{rT8Q7eM9ZI$#0)#qXkW03cxwz z?Bvv2Y|M*z7eFfa4z#c}01u)Np8z<#?6+5^RCj(noK{|j5@82e^9cX{W*v9kxU~(@ zztu+K>n<9LjJQkp%7>TS)q_&RI2C5jy;PcOFr!jX^?OZx)yST&1FcFfcN%S+`coVq z`p~DSzdW2}`%%k-%h$2IaZ6sGyldH6^n5Ko>;0zqv5wyJ&w?nyIepxYaUtKQxbi2f z9k=w+YSsYaGi26c!A*JdoLW1uE+9Kmf?WPnrjif4lK0b|YdT8TCK}XO6QU6q1{Mbo z*nmK?4IUrX%_kxKMhU5_(t}V}T>Sa4H-JLRxepZnQ1C~Z0;M|mSB=Bkm0~@fz$oCZ zt^u}P5q22*xbHX!Q57bFW&g^%Z-5jc=wNT%7(^s$k~Ke=?vA#;fA7`Ga-rG~r2Jp^ zM4SiFX|7f4y@RFDSemZ2eF)0kYn)f=mnG~+j}ds=Ma-!7`OQ-Dsl8;&{J57}>qo{wA@e zLe1lL{L(os&x^m{CQXeAS-TW=3d8MRfP+r4S$T8#x;UGV{+TnAWp&Hn;Lyv1G#UJR zz!70uW(Pm|74l=Ke-M)wKzKp!5&-$M)tiezC52NqReKte`dM>bLM(46a?I zw3$Q3?R$0c^1!-;%D%tG#|7k4&7jf!UbTLQuab==%6mxA_gVj+9FvIS;m1vkPYA94 zP&mm)o;pIZ%H7NM!S`ip3`TL;i`0vmC5|bz1do7oj9%Gf7VImdxsk^2h&Xfzo~)8q zXFwOU19k-=yG?%(vyGj#X{ZJ4szxDfIQ^dMQw1*rsK?N~AU0Iro+K7F&h7$F6W1@V zL_sf&0Uwcso-dHH52Rj*zyZ3p4mo9@RU_c|kpdwNpiftbq31mZMk+{v_nGNJS`>&0 zidEorD$mx#JdoAOUspGdn3%P9!q{Ct0l`TjIrn!E1H!cF9r!g)`FGDnfteMs(& zD`z8nzh{$TUDh+*Pt|RjMXlPnvVOX`+(oPBi@v>nlArwqME7Q;i8gTQVcB?2DtQyL z>Awp$h0a~eS;I(kAd#e+Twc2^3+E8ItXt2Zo|=(i0N+34>mss7k>D)qR~L4$p0l@} zbNG9{2I9EzDuzMcjUXayn1)z4%=ZajGcS<$>^$s=Qni*=g&$GN-Y+P7M8Z8 zv}FLY!8QQtv0{k4XKVO(gkVxHnq4dUwCnZ8GQwC_G8*=(Mg`JxT5L+!s zkVa~Jjzfn-P*|8O96VuCp)hy=P%M(FdwWkO9tXb!EqE*mEX3maG1@FYiTYiHU-ihn zqUB64m7eUIcRF5;e&s1Yg#J!u)PFVj^k0p{9ojOEpH1o_AH2Qmh3Az$f0WqXRA0@x z$KY!u`1OOAx08}zuOAbaoMy?D#y@}G4$0@giGr@)KzuN10NcRhqh?dh%{azYozy4F zei-mtnSzr}2+~tPksVI5m6uQV^!6#HIJ(TjPlYR!pPVIp;ST;&(3ekd z8r?Wq1+H*pXs2jR0tyR2DgL@{~-pzzv33q+#`KXW|0@8aiHP2 zv(wb78Is$0;-$>yGWU%*ha*V+_sjsik!_?N`K~|Qv z$?QDf+k9eTCApo}pawf*WaPa0bA*OdE)FVpu)%8qM0^LLon@uN>9et{@~Paz@l>iI zZuS1SB>nsJ3Nuyc!u_+!nk)>~hW5vVpX5?%OPr=XI(Ibo7U#wyyS=*)FJ`c6;?u#D zCQmf4ICaET!lWk4ZUy7OqjKcT!pA%Q1g zGQyjC3+&QbW90n0%5597p$NGPA`)Sv^!Ly+y%QHVi7zR0m}-f(3f-B8;F@2{MG|b~ zG~x^-^JrxWsl}~v)+SRH&r5!5XvO6NG!Bbj|CdDXJJ*BrW_+IB zV<~{;1ELO%CK2I>ci>+h*MiZA;S-rcrb_^kv@}CM2#hKaDWS3Y&M~lR5q|@f-DdFS zhzpBHCT>c9f&3)RHYZE)F0=g4^mns3fn3xf5zl-1^=FYCmjpjP{;A2@O00Vk$7K}z zq^VY5RFMtcQz%{in@))Qv5bk(OU=|?JaOi)iqhoW5R}!$jvfJ$4{*9*lZLTuLxN3> zNzT?L&BdMJeYjV)8XvWtxehNn#p{@JteQ=I*x|o#!V|W4L(a$%9yS6#KD(Apk`J`B zHR!$Wf-MmOZK#&rjf=x!-`%6u@Ctl?;r*FOo5{?68EQQ1sj#NNic;l<>lPQ?j-|G) zG}*_y2xg&QL=0uV;ijp-`^T2F$&2c?feMz=Zds?m;|ljheKnKDY|GrTUkiKcZLYTM zMnBz~#;Z?oVnmMXTtqMx#?|C}$-Qzkxd~f4KQ1zsb5)H+QZ+m2+P#m-wV!TOvkx!Z zPXtW5jrl(zABxq{t{hzw>V}>V1&s}9KBwH(CRh^lH})a^LZp5+l*FWEN15%pY>4AC zbIj$%Y!i+&7PmZR<}%QA=E-^*U?<9*dgn*xD|s>z(IH31=w_jWex++z)vHc^Up74m zB7LW;?i6hwc73eAVw!s87-nuJ*{0i(k{=)m9nB8%!F)H35W}hd$5&x>zhtKKKCjQv z5g#;{a_3qLAq&Son1lO6;_|DSEI&VKh@G!Fty|XS2$nuqzavgxyar$^xXK!a!S8#L z!V9k{ZQVq6E#FtQQOYs8*3h^vf4j*E@Y$BN*y&FXbwA$S-p$u-ku|rtO(4b^%h{&5$2VA-W>QCE$Ai63h(@u|2bDcx?#bM9lAsB_5aS31t2Lf_LFi?})q$)ZK*Yan2v zKO5TN`}fRoANnM4peGfHeA6lSOO^j-q+4zz+!$XWFQ0u+&46pB*Bd@hkoAsJX5JGG_x!XHS5^K={5 zpEB|OGHTyw#vSK$vum&uq4*Lxxvdl(_V-3vel~r)ZFd=cA|Y>+fB2@X>TB(d-9sD; z{-hZ$sQUNo*zZ@XuPD=>rX9R*;)K`x#n=B@eg6&>FVp`PM5W#@)-(m=-8T3a>5UsV zQpW@!?SzMoavQ+60!I8qHSmShsG!@6BbvN+B_}3`%>(Nm>_$xIc4+MpSNB>DQl}Ok z8h_tctL}+^@Y|oi!?+CM3%PCLnUVrGNhKGlLgKsVA4EAV^yN*3Ik!iu({ga+Ir6d7 z%2(4Ftz>k+ynR3X3H$Kb+>hXg!Ny9X3)lV^)#3aWr3Li39l^~YoJxEAUx=JX)qr0X z7|oeo0n=7(lL*a}cYs8nF)$$JJ8|MZcv_nMzL#I*{oA=Dww7!ExLbgd#wKxpgeO*X zqZ3PddIinOxmb=mkVz+Hh$oz3Y8m>W41Tv_opG;}7dZEiuPV4DGR4M*8=F|qWReqF z6$cghUeQ<3DRXry2sO_ZOl@iPidYRe8aBOFomS1+^A+et;u%tm%++>7qvQ^WbUI*yl?*O^t8fn1iLVS-3PQW!_s^03ci93sr@Vs> zw}SZN^>^?i44%)oOf@Rx%vR@Vt@K`AqmuJU9}=O_jzI4EG2C^Oh^r2GPDE&*%a1g4 zNu1`yB#fM~`j&R%Xsqx__D#yYKyJA>lY*gWI@Ftp4La$hNXgRN1l~{{#_iqBB?GV? zB~rY#DGqx`0g7ZvXDN99sUIPB0rcNgfL7 z_1$(R`Cb5*OKK_MU>if~2Pj=7>sPhG{Z7-#>1k)}*@wb?P)|WEC2lJnnaZ8&qd{sO zJmq7?7(4vn=Gr;+RhCtrRaHC=N47^jOhZHNp+i~#UsexRtQq}HK4KO*4)4Ui<~(BZI^B(JT!c+IltLWc;CtySgZb)5tu}Rp zj)SL#4CBw%%5fo~!r{l%;W(TK^k1GR-f)QKI*!jPz=`xKJBCWH$}oH%?ca;39og5I ztf+KHdA&&yXZr$ayB}$vZ%DjRqm&P5zhP!!8hYQ+$&`I) zgf4h$M%5mdSZLUU-B{8k^qBF+$=Xcl2uZT2k>hBlj}eB2Je-M2O-77O6r+~1aMTU0 z??T}QqbRPz)uV5!tUS)@Zxj&zcAUQe0x zEr5UA6@bH#tweKlE~vBeKlnip$>QC<$>Ayb&rUknarCxT6>lH$qN(6Q|CT*BlGEZr z1WgDN?(!pm2_cP;U~rxSHku4Lb7cp7qXZ{GgNch2^!j3hArvewO373LN*7+$aBSyB zIcn(JUpUDXttE2AE&S5#3{Z)-5w ze>{^m{5c>lC4M1jtM=Eaqa^9<k99Qilb>Bs1-VmiX{cNqjs%&Qi4fS3kWtP+Ta1o$l@R<}qe1=^|EfB|D%AL2M3 zO)42&&`UM--BgYCqb*1p^0In)|LDqk=H@M$G`r^sg+C&clv7IQ)ny$-_iix{XPs|p z)1bX1gQ0lc?D~B@d_t5)N4{+A4bnJoIfMXqeWC$K=T`xvKml!Y!H%C_?)sM$ux$?A z-?;D?zK62GfQa%!<@KY{=0=Biu0{8suU^DUQ_EL;#X4quX*z;aWW!xL0`KX}^@6P~ z-UCWB67F)zjbtfifiuo1f8tK;ckvaS(|0@ao|aMRMVL$qLoV(hI0uQ0jbGh4B#(w8 z`6V9(sWeb%mK*+f3_id3E{;(UUv?uuxL)GJ567A&(zKVfX)#Gpu`1cJ!bvFa1B)XfPP ztcQ@=^ck|+Z^|#CbM_MB1FtL0sI}V1vn-Z%`K^ls*U&-Z)i zxo2sFsjt+p(*bWRr1=FLoxi$dd;WU3W1ax!1T72i)k&?|n@g7}+<_zm#wBcEXAQ=u z)l)HC!2`OC6l)R8cKBZn91D~)(_gtIlOFKido+DuPQh4>AN3PZQ_D(AOLORSbaYhm zTSyrVftHO$$!mnKL5<4z6!Ynd*m@23=pz^2schH#ln<|H`>f+mf!AR9aAl7#Tj3h5 z*y_`d@B4+1?4;|eS+JU-JH_AveZXivP|A9q)XMha^@i9W#Aov$1&pI_;%~7V{1GM!VFZu{v0P@VfHEJ zkYl|HEtMBzzS+V<*e@yWY=}wh7T57(qZzvMpC)gdZbOn`#owTD3US{GW^fl8hW|YqU zTzJ`ZPC z_U-{i`bGR@acI-DLjnPZT@TdMw$g`CU3)E$L62s*nUu{+6FS#vvM%RdA4KwJMiG@p z@P@NRgkggQ%;?q0WAgR&Qsm-IlIHPzI`{A%LwAOUEy1QNLBiPB*#8;VwEO?f{1My$ zyJvX9pGz)CiL)B)a7poz8`3|KUJ+!WwQgch2J@%`0W;c&KJNM*VnA=>E?;0Iz za=T0D*F@2tD4WZ%gKaDL*U4XX+X-KE8pD2z)#q;cgw^i#26y?V_sNe735f$Eww88f zW+mrM+v?`amiwg~;z&$DtW)4{fwG={|IZc}k0?RFP(%Ea!7=R@;Jtkn_NUZ$=L8Y| zjR7N8!0v(Jg|Ktvu9{e~Lvm@3m@NW0OuT@WcK!U+n|BMIIC3HPuk}x$Gu>Nc#08HW z>s|+Dr?OBhflc@sa9(->oB(1nn$;OS#_S&`7m=eaEH+m^Wd83{Ol_KfGKGLnT@ z5>3kPU^9(v>!zG0wR4Le#nPQM6kx#f;2QSS29YW`fow^Y*06eM(-QBL%e8>ciP67x zF6JI9(dsRHOvj&#jEOm7z`8QbR(+zfTzwlXV4-5*;tAchM*!v?0U`$^sVxu4_TwOw z+vfZ=IENt%U=-HnH)MUGLLKaF0n;+ByY=c{q|i&Rt~P5N{GzWKMVx8j8_U830`)&X zWy|W<^r5_X56}JW4mL`qG0#ndR%`r2Bxh}RC!g6#FD`(}$L-`+S@1>k9(G4DBJ$m7 zyi?cfq8W(Z*$z>LA$)y+H-$JjI7mAIWR3Xf^dRpou{roEP~xU$ZO8}nbl*vHbY|Le z8!M&f#o%v|`n^>l3QptQ$0=>XN}k37e}4%scZ_n99WD`mQq#1)-qH6nT3&ThrL9Eu z22(2VgtK8sCjhH;71FTJ(E0}GGV7nHsaDQ&dgJHFL&#YW_ib>+*7$j|&Y*7uvA2ie z8?f}Cujw+`6oEo4vVqK1IeZex^3IUV0Rf#j9L~`_IzK4-8JB2WGP>Qjf zv@0l<+KZ|bW8s`yA|#gKLF`M_J^o{lsrcPkfu*}Q4fcdHK+FCf!{uRzP;LNU1skIq z;N9B~CK{jV`7lFtW){Z&b`6!!pPY!`Yo%E+SUi|2Voe zIZDuDBa?9VR}XYpu7)s~*FVd5-Nwa!HA{a?kt-NK)y=ZR`JQQJffXh;pD*~mU>%|o9r#hzwb@<>)= zN|g!j8unLMJ%YJMvdXr=x#Xo@25R?ni!tmE8MHhA)Eo<5|5OiEJ*s|V@G>&-0i4M1 zZEpvFp)kAwIFAvHUBHe`b>UVQy~Dp%Xz@pdW2@|6NcDa5NCx$IFe15<#T)w!?$K1gKW_ZI+RJoU!Rc4MAH!sM-##qMM$JC1 z^*>S5^iKFl)q?VcIs7GedPi&8-BL;{2K=(z2rdbEFUpVPlYPcs)S@>Te(SD|i8DR> zM6EIZ0SFNa3wPxuUNjdrxiGV8PDC!~Z{<~f?Z3^`E2ZK9yuOFg{VBqDot*8fM%V5xUAPKX1pCOc)Zd$TdAODw5W=L|VamVfp_{A6@uh4dBcn;dlu1CFgVKJ{S89W^qphGrEANSR$n zd~NSaFy^N#Cu49SWU2PQm7`zh?#;M(4%H?toFUdT zQ0ybceY=w^UQC@ixc@FKiotmh-~mL#5zS@;lp-2P^qvBH_tmwb?xjrt&P#u=La~Fh zh!D}6F@=sz5+@6*l(&;UP`8G@Hyk;B&IvX<>H0t8fuN@F1RwMc-_VRtZD1rf50tLg zJX9AhKg-)MwQn31J@cFQYUCDaV3ix^p?rqYQ&*iuWdm{6-OYK^-DGWo2^j7v%E(7E zoa01MW5?UuzctLlbdAK0#gj@8iA;C=lt5KaVWMQ(UxYKhDCGQ zNqtJQC8a5u)6&bQdRSM$bsU5?h7ZBO02^(``$D0+T)*-Z)a1BfSY^`bxTj=Q%Jp1Cq20^wVKSE0~Y`ZYnH@GQ{*j!K@aK@^HyrF<`x>+;LC33Z2&079Wt=F0)g!re17VL!RE>Rm z(RBjk06~Z?I+R);y1RvigrcEH1ToPk1djvCP{^+Wg1GX+jb(4g2B`&6AL5ClU!(9G zDok@L8K8i36S1{Eu0#8ZCG}EWXIkTtv;+*whSElhG*42q_0EGmzkFz~9Vk~Lc$Llu zc8cT?GM)xv5aLROs6P>+<3ytd7z9gfRf2J6pWpZr0q!z#39zh@I)5q955n*oS-Yn2 z_yczNhh)BF;<)0j5}X&q!YRT-$MLlC+a!5W3_|Nb6Cg<6G9i8IPn2ZB(qh}TqJ>i} z-jYInH;MkZ&l-g^<)Ro!J2Dylc7MAH2qqlI_vC2gAk!6~gm17Lf%qOU^TZ`!F3&sZ zgkUgGcny6cPlRdUhp%0`dn@G*cN&qZN0&djN69b<>Esiwj_cIC)^=XbCbmI;+q+fk zOLYt1=qP?IZuZi*zn|5z1s^s5GdozNY^;i|%Nw2Qmn-KIs{c*iOFMUmGmMka=tX+l zq&NZ$FSN(WY0Cay_a0pL*-u<;beNLFEtUYb9J0Jk^bzcdvO8^i$2C>V=r^XD${SRm zrw89hBv!^n=PgjX5k#NOHuyf|OuaxwLjfP~9iS2&@S3$@mq0B)4X)HbfV=nmazQkg zm`TWtOoRx5`we}^l_g>_(1Do{wqjK&ep{^ z_Mt4MN)2u;QiePq(xuZdD+k=wreLZ|E-)jIbQ2^FFs$L$=Fh1PvGS+zm&ha=xS5@o zA+I5e7-7$4u9bIP*t&0SX?ugljQG`Gxm{=a{guYwR4I3&%W8 zTk_q__a$r?QSf8j00E&ns(24+>Vx|yVrKY2$e`D6NM#O6%zjXh0DHS;`&+`r*KVm{ zoKEAk^F7~oj`Y$M&y*~I!T^<`#2^og0GXo!w61n&q#}d60%4OY25(EFAbmxMnvl8u zf>O30=4dp^E6)v^5aF_$9^&rD1kai0s;nyD*;+>f$|_;CWq4gO3^Zx+0$Mb#)fMp7RW z6ekhIFfBB6hwn9UHMjW@Dyu*Ke?Gp%Ek34;LT}||554DHtSdkpyW*AuKAvIFrpPBT3!Ayt)J39~sxr$AIu`0M7nn9urmN z!A<7FGG<}7B{b#mlZ#@SI<1hP{LYuAkSDf`Ko0q@4;OXreI%bjWx)mu0^3LJ!<=2t z?|#x1omX|A_wPciX@UY1`O$_SN29#wpwR{C6CzCq&;yaClhO11{3}gI${3510skbW z*Wuo&`}=>iK76Ivbf0(){F9DnT{)K$SS9kI@A)hX3@42D+$Iu2m%-mtiq|0IW* zEirGh+*2?bXElo)xRG1xNhc9S;q9G8A#gn02q*W&#Ux~6Z&l8bcA=)L9Je^gK=m>s z{m+3f6iTW8LzQ03`L=A%Px_?Av`4nGU&r@K5L6npnftpT0TKnj32!!xO3$G(0^w_4sy;&3q#YS~1hxRyY3a zobOO)!7i`rLZcdZuq^78BHHAe*nn6y)bQ<|ct08uH_3itpm3M|@ zUus@^d0cM}b@AqUnB5=$+ehMTkL&3x-K2Atm|aIIesmWs41imAzH(c*)4( zO~X#)uItzAGv@k7gd&x4j7qUX;DEZiOCc7zHjFeD`IyLZvM*enF@Fx^+Oc<<60qDx z;!R<&v}+qQAiSHM`Cl_D4Dkzh1lL7@aIOorg8Ryc6#V6(_2XauPqn?s-9PikH}hR2 zHGTSxjN7Ald^E;`k~XUwt*h*N_E}83x2`Hp6kM#z9oam8{I`go_T2{z8?TigvBvOx zX_mi+Id-3^z)m|Z-1$Ex`J3vNrBN0aIeEE4r~< z7{GMld}b3j{`Wu#)#3*+(b0^otQPJytq!6)1`fBMdF>ii(a0+r(O1UNgA&oBPZm#=it*LU5*BEt9!DtEZi|96 zt%T5%CNOG|0%k?}k_(J0{cXW21d%|t-DT#Se3`r#CY7bBxaK6JAUGcq-7&a-M839_ z#18Lie4jp;(#oYvfy9PHQgdS$QUlGD|6c_>h$16MA>Fq(56iQ)DNd`?;#<@SaRD0a>;A+(M7(t38#Jp7_}_#99OG-Q?d4U-4&CrRO~>pS zPE63GG7hv>74I;a#|3j?l&Xs>_iP)~k{GnN$Dq-Gm>eU%yKqthlv`tG4dfP(OMQpI zm$Pq1Jqm!23BtijsG;7tkF)@j{d;4KQlwEj+0})C$JR4)wAz>PqPsU=cV!a~<|IuL zuCQdjDAZc^bnCEk`+51Yv;Kg$laU{pL4CU^CBNHAaMg;Q#cJU!&G-|w8Z=JJnlUW! z5d0}MRQ(a|%8|EP_6!S>c64sOs=On2FggcMPVgozWPyBppcR}AP>`R;#+RyiO`>3? z?EmHY`arC_{>6)oMcX1ToL^!3(X|7~t~BK;Qtt-xKdy>fuhWlecZ<>J57eStR3bjc z8qYvNXMo0d8h~ z|G4Cdc_n{U_--+iwwW>PB^nL16j3?qN-pcdxv|)T*xEc>+DuRJhky6d`6Z&Ff--~G zCPUAr8`Pqc@x-q`r8P2|$JG~l8bb&WU~s)+E{J)~Eqz|%Wllt9DMrN4BLNea?a%zY z5qBrrO~q6$%61t(^H`Z%C>?)Pk6GR78jPfaVZ5zHAXa#a&#=hk9w?mBCwTew>>M!H z-Wbkq@@)FvpaCFKT>2Su89-B8fkUN0BCG1ol{m{5651{jq;i03+V`RmwIKD^Ymj*PF81ggm z*L4ZN@LZNLbnENG;Kyun9<NA?A>uY_=!E^{-y`HMYxOsnen39RVDxVP9xIMPhVj z8eNoJy`Fay2)hc$xdC;w6{lBa4iL1she5v5@Gu5E>>zCrK zf-9?jj_@}LsgX4q!WCnU<%1dczmO)9TX3wen{0F&;T9iS0k1pW$MM^upP{GY{=-f_|6ZEq|nmYI#TixR7@kAY@$&I2~_8@_RWmw@~&IIENYP?DNQMVx5H6UPC9JNYU_e0 z|FjjzJIN#vhdcyGy+HUw0`EW_$wr|xOZxDTMGDvLBv5$+A~5=#V! zA-#F|zEa$m=)==2*Vme@o4n0Fuedsf@?i{y$bU(|f6jMGC(HXZz#LL?=}HyeE~S8C zw?+bEnObLH^rW2Q&tS~zxw3-a7rTAb4(@Q+8T*HM+)~cWB%+#fdfQZ z_WBn-Z4H5lx&kwyzLR%vg&=hj#c%nW$V%KD=7zeQ!f9_=?JM;CFJsinl7H%)YA zuPAceVzb)5aJ&XOVaGP zSz<}u3S>y|CUZo)*qzH|rFA7bRXm&wuZwnNzi?@;p>^}7a79n)`ZwHbyLTY&woU8* z)5}d==WdL2;GTd^pjn+J0UA8*enTubbr#3#6p*3KPh7I6Kr&tkxunO|t%YaBpl`fX zIsK2$?8~&FD%nwBK0ek8{;_%v-zIwI!&p1fe)VKSBYw;=wluP* zxjf&p$#u;3D(bM&A2f8GXEsUGxfW~9_K+Z^Gu>y%Flb?vXwv!9g@5+AxywRVtk?3W z%ux~g?ohH~jhGK%yX9{I`;ix5FlU9n9L(Wix$FUU42KV@nb}Eei~8SxCAnS@Q^X~B zER+6=3g?X2r?g+!lkoK6Q@CsvHWUDmv#GzROBn>DQ|4vcTrW>uhBDR}#@E6y&1k3` zKuEIJeqCW1nB~Y|DTI}FE=TD^_grgt8l=$l-nWw=>+qGLjcVUV#!ip-Q-`xaGcUu<@cEZ*QjDB+Q+XP2AzZIaF)}-{OQE_m}$No)4&% zbb;8f2-Ju6B?S6`G3RMu820PeuPzum&7&4n0iDMDf|8F{6{QOF?SJG;vP%&w1cp@~ zHKSQP+jPIgUk%R?I7%gF3aAQ2fZs~&PpOW|+htTwFNZ~R6}HJhkmt~PW{}8dt|$Fx z@^V($Et_K3>>)}1)C5q7t`tgS^eRkct{I;Jhnf1_g_3%x>j6n%3P;8j zK)+@Ua9-Kay>0z#p3!mu$&!JdfXaqR)`VWC8_f+1r@Rz_&tQUm-Z|U2IbxSa9a^sG zvlN77(j?o+%pCKfhRmGiUD+Sp7JfKy|9A7w=j!An9w_r2{De@YZ$ za{K68n!VAgST=}>4l=1W>bJnbWO(VO2hbWLok1d62o?}9Axqhd7^}QWe(!1*J$jIM ze02|1$lQo#<5=M$GA?k-n22E(fL0Zw6K7W}@sj??nONQRX77e{0Z$RDRR1mteU9d5 z@I^LNesD+0l23K%MUSR0mQ0!r3q@;T3}nqa@wg+hP1(y0N}IllO2nL9q7ywB9ib99 zSdnq6@axUmxwN=huArnADHMw@9nQJFk$8i!6FhPB5^Oh0a6xc|>C-TKU9uCXIGrFA zE5Vr;Zga8I1}DVxD7DScEssxaqVB@M>c@OVef)BSFcKIBE?HiSd&~QI_Irax3grTf z(4D9`@)@;YRV*C4XISCk(~-GL`Mi$rP%+cGpeXX4p0;Idi${q4W8Cv+zg^T`k-y#} z1jq4|ev#4%N5X)<0@KEsVQ%g-Wy4c`c9N@16&yHE@xN%hSu*(6K%pBiaW^SRqOvUI zUzRm0x@xaR+=t!Y$mC2^Ucb22r%sTbxTdR;mlwP6!?R1TF)!*8qmUwzs;f+8uakgr zzx2HF<)1gp-uww=2Hdk1W*X+t0drC6y-`Hhkecu9gyX~pY(GAhC5 zZCueJIB`8m`#4Dc##{NfT*|tQ?2wbg1}WRB$Ks$<#z0(o*1>K9W8=PV%Y6T|8uV?M zQa3MR&HgfL1augrk-{T;=uf_M^)psNt)UTJE$#DXx^^AdTMOj9J*iXSfnJ-k+kg0R z{q)-1#0K?=lgH8VB>5ck)YRX^-i{+oELRv(K8$&eSvJD)(#* zFiR9H-_hrfz`giiE=e}LZh~$)OOBdH-6r2ItGhmhbKD&|DGmf2R$c81e7|krw1T^U z|Fa{h>FJjoq{YO=K@3~r&?&)cP}6Yik!vG;_-#$u-mksV{wExZf`|CCS{}su6V*jL zX=;?C^@$s!qrh6HY{|A|^a%3npzwcCPUWJ0FVXhgT0IfIJOXVdv@VzolIkT)PHsUr()9Osa%!n8s5dwV-!6oHM3RpgNyTaZc#JX#kHcS2hKG1HZ0{Gj#@;G%?!wcb zqg8kRcFJH%9!EHW^+M1%0j^yjn}+y>$j5c`^n5^q9jN+d;s=p2QHF`Jlp$`iziW+q zhFO|8fM{d&`(*;|c^VCnde*OVHO(*`<-@q!IXc zt+y?9!ftyY)j^>6Mdz4%!X9Yz?iH>Q$+)kICy0N}^MoeM$ZoXBI?aLUkgC&MMMcNcvltnU@;^S)Cm8k zuP-z-{SBL8f}!u1F-qVW{w66xO>iw2aT=8{?^f|}|-_MZvgm6KNR<$dm5 z=Il)6KmMtf>vK-4PpP;{W;%Ue!N7_E$uUZuoe>MO3-x!J-nP9UGarQHfkIMHH_;AU zyb3+oV0JsG;L~J_nR2ovNOdF5%Jv~?)5D_LNuu|;`Dgb#{XZNMeQXMKJP@#>>m4Qk z#DXjpR9pu7N%^h0N37HvdzKcO)X}{<@`3rYT3;u!>xgE!(WCZFWIJKc?(nMd_8gu4 z!m|2P*=)L#iHoY8$9E}P1UX&y=B^XiB9BgRLJ!&OoN5vjRi{*}MC z=h@Plak$@QCC>W@5dHf!!+!agK^;Z^sLBAenJlj>H_$%5J+ zsr(FnDDyw_Cmj#YPpVBScUbpZA$_KS)8^zE7=2l-GsN=Iw%XRAcVL}qC|j;|N=@_3kW zUAyV?tAOy1Uth;^cgYXX_ek=CMaVat?G#L?%?}*Uq`o-=??hAoy^n6nmFNC~Yyp&M z$($_oxonst5Od)2x+VYG+SqqAx0uk3Zzt=G&8yT$iH1bMfsww+LXj$0<Zz5?TuVYWEPO%6y(s4wa zBOjxc-yE#iTb7dKdsWfaStgor!L+n<-2G9xOkq4JE>UzvdT!)iovD$_cxH#$P)K-( zH<4xc<6Scz*lLkKamAl+wC%lX$_mdR-vg`^oAsK6mH2;n{z+IVXb|mj`5pE%3{3x_ z`yqhj*l~PE200y8Vt?=8JiniQlDkaI@CEbOslWVdo-+76{j<;3v(n2)l*wvtDENdrF@{ULU8L8a)cim7SkD`@iLH9dW$Dx8z{DQZE)H2gR<(PApg<*Jx!0e#}jM2RGPeSHxx z&T;FUIkvslD96Q|I67s1y#_9BafBsFBJAujOG?e;L*7U}=0jV?R(9l557zOGAJ2(> zH+|b{7mRU{zk8rW6hud=RvX`uOnQD$Yu_|jIxz}LX>0`A3o@S!pM-}5ddYiVX&ipZV8@pq9#Ak;Dx^Q~b@8`8YrQt6UD=*x! zXe+L@UHzzP-JLW`Bt=I&A4_z4S}wu+v>68z3QLb`nQQcgAAPRMWa8~k6#gF5L5_Sm z_L2yUjJo+3o(Ti56T6PT>GIvlFb#q zVzPQ_o=HpZnfUfr<9u1_3v=W_$Re+bu z(j`t%%s}Z*g!<6gi8wi~?#$*C0KxIn#Ec0u&6$R?&vcy5GhG_}^6iMVbpsZbqJ{mb z=E`W~s)^?OuKSUC%jXBbHEt=eGt3PT zV(X8ViG{c73#qGw@g$_`|GnPkXT>-Fyd;eIwUfDyW$H%lM=LU|=P%-JuwaMEZS}yD zQ>VYHtIO`%wFG#S7fj3kKJUN3+1#kT5`J06#P*KWtot4(S3LY9x8a3dZJBGKOJ18= z`r`)=(6X{G9r{0g`t%Drz)X+9a2b62-mI!gNYJA0Wn?4(x{&myG~Kv=hr`jcEIx(J z>Yp1r(T2&=IX`E6wz=`agRk??L*&Z@@cEAbsrXW23uJSWsYOs{#DE`qpXz~T)& z@0zs~HX!_>Q28&&7@W!l25y+j`(DK;C?J0es>Rz44WZ6XNhdZraaS+qVIC^JAsDjU zkvMvR-7y`1CNM9l`uulHe%Gf@_aUiu57O6z9|Nc05C;Ig#Sg?Q{U1qHzvhz6l8$3w|r5r{Q|l7#>MnF{ae4hL7{!&KGO?46uQ zR8&+*$jH6{>HX5UR8a_Wilt9hw{+ZwgUIlRywV%whQZAOu|u0&ScnHV6nvHoD4Hl} zXmr5g+gG=)0VH@Sm=%_h66~(CoRoD%_`0?)AzJqr(mXm8K!FM;xej-qmUPEB4Fae^!s|M_jf!PWJf+{>fq zP17+g&T;JpHYTInH4OC)Z!(uSY{Vj}s&--*L7r~_^wKg43LN^zLD@n!$6n?~{&{^H z4w?J?c}5vqTCXT(3FU1=B&DOsyL76_zj^p@#m=2OqY@INy$THbJJZx-{qkYVRUjDJ(=yS|6WCb!PzQwSh~kg+;eNdN6A{(%JD^xe3af??c&G) zYFz|^{#tGs3I+Ph)rJF!V#dUy$r}Zva!Ct!&m~w+rHB0fMT9IZ37hQhI-`jkgHDYj zwtHtjtvL13mZ*nPR0DQE3RN8R!}oQqAKR~Azkbp;^F`qe6e;4%E8`7(_-xkW$oMtt zdqSQWc~3ib>}c{fmCT5YjC2VWYQkCf`trfzy0Q;Ld3m||00iZ8(3A*>i!)%VUpaI| zH+N8wE%!{w`JHp~nu)4r?#bovKHS{g4KH6>$OpA2iV;2GdS2cbM5>r>OeI`I{)8^B z;q~hcvUC=vf}L1TX_r~f|a+}J2( zw;+h8UvNi5#@fcFt)U^5OgTd(wZ3qxfI!-#&(J-VJ(}@&Q=LRHIC=7<$c&(qp4Em8 z8+25jjgF2+erp$(?xj&fsK|7&gm6g`OK4!Rq9%70J7Yh%K!)f2b=JoHBX5t~pGc0U zI7aYgR&?qHE8*Yn5&x1=b(+JIw*O21{kH_`sloowY@E+({h#&8pQN9+^8YnA|IZ&a z=h?u|-$i`vv&knQnUUK`@0C*dz0?SajDG?N^Q6xF*|joeY|Sk#>OvPyPVK;-)omeo zh`VP;yyOQ{R1D`$LrugBXSd;h3y8xv1jcaRVAY==v*~ZA4|~v_h{PGx)7J;+qD2&a z)ZCnVqoVZaJ!!l8^KkxDOE+j37~loN#e@0j3rR_lV#Zl%;B;l=<*A5s>tQjwDgd7n z5$7Mjd^$9MD!ckEW{zC??()wL*2MxY$sg}u4+9}8lqMIiUd^5N2lhJ$v-b=V>oBnf zUOWZ^4*P(fkohsH^9TTHZfV}Bz`BYRcxpe!DZITa52VPp*ls1T^}H)zE9aL1ffjCf zIbNo8LeD9kx*OV0S~aN0x%5wYdK&CvqAbIKGGkKGvQ!|Nst&uc-C{{P4&PxBOi$t8 zd6Js{0}r?0yVpNId|19b&x(6M*wAtE!OVV+|e$QP&k4$99!!?cQJ7z30X zOq~7Y?pBlrWKxNgV|WZj;_(tdUZrtPi~%;})lyCj)iWWtOaR7Yy&4#3LJMH@Gj12h zwIe4^+})wCungC#u_uzGROD~v!0#AnDUPg@suYTml@#Y?m)G8$(cZoef6N;{%79bm zKHA@u$A>eu8ljN~zQxqw;CdH-Oc((5nyFrOM4{yLuipn}W(6|bIe5i#y{how(zmM* z^ivp%l5H{9NAQ(xNht?}C}OOedP;FB3Im7|6c%22=FAzTkY6gi3T_c|bDrP{SEIF^ zH~RrcGT-5!XUwNjHG_L!qd*|E6@-B%2??u-_)&EbmBIP*=ks<2;EA|#|I35Xiqho& z=Y4T9CFccS{68S;|3I(*^M?G_UzKh7zrI}mD?gfk=nap3fl3s;pP%33XU~de3rJb% zA#Y=2Gch^2m)uP=?f^-gym#;3{kf-l_V)G^{u{|3pFQF*ENFS$*0vWf5c_HJc9HYL zy#f*6;n*?86)VE0o-#NUC8796P~4YmVuoMkq?dO4_6m%@B**_?6%vvhi!Ro=jXHN> zTp#DGslA;Z^Z7bAld>!HU#Ot1fza5~^y4rg(^F>>0<>BX_U%5>8|3{f*cG-)>WYhZ?3+}vx<)H zVNVYm9uode3vBE1-2n6l)J;BRuO3-le#z8y^Ei@bXMx52^K}J3k2dH_2$*Fr zz0cyfaLZEc+8el3BEQ<7I1!J?7Tb;I^>%hPzO=N|sL!xle;$=|G0ei;kA5q&KYH{W z-Vch=%If<@VqxLo4o8ld9k0H9do$+v@;}ccnebryh4M-|x?@2B9^!YS91B>NQ3ttn z-&uqAv`tcyFg?RrYNq+SLw*ICTUazcdbD!y-n|(y?3qT<_|goj9I&&G0SM42je{L; z1}ZTFQT800njO2b;YKC+r@2>6%y%>H;mvXcf@CkxqxD1O z>C%(Mw3dmftJ2HqIt;6Nmk(-dH(~=8zUI0rYa}oKw;0Z#gqD`dKRzr+0B5P5i;L*_ zxHw3?X(*K#!U7qEueTbLD@)(LMf*GC1h=B9B9vDSc5X#~+X;Q+VDJxjZeCt8cQMhE zov`0;)vrDD8&yuP#~VIpEv1QbkKcC&^9zk#T`UCmfEjb=A|6>+r;3tm+W?SkMOU$tCU8GoTG(L= zklPz{1En{7RN=8{m9LL0v8(*Vn;5`_Mx-i|d|{>*hGfpLaD7k6#&6%gnU}h-;I?&Q zVhXl{t1$Jcj;z??CswGpG8AL24Ja6Kd0e1QARA4(KEZ}o1ODlRvcP8-l9xB}F3y>C z^udE^ZrQ%{xz7%Z=@{_{oA58fDPPmp_T~jCa1hw-1P+ewWYvZfCr-Ko5Mm5&)c{5!dteCpg7L>22KQ6I3fXs>4hto zE`?#ifCId?$rWp}wj!D!MaUgf}_xGzEJXlnH|C8l*1qD+p!4~ne z5*g-k$(Ju1k7UoqX;>_9)?GDHJm?3pt^-8J#HbPV#QKdJ<1rSj=6#=!lXJ_R<)zVt zevH+PKf5J;N7|;E9U!0;kt=p$!bLAtehq>&r+6s9!&w}~bU<^ZQ?naLkxc*t3d<&C z!g#@Kb_8r=Ge{e(ag2w#PU0~;hN5YVB|XG39v&N81^lSu_vy}~Z|BzW$DsUi09B*( z*ZA5fdewk0RD>5i=}x$3V4gqf3Nf>V4>uDo4+`R?rKJX(LjoeG)RpWEs}S3ylsAGO zj*5$GMv$U%sYWg%sNq&HW5O#smP=Y_+qSjvmF^ZT5(nB?!qSd=$NHo;>^!*=(*)%2 z5&aILG1(Z_k9%9(-X#6tCo6W@HRA)r+HEcNFw)=;dnOc%rksAt(Kp1peNX(lHyqcC zuy{CY&4`Hiuu=l5Fb%C|fB)Uw6v^oI?UOTM2v9AlJ72zWgCEAKD=7G!8c5+d_;eJq z6XTDh>;BPEh&kxjtl2v=P^R(f7lLSn&W2N`PGQFBp1Fl-N*-&;or7rPv{!v~*45WX zX&yPZj3K_T`1p0?>VQ|moQNYnBs@`KTZ(2S^Y!4|x)nRo!to`D6)@rBaig*3Ur9XabYlody7wm$FUjFHrAB;>8(AR@Qy1_jxmtOiY-*b zXkk-AVKN3kV_y_=I6ZqjfBp2s__?)Gyz@+~GX6!e&q@;P47e3;wYPF}SE+7f*raJX z2!I#~DW?-YoqFaT3YM3w1$y}Uap>H+b996v3<(K&T^Dd@bVjZv&%$@xg_3olB@sM{ z4LxP67F@d!P#oOcVlXN?3gT}+zBCIP+X~7ORO#tOMRrG|2fBIvV|MQtpO}c!h~-?v z3PeEUxxL`R?`_fu_)N%6gdVnfU13%% zFQ8HNUkugT5`O>w{RvBB8|Its#~)#tg9)a+30)$bo24v+ulH9BRPLQ1h9w%Jh1a9KV$iy!Rd}6McmbR38&Y0i&Xb z-8tAcw#=crfMkjRI2_fGU)R^JT|0qAIjr5cRQ%LY)5C~GKyQT_mSkmoPDxE=c%Y3r zE1EQT-K?1eQ)yCUt7=&vST(R*<2AaPCJnFIVF0#++FLHK`G%!BB9ZEXKUc%>>adH8 z);Z=)TZA%jJ`Au&MmyJpVVX=zU)9?WcD!b87r@Ci7=R z361c>sE)?Ts_ST6Na-!i&l9F&%@F^F4LP9pGG;!*l?%%ev0dutkNn_i1Txl(U*f$? zwYCbG^lYu$@C{4$RQ*n$j{0>s>)@>e8#itgky=1;%SIVCzoq0AMlryFlsRZ=$*Y}d zG*Ihzf50aveg@$t!W9W^WhfGW&g(BovJ(I#BEdT1PBBKwM}5GLYlDHoTG1_Ci>T;I z%FDMP$Iz9Gk!xl80JnmJ*g4(B$kZOPh!ATBq%5- z&+fr)6zmajAKC|cEen`8*}UZn=Pe#54&x8CM&i+*dK7-Vo>XVQt;rs4wA-J0s@k=@ zUcIsd6a(TYkF%Enq1)7(nhMdup=zI4TxSR?8?i5%Z(^z(pAVuig04Vmo(Q^e4bXC#w8qsjG47u2XmC_$d~a zmfiRV#dhqt0t%l+G>m#^BgWn?E%N`9PC& zUb5I+eb1YQkVPp9Z)*K6qps3Hug)9vuaJ;3j6-SCa7~SoTggvFG0g5HwCiUw^$pG( z`JF!lK!$;yUI6_!hv4U!75-Se4-kNzpI;yGHWP6!amgwt-WfY+Cm?iB$LHz70ch** zPuaL9c$;O>6z>Nl)jPLu@52X5!%H1No3HJ#TvuHo*fA`F>B1`%lah*$-?ZldSTFJN zdRQj7p(Ma5)%Q-YOer7P{dLHx;1WT?IGy`3fB$xLG#o)89X%Ksv{(M#oZJ!l5tmKt z;K5dOIWOM484rgjye8AqK~(oJ+1Nnw1hfEMOv@0N(fUw57$4g|=FzM@+bfqqMK zg7|DK>(_^#SwS~~t~dPL*1;E+^bcGY#fQ4vAE8hckdj)5vVFa~qO>$KA`^un*zP%H zKwRUQq`q~1APzQV8BAQ2-#^ZN@){=VKtAgl;Yp=BO-Mb+3le;psE4* zjwE0mWgO|2At4Db7Qj~UA^Cu))12@F8A}-@Fgsn#ugs1uJv0?ZQ6EjaM(@soHE#aNX_?hmP!(X|K-3pkVU%ES-#JrYbT*^eb*E{T2N5TfyG_( zgPDT-yPqseqHI6%q=9@IV34v94{_2C#Tpy>v^34Zi<>pcUw0sb#N#$~{t9GJ2}xoF z1>JKm1lquFW67ja29a?Fqxow9UD;-f8Ggs*sRS;qCw-i2*Z9=c)!S<;{4GFexz(cE zb|oe2_U7R4rKPP(F@svzsXCnbDKkHxobX{(?@g$*iec3fIN!__6%)f+9e8E`bCJ&| z+$4NA=kfL{12fjw*Q2TAS3KtK?(T(q6!%H|q82q_Ev;x@r8I}Ly?rR``h~#e!i`T& zON$;(f6^+WHxYW^#Ks5)Z(DhRxE5uVVzl$cw!_jlhB6dyutrm7f7ZwLV-r0~I}QszU4 z4y8NR;BSzYML}Vqom)U8-&&J|MnxDO_Lm}4k|TzML>W7pZC2MLIOBPi3G5U*lQM~)so%ziDaT1)I;^D$RvNARtl zKfbub+(T=o1o$1a@3|a~;)yC}+#7UnM^PFzX?S!oen$Tq-;ER5llOg%D2i7&;WD9E zX%7C41e{zYJ%tu}(tzVyiVu3kvT;$wIT556JKznk!s(n~zUPs}Q(?+@J~@y4zb}db z&#a}Rqto{dmo?)N{uFjo+u6^Hz0xke;_3^gRp+a(>`u7jsILl01DmzA@a{Z=Uc8{R z^ubCguT!V+w%7xMf+R{$aQHCo%2fDr2bT|72P9))u}_j|o86@n-t=~zzB@FTwR!hh zde9m~2qK~=fND>ua!&o45UMa5P$-1v$8TTo_UU^t=*TdTGYlPkXSZF zz6ZOpdF@kI2Zu8Ry8(%NdU`g%MC{`A>)W>FPhI+jHBb(yh|KTVpP6$Rgv4MS&gQV; za7m)+7D7A$bezl06(fBNunj0W)?r6Mii)2cait}B4|=ahh(QQMZVI1LP^!7-f4qj+Ki`OX9hgO`F3MuP-e zgg#P1;)d0Jxb26OtAX@LY{tRAGKM#&rlzWIxhx2@CnzY0dT=XR4PImI9E2|gmB(gz z?9H1u;GE8ZP_qXM13*cX6J2S0S2$1{TFezLXvzDltZNCP+@5WS9i`|5uUVaT$biQA zwUsX1dBIUnfE*_g6Zj*>wjE*?3Ae{$dN%tK%1d3kXo0#V4(-EjqQePZzka=dUTjp< zdR}qBey@*%syOod+>1m__qhxwvv8P|$|^zSfytTXaa5rMT2hB}V;&^Ej6$&*1@Rg^ z(`JN2+!B`0S-d+>3S$jx6mDo^bF)V8v78*lUUU{6f#sXZrVs8}7EWz#Xi%FO3a_5z z*x54mf-&)zVh(TZZXqFcmz*{LUugU!{d@0aw7*)~#KIz&%E`h)16C8cB?I>l=onl7 z?##K|R3?E{JzWj@7?8$ylUkOH>U!yo&zoLMRpOq4neyY-RsgO7E%BsFfSrYQ4{Ws# zF`ZU~muX39sc@wFUT}4ZNl7A4(B5>#gEx5)9Ms*{w^c$y;uDw^kkSkbco^YG=+OHd z5x(4(gDWPF#g-2CEeN$e&@^gGB^%>MU4dqsR|nRH{N9a{BM#L%GG`n(;yb?gK3%x5 z9^@ql4>z{~uyW)XUfoZi2ficEwDtBTqSf?mX{i%Zx|qKe5j<$saOA|UTuDr%LLfkj z;EN{WB1exNWmvoR4)jJ{A9JFS2@R0dA49f+->k7IaPCZtpQ$4KCtUMf^d2)@H5NwS zzE#7s&qmdD(OPONumAwS4Huq7IrE(tedA?+ga9aKgC9gMaXnt;1xW9Lu;w+Rv16aQ z`s1#yi%7nx!n#12#_}h_2<`$@-rB;fq#Di$+pEhZX-P%?zEf0m5Osc2L;qB~JJj4X zX(&}p27Up`B^%5+?(Nd_3C>?Va_Fnr2kaca10vxvx=1JzyHJxz^B%S-3lW*owQv?< z>B5gh{?+^X^>ySYO++h>?q?umkf=pgtXSdpqTs4bj{ESCyu5rOG&|td3~@2hJJWCe zGKP}77k)>V+72IfbW{gaVP$%1yL8<|%z6tei%ostRkF7kycsBnTog*|TwK}n^)6Q0 ztp-W&fi~R$6fDfmFW^Z7tfeaZ=ys=~>SX1j zlDllp6IW!-tE3El3%+dsbqh=Z>Z03=i(!YNA#l5pnJGf_MZ~unR0}$<*}gCFtu9i9 z(t*Q3(_NgqwB9EidZu?ZE-!85XKnF0aU9~X3W4x5V6BL*stYk#_jPh-rvVOUwgh)x0Wf!rUh(0{& zS&D2p8cY-D>)rtsW`G`=cDeyOUqLNef{Vy%hM+(HhJ~58x3{RWvgJ(h;DeA`xI#JL z8N8wA!$agGJQ*aKIyf3sDmyznIRosRocl4g#$1Q?)Exs+dESCr)W47`ijYlNZ6I~9 zcTgZzde5F)V6bs))fO{0`#=IjrMTes;B9E)<7ea7$E+TWM=)B=eDJ^l&edJBK$7H_ zdLu3kf9-w8ovff1h-MG?6T~+b)xaeKZkY!Y3cn)Ou;rk`Fht74mA)Juy#`3l(9`qYK*e|v zeC&Fm6uqRA`EFuD)I=z;^j1<=RIXd-GIlVujpLqg9BVd^fL;v63Rd> zyj)mV*txzIf9}Ykgbw*(S66LH#1K;A_IHz$+D5r1y>Crk$i8@|9|@SW1rqJD z2W*D(0Toa!eTR@LvgITo0Z6%m3%}$z3^NsezTZuaJ7t|8hJvIy{n-aXDAOm~1SFt@Ub9b`9tS$RA! zxX3B+NI`YoqF>VQwNZbqN<0{oJ!wuW5oa_t47ZeMEUNEk6a(`K?A*x% z_(C*0l09ax-^>0{*nR1)|M{t7>M1QmNt8~76|XPp9+3W&<nKR%vU5_6A$mi&p;y%t~eQ2?0?;dx2-T)Z}X9G>G>Fv&T(xiSQCF>*uNG5Z+R*UR^U}7pYPgoj3h)4<0m_6|jVPpWGpJ z3|M%l5{yD=o&KItFtnuq-~>o&Q8aLE-n=~R{zsx^dKub4 z0jx=e)<#_5ziDj?K4HqEbmBb#XvQ~Eh&6`OJ?!pg1#*QaUL(qZyps?r zCkectccLD=c6R#yQN~*nu07&qa&7=2q+~@7B@3K%bhWfXfjOXX*0Islp7%kd1+$f2 ze*dQ>{@On9%?R7pZ!B`zHTX`N;?NN*lR}T8J27>@4}wZc=F_hNj3-S`t%e@!8uX_C z%m!~>c1EeEhpMONK%}Fn22C*BpalYxJI0#FdC;AQ)YcQ4e@EAl<$Tj5Ku{wD1CYIfp_dcxjl9FlvMmf|8aNq#%r656F8U51^kO@e$Qg({Iq@ zDB0skS&DQ>WWH##IGi|P>kNHk{szCv$+XwjUGnXJYuNfLuXTIg{rw9{#q%Jxgp`!H zpFVv`7CCq38gCWckD6sQU>jSK<`~M~^uIsUo1s#6aB?!no1VFvnre26(^lg3X4Ghl z(2J)5@R2TSX(<(FhiKs`zlZAn0%e0w5rznc0h$Y!FK^yf1c8$Gq!FZ#d$(Lp2Lgfv zw;lCWrT=GfPE*+70xPoJ*CFGQ~##=0EcSo!Wj z$O!0Z!cf8tRCpwywL(8hjldOa930$el*^8qy%AMP|6Q{~RsW6`+zHcCeEf%?qc#OCvOsyXGVN_3 z85f|CFO55l%eKnE^C`Nnv|tv{y}ega(TX;X#w$F&9bkLu(fuU@KZyNF^$+l^z^684 zPsBA4z8Ksnv@;enpU9SEv|gjoKOw?;wk>*OlX5XM!>koo@svp|3r!rZ9#Pa(`{ZA+ zeNlDe)~y7{zQ_Z`kN}L2cxh>Asc%M}mz~}5U;Xnv2=9>Cu(!WN(i%q7K)rfO<#n+W z9WEIW0H9QRi;Ul-;>_Py@(4E%-QHq*u)<^LEM61%&jg~`LuZ$Yz69kh;4W#EU_*8| z)M=K1lHdDNpgsdHBWyjYtfXWQ2@6I8$cPg}DY&q)pXV?-A`i<>Mg%2fN%Wy?D|bk@ zq1izMKtRD*0J3S9?}zo7Feg_=H=L{~sS%FASg-467^$;3pU3awE zxOsRQaW+Bf?}hp!G;HHZR51J8-10i_pZfntML_7e(V_1?$Y=#ELPA92`Yn6zZ#P%; zQe`~fszEuj0SK3#nYk7F3j{28Z`-D)^#kosoF!XW>)2`8p=*Kl9q}FymN`N1rC>)F zY+JXPP*6p!1qY)J-GvwS`w(#tpr$mWbdf%KHsoe!B3q=`!afg88jsg521^pH*kxIi{AVfMWlJ{Ym1Z=>E zD*+Y`Y&T(*a7ZJe77%4;0K~O1ae!~aS@Z~@ZqZ|3iY0hE`1|D64N9IA^sgHM0A-In z;TBxR38tr~C(IWM56>DXZ^5t_R|>wtCJr(pf{-$TLaV>RV-xytLfiJh{(`8iK@>HC z>*z#phx~Infj#j0+`68N>m1&^J>ES1x^)DpV$eteuPLaq$_CI=vQ4tm*P$9|j@bh>E_I$sHciaou}4G<9eGoB;Q zu%@=}F42Ai<#M|Jz%r%=OecAwM8H6s6ZB=6)RmAe%42N1v^$PhLHw@}h-9Nt{iY-O zrtJJH%b@Fr90omg-P6sMSA!=antjrEvHV@ubnN7bB8-VNHZ~IXYSLZ9Jz@ulAA5Qf zj`XA-2-PGvvYsZ!dOEBAI!Z>gQnrGE0Tg_QJnsd#4DBT=Nlx9myY8l5G1f@()#G5$ zC~n=_Vbnbj?Ew!|2xLVgl}e53CJj*NA~K+^Bnn-ep<)E*lkQL_Pl>8*F?q54!|i=r zb;tvwmh-3z_d!yNnl38);H?p870+g8Zza?o#)N);4EZhb7k_5+As@WD6@b7<4=l)- z8;OQ=EsxwvT<(@gMlM7El6>g<20_ddhG+NBxqa4ECs$!es0Ak>_Or0$Hqf_-r48ha zqNh)~dZ4Vi4_sHGNlQy7`c(A(8z6xQ#oKVe&=f-MX+=sINLJh&;*SMm)u+biwb6)Qf4uUL%pjoVoJ$Y*L| z;%$CWy*dWyaCorPd}#%$^c55kf2;BDMj~_&oR^yuXGwB(x= zs^c|8{zSa8%n=}6ddgd*IxQU?@!~Fb)H4_lVo#f(2TM#RB3~D^lde1O*f)sSR)#`B+?z$WawSn>)SaGJ zPbtUb97@APtUU)wA)MSXi5S2tsR)LdSIJX=g+QKy~YfQ^3C}g6m&B|iLxv3x2Oj-5sP1eNkfN z6VQ12lnDeXD%E@B(HdGc(63D@fJ(wiI2_%x}VT1V=3Oj~Tsn3%=nlrT;zqja?V4tg;j7%Di`4y#vq9SRO ziF)p-Ti%x!7D^VH7prR>O4^AYRg*kEx#j&MaV~GbBxM4{v$AxdjrjP{qll7{66ykm z)jH8Yiq5gWHFk1mW@hG^Nrd@KvdAH-efQne53$Ob{5_~!yCKlB8Y$7xpi(cO&FK<= z=2A4Nojd4LM(+I^hQN|Eyko)V(izv*+0PIKI^JAZ#FHSaI zgsvzlG5M;+)Yj2K)oK}U(HomomVWQq6kW66Djl3mS1GtcLnu0;qb8uc*~-t)4@r?}%}*3c zZLeO{JX)5Nh-cxj&Zb+nDta92wNANU`hmLmyN?TU31Oj`H=Ecq+sYOJx-|NXkp~39 zu})5=+T|>xZ3<36EKu5ixgcY7k$Kv}*^9p;2x25WbQhS3~5cwH&o_j)?i zQRqLC4iNTA#>>2EYRY)h+NU!zHYW1Z4tiba#kUnJK;-s+rHA1J->5PD6fA_b_4TjLzsK5UT5oX0C}4I&=WYOL*S0*l`EwK8ns3La!zg23nsq?284Y!3x5s3*(-o zk%G|N+0*k74U%x&YGlvY@~by)FzM=c$Sq`b=NN8I=U}2BRm}s}Q9@soH-?f%VzpPm zzba~^&>~4d!XgOykE|wlB?XGx42VSC`gqvb!ckVBK~&Z29^miK@Vla{Y!yYtf03Sv zSKeS($nb-gF-z|*AB^~o4)#h+_)cK0*vQC;_#*ua016pg1Y-T%J7^B<)I!n+LUkHq z$^=|Y5x`fZefVn^xDIV??Tk-Ym-(b9jnD+IUIWV-vmx$Dq?NI)c zjjO7w-D)>wR-vSWMhhb})@a}9r`SFS4l=YDLZycBLedaYKUQXxOGN@HKHSDiS(PA& zZWISLy-A{0D800&rN6`Y$k9<0)Ze(aS=w@bJK*ORbyfCuuc`c-TAu2&XL;T4f<7bz z5_8M{y@&r3j6J4g_98)bBJp*l_;461-OTgP<($pEoB9PH1ETu3c%bO}$4of${f{=j zXgqBwy6K(9(^v&f{342vZBHUXA%F;u*AIUmTH>|-Tz9A9o$2`UUW>kN)LPBxMXu-J ziN-73jA9xUBJ5Vgy>7H7Y-+_XAw6bH(4)!9jAVroDT7^E&aP0cME>>_W?e(32`Y=C zakepz)%5C>RZ13yt3dX0)F_HiRzDjYthfI0rFu9<)XZsedYWjfpO@H>Ij<4Y_`@C> z&&3JD0=owr$%1w83lT>XbZ`6b`2HV45%DCQNI&#+;pP3wu>%8$BlwuSW2QnE*=wpg zB;^{Mon>TX)G$85!pa(k%1Z$Fj|f(7+V1Vb;xHl<(Vcl!1%^4|cQWVe_RRtUl!G@_ z8gV<`zI{u9Hdb0%dU&1}r4UH%5r69Q_pdfMV^}G)My+bKu=Mw@ zA{VtHbk@7O-|l`)bLY@#dD%N)W`ea75}18wU#~%L2gkgrwKcm}v+%Io)&0}bf#*$= z9P*qBh@m>BrE^(((70UB&tKS8PD9Ui6W$i$r!;y|cMeZalLdqJ5LsnDeRW%tnKV@) z5o|(LOE4wr0)tlXU3x8rly)H3`4DO*oZ;v3x6PKAR5EXzLd_X7t84N81Io;K$k4AS zKLZ>lKO2aR9x31;a~EjUaK1bNVTWr93*C!jOhHZ!cX4AfG6#2(5Gbhs8UQCS{B{n= z`;Sft%}Pv23?&botv+pOd0rnfK;#ALXh5DxsZm-1x1pK8g;#-4sZWRj3pky*&!3y| zs)paai@->Gvho$z|AVf&( zPG`=HbtGdDS8QeOHArI~04V=>Q%zgM)X26)~9#JuiJ*ynNnBm{MxI)qE0n zgF0mD6gG*g16l`F;Tlv2q=u)6?�W`yi&$5)NGmkRkSt67QD6F&_jesP(qV+Vr%0 zahNjZ<(VEZCgUZT8@eKGXItP?DGX^53OzM7H40dPSilxXFnqi>ULoe0l16C|1$pe;eS zuzcnfp8{w{3m7IVNpo^>ki%)qC!mcx@PN*zN8;r9=Y?FN%N_;_r-P`}i1phc+9j)a zVpl&uDl!UgTU;9=6g+=O=@mijZ-yNO>=`mnJOCcx#z+9<&5C z40obNh@?=s+E#tai?68($n+T1zX**0LMIXWR00P>EFoPh+X4y`-5c`smbpk@62RMott^O^$tx6St5o>e_5lKb`&5CT}&~1q>M-j0V!=n!o zC!h|1gg8E@!X7mBTL@G!8Z}6J0LX3$WEBIAhpD5EdeE{%9n@3}A|3n(UYdgbIg55!O4K$Q&8yR;mZH z<9X9P(Nv@HQ7dm{sSw1h`m=VJnsC#Ltgx#rH_K9!1{=)Ark2jjqC(NYlD zC_wpvSilNk(}>b;9%FF%HsyDg@|Ymi(Zs%GVx3Biv!QIR#Lt2xHXE^@B4--%b@J}f z#sz(|f=mHI!oH3|d^qs=&w*uFkMnjGW>0tR9;`!i42>}YG01ol$>C5631)qn?qdee zExtWI_a~|ZJR%V?QUuLB8~JKvE@;QIf$BX5*#*a3DI}AFJphFo0YO4y^=@?Q0c{*H zNrRk*!i-ux*nm+c=-@TY%~NCwqOGOQfpdj!^C0oXK?WjD1D&!`?U0v&T?VVR7t(!u z7Z-YzyIm^ymfV=bz)&Is<|WA(6ka26)#==0^@{Df31-da968b9>9#e#ae;w>kbcKN zCQ$`XCPe*6@8_YXpxzMF)zt-sz|PGr%4>2c*Q6VdBNw$5kxC$-zZ)Ih2kgQ=UWH*H z6tw5io7IAv0fulcSP$Ox_7dEOvQfWXwwB7qkM_iHNW?HLo&NnR2*bC0C}N=K;X{i| zYsP*I5~5y^(xei``R>LsF`V+4`~RdbG3R_P@MI`(f$4d!@j22?M)#i#v^d-(10G0P zZJnJLVBR$W(G@3C1Ue)-sF~Ig4GB6BWhdTP%2OQR3qX_v$nd054HT4%3*ZfvlTi2Z zd!I6X$l+GG4b|7Rys~#bh7&_p*CQB^svPX~BnnzCrpLUjtm=51U<_br6AQ}C?)cmG ze^;W=%>H*JN-1!^kkOtc=YNP%+Urrb8G-yHV`1cVRabXo6M5u9_MAsp*eT|i-H8kvRcJQpbi5)K^sD0C!TA+gLHe-A5VK`4-jI0{Ov zf%>2tQQ|B@2oEBq!bf#`T`9}2a@s%(&n+8G9ipT`uO2!omeo#u!tpd|V9`kM(Sa5L zhr$*3A#*l+$1HP9_d82VXN3K1bou*@$91{l6FKo=st!#)VTZHytY zQidxqqp;S`-~W<{cxKacUT1zM{nb%JNWtLrF5;|s78)-qt74}WdGvx`0s;bh@U0@f zeO7$516J`4GG%j>nAMGHmz~<`^jsLx74P2j62$LM!IqAg zNg)XwI%Oum1I!_Wj!x+IJioz_5B}+w;kFV3a}+fMN`?1-T#Ur?DhpzIz-VRX=QC2^Hj+DjU8W@^ zCC}O@$(W;o^7P^a3k4F(eK0Ngi@04DE3&Z&qE(8@E9~mkm%4Va5BYWL z=FN@hVPWnTgSpK7{Ik0nxbuU7z5vzW6w=ufP?aYfd~9RBY!p57rrnRwnD{Sa6MI-7 zK 0.7 +# 256 * 128 * 8192 -> 10 +if __name__ == '__main__': + torch.manual_seed(0) + + # hparams + repeat = 16 + dim=8192 + layers = 4 + + batch_size = 256 * 128 + + # simulate forward pass + x = torch.randn(batch_size, dim, dtype=torch.float16).cuda() + + for _ in range(repeat // 2): + quantize_rowwise_nogroup(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(x) + torch.cuda.synchronize() + end = time.time() + + print(f"time: {(end - start) / repeat * 1000:.3f} ms") + + + + + + \ No newline at end of file