# you better arXiv:2001.08361v1 or else import tkinter as tk from tkinter import ttk import numpy as np import matplotlib.pyplot as plt from matplotlib.figure import Figure from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg # Constants for base values and exponents COMPUTE_BASE = 2.3e8 DATASET_BASE = 5.4e13 PARAMETERS_BASE = 8.8e13 EXPS = [-0.05, -0.095, -0.076] # Default input values and labels DEFAULTS = ["1e6", "1e9", "1e7"] LABELS = [ "Enter your compute (petaflop/days)", "Enter your dataset size (tokens)", "Enter your model size (number of non-embedding model parameters)" ] # Apply a dark theme to plots plt.style.use('dark_background') def update_function(entry, result_label, base, exp, label): """Calculate and update the isolated loss for a given input.""" try: value = float(entry.get() or 0) loss = (value / base) ** exp result_label.config(text=f"{label}: {int(value):,} → Isolated Loss = {loss:.4f}") update_all_graphs() except ValueError: result_label.config(text=f"{label}: Invalid input") def update_all_graphs(): """Refresh all graphs and the estimated loss display.""" update_main_graph() update_estimated_loss() update_mini_plots() def update_main_graph(): """Adjust the layout of the main graph.""" try: plt.tight_layout() plt.subplots_adjust(right=0.45) except Exception as e: print(f"Error updating main graph: {e}") def get_bottleneck_description(label): """Return a short description based on the full label.""" if label == "Enter your compute (petaflop/days)": return "Compute" elif label == "Enter your dataset size (tokens)": return "Dataset size" elif label == "Enter your model size (number of non-embedding model parameters)": return "Model size" else: return "Unknown" def update_estimated_loss(): """Calculate and display the estimated loss and identify the bottleneck.""" try: input_values = [float(entry.get() or 0) for entry in entries] losses = [(value / base) ** exp for value, base, exp in zip(input_values, [COMPUTE_BASE, DATASET_BASE, PARAMETERS_BASE], EXPS)] max_loss = max(losses) estimated_loss_label.config(text=f"Estimated LM Configuration Loss: {max_loss:.4f}") # Determine the bottleneck description bottleneck_index = losses.index(max_loss) bottleneck_label_text = get_bottleneck_description(LABELS[bottleneck_index]) bottleneck_label.config(text=f"Your model bottleneck is: {bottleneck_label_text}") except Exception as e: print(f"Error updating estimated loss: {e}") def update_mini_plots(): """Update the mini plots for each input parameter.""" try: input_values = [float(entry.get() or 0) for entry in entries] default_ranges = [np.logspace(5, 9, 100), np.logspace(8, 15, 100), np.logspace(6, 14, 100)] equations = [ r"$f(x) = \left(\frac{x}{2.3 \times 10^8}\right)^{-0.05}$", r"$g(x) = \left(\frac{x}{5.4 \times 10^{13}}\right)^{-0.095}$", r"$j(x) = \left(\frac{x}{8.8 \times 10^{13}}\right)^{-0.076}$" ] for i, (fig, value, base, exp, default_range, equation) in enumerate(zip(mini_figs, input_values, [COMPUTE_BASE, DATASET_BASE, PARAMETERS_BASE], EXPS, default_ranges, equations)): ax = fig.gca() ax.clear() x_min, x_max = min(default_range[0], value * 0.01), max(default_range[-1], value * 100) x_points = np.logspace(np.log10(x_min), np.log10(x_max), 100) y_points = (x_points / base) ** exp ax.semilogx(x_points, y_points, 'b-', alpha=0.5) ax.plot(value, (value / base) ** exp, 'ro') ax.grid(True, alpha=0.3) ax.set_xticks([x_min, value, x_max]) ax.set_xticklabels([f"{tick:.1e}" for tick in [x_min, value, x_max]], rotation=45) ax.text(0.95, 0.95, equation, transform=ax.transAxes, fontsize=10, verticalalignment='top', horizontalalignment='right', color='white') fig.tight_layout() mini_canvases[i].draw() except Exception as e: print(f"Error updating mini plots: {e}") # Create the main application window root = tk.Tk() root.title("Power Law Functions") root.geometry("900x600") # Create a frame for controls with a fixed height controls_frame = ttk.Frame(root, height=300) controls_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) controls_frame.pack_propagate(False) # Initialize frames, figures, and canvases for each input frames = [] mini_figs = [] mini_canvases = [] for _ in range(3): frame = ttk.Frame(controls_frame) frame.pack(fill=tk.X, pady=5) frames.append(frame) left_frame = ttk.Frame(frame) left_frame.pack(side=tk.LEFT, padx=(0, 10)) right_frame = ttk.Frame(frame) right_frame.pack(side=tk.RIGHT) mini_fig = Figure(figsize=(3.33, 1.5)) mini_figs.append(mini_fig) mini_canvas = FigureCanvasTkAgg(mini_fig, master=right_frame) mini_canvas.get_tk_widget().pack(side=tk.RIGHT) mini_canvases.append(mini_canvas) # Setup input sections with labels and entry fields entries = [] results = [] for i, (frame, label_text, default) in enumerate(zip(frames, LABELS, DEFAULTS)): left_frame = frame.winfo_children()[0] tk.Label(left_frame, text=label_text).pack(anchor='w') entry = ttk.Entry(left_frame, width=20) entry.pack(anchor='w') entry.insert(0, default) entries.append(entry) result_label = tk.Label(left_frame, text="f(x) = 0") result_label.pack(anchor='w') results.append(result_label) # Results section for estimated loss and bottleneck estimated_loss_label = tk.Label(root, text="Estimated LM Configuration Loss: 0", font=('Arial', 20, 'bold')) estimated_loss_label.pack(pady=(20,5)) bottleneck_label = tk.Label(root, text="Your model bottleneck is: None", font=('Arial', 20, 'bold')) bottleneck_label.pack(pady=(5,20)) # Bind entry fields to update functions entry1, entry2, entry3 = entries result1, result2, result3 = results entry1.bind('', lambda e: update_function(entry1, result1, COMPUTE_BASE, EXPS[0], "Compute")) entry2.bind('', lambda e: update_function(entry2, result2, DATASET_BASE, EXPS[1], "Dataset_size")) entry3.bind('', lambda e: update_function(entry3, result3, PARAMETERS_BASE, EXPS[2], "Model_size")) # Initial update for each function update_function(entry1, result1, COMPUTE_BASE, EXPS[0], "Compute") update_function(entry2, result2, DATASET_BASE, EXPS[1], "Dataset") update_function(entry3, result3, PARAMETERS_BASE, EXPS[2], "Parameters") # Start the Tkinter main loop root.mainloop()