Sphere Function with Various Algorithms¶

  1"""Runs various QD algorithms on the Sphere function.
2
3Install the following dependencies before running this example:
4    pip install ribs[visualize] tqdm fire
5
6The sphere function in this example is adapted from Section 4 of Fontaine 2020
7(https://arxiv.org/abs/1912.02400). Namely, each solution value is clipped to
8the range [-5.12, 5.12], and the optimum is moved from [0,..] to [0.4 * 5.12 =
92.048,..]. Furthermore, the objectives are normalized to the range [0,
10100] where 100 is the maximum and corresponds to 0 on the original sphere
11function.
12
13There are two measures in this example. The first is the sum of the first n/2
14clipped values of the solution, and the second is the sum of the last n/2
15clipped values of the solution. Having each measure depend equally on several
16values in the solution space makes the problem more difficult (refer to
18
19The supported algorithms are:
20- map_elites: GridArchive with GaussianEmitter.
21- line_map_elites: GridArchive with IsoLineEmitter.
22- cvt_map_elites: CVTArchive with GaussianEmitter.
23- line_cvt_map_elites: CVTArchive with IsoLineEmitter.
24- me_map_elites: MAP-Elites with Bandit Scheduler.
25- cma_me_imp: GridArchive with EvolutionStrategyEmitter using
26  TwoStageImprovmentRanker.
27- cma_me_imp_mu: GridArchive with EvolutionStrategyEmitter using
28  TwoStageImprovmentRanker and mu selection rule.
29- cma_me_rd: GridArchive with EvolutionStrategyEmitter using
30  RandomDirectionRanker.
31- cma_me_rd_mu: GridArchive with EvolutionStrategyEmitter using
32  TwoStageRandomDirectionRanker and mu selection rule.
33- cma_me_opt: GridArchive with EvolutionStrategyEmitter using ObjectiveRanker
34  with mu selection rule.
35- cma_me_mixed: GridArchive with EvolutionStrategyEmitter, where half (7) of
36  the emitter are using TwoStageRandomDirectionRanker and half (8) are
37  TwoStageImprovementRanker.
38- cma_mega: GridArchive with GradientArborescenceEmitter.
39- cma_mega_adam: GridArchive with GradientArborescenceEmitter using Adam
40  Optimizer.
41- cma_mae: GridArchive (learning_rate = 0.01) with EvolutionStrategyEmitter
42  using ImprovementRanker.
43- cma_maega: GridArchive (learning_rate = 0.01) with
45
46All algorithms use 15 emitters, each with a batch size of 37. Each one runs for
474500 iterations for a total of 15 * 37 * 4500 ~= 2.5M evaluations.
48
49Notes:
50- cma_mega and cma_mega_adam use only one emitter and run for 10,000
51  iterations. This is to be consistent with the paper (Fontaine 2021
52  <https://arxiv.org/abs/2106.03894>_) in which these algorithms were proposed.
53- cma_mae and cma_maega run for 10,000 iterations as well.
54- CVTArchive in this example uses 10,000 cells, as opposed to the 250,000
55  (500x500) in the GridArchive, so it is not fair to directly compare
56  cvt_map_elites and line_cvt_map_elites to the other algorithms.
57
58Outputs are saved in the sphere_output/ directory by default. The archive is
59saved as a CSV named {algorithm}_{dim}_archive.csv, while snapshots of the
60heatmap are saved as {algorithm}_{dim}_heatmap_{iteration}.png. Metrics about
61the run are also saved in {algorithm}_{dim}_metrics.json, and plots of the
62metrics are saved in PNG's with the name {algorithm}_{dim}_metric_name.png.
63
64To generate a video of the heatmap from the heatmap images, use a tool like
65ffmpeg. For example, the following will generate a 6FPS video showing the
66heatmap for cma_me_imp with 20 dims.
67
68    ffmpeg -r 6 -i "sphere_output/cma_me_imp_20_heatmap_%*.png" \
69        sphere_output/cma_me_imp_20_heatmap_video.mp4
70
71Usage (see sphere_main function for all args or run python sphere.py --help):
72    python sphere.py ALGORITHM
73Example:
74    python sphere.py map_elites
75
76    # To make numpy and sklearn run single-threaded, set env variables for BLAS
77    # and OpenMP:
79Help:
80    python sphere.py --help
81"""
82import copy
83import json
84import time
85from pathlib import Path
86
87import fire
88import matplotlib.pyplot as plt
89import numpy as np
90import tqdm
91
92from ribs.archives import CVTArchive, GridArchive
93from ribs.emitters import (EvolutionStrategyEmitter, GaussianEmitter,
95from ribs.schedulers import BanditScheduler, Scheduler
96from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
97
98CONFIG = {
99    "map_elites": {
100        "dim": 20,
101        "iters": 4500,
102        "archive_dims": (500, 500),
103        "use_result_archive": False,
104        "is_dqd": False,
105        "batch_size": 37,
106        "archive": {
107            "class": GridArchive,
108            "kwargs": {
109                "threshold_min": -np.inf
110            }
111        },
112        "emitters": [{
113            "class": GaussianEmitter,
114            "kwargs": {
115                "sigma": 0.5
116            },
117            "num_emitters": 15
118        }],
119        "scheduler": {
120            "class": Scheduler,
121            "kwargs": {}
122        }
123    },
124    "line_map_elites": {
125        "dim": 20,
126        "iters": 4500,
127        "archive_dims": (500, 500),
128        "use_result_archive": False,
129        "is_dqd": False,
130        "batch_size": 37,
131        "archive": {
132            "class": GridArchive,
133            "kwargs": {
134                "threshold_min": -np.inf
135            }
136        },
137        "emitters": [{
138            "class": IsoLineEmitter,
139            "kwargs": {
140                "iso_sigma": 0.1,
141                "line_sigma": 0.2
142            },
143            "num_emitters": 15
144        }],
145        "scheduler": {
146            "class": Scheduler,
147            "kwargs": {}
148        }
149    },
150    "cvt_map_elites": {
151        "dim": 20,
152        "iters": 4500,
153        "archive_dims": (500, 500),
154        "use_result_archive": False,
155        "is_dqd": False,
156        "batch_size": 37,
157        "archive": {
158            "class": CVTArchive,
159            "kwargs": {
160                "cells": 10_000,
161                "samples": 100_000,
162                "use_kd_tree": True
163            }
164        },
165        "emitters": [{
166            "class": GaussianEmitter,
167            "kwargs": {
168                "sigma": 0.5
169            },
170            "num_emitters": 15
171        }],
172        "scheduler": {
173            "class": Scheduler,
174            "kwargs": {}
175        }
176    },
177    "line_cvt_map_elites": {
178        "dim": 20,
179        "iters": 4500,
180        "archive_dims": (500, 500),
181        "use_result_archive": False,
182        "is_dqd": False,
183        "batch_size": 37,
184        "archive": {
185            "class": CVTArchive,
186            "kwargs": {
187                "cells": 10_000,
188                "samples": 100_000,
189                "use_kd_tree": True
190            }
191        },
192        "emitters": [{
193            "class": IsoLineEmitter,
194            "kwargs": {
195                "iso_sigma": 0.1,
196                "line_sigma": 0.2
197            },
198            "num_emitters": 15
199        }],
200        "scheduler": {
201            "class": Scheduler,
202            "kwargs": {}
203        }
204    },
205    "me_map_elites": {
206        "dim": 100,
207        "iters": 20_000,
208        "archive_dims": (100, 100),
209        "use_result_archive": False,
210        "is_dqd": False,
211        "batch_size": 50,
212        "archive": {
213            "class": GridArchive,
214            "kwargs": {
215                "threshold_min": -np.inf
216            }
217        },
218        "emitters": [{
219            "class": EvolutionStrategyEmitter,
220            "kwargs": {
221                "sigma0": 0.5,
222                "ranker": "obj"
223            },
224            "num_emitters": 12
225        }, {
226            "class": EvolutionStrategyEmitter,
227            "kwargs": {
228                "sigma0": 0.5,
229                "ranker": "2rd"
230            },
231            "num_emitters": 12
232        }, {
233            "class": EvolutionStrategyEmitter,
234            "kwargs": {
235                "sigma0": 0.5,
236                "ranker": "2imp"
237            },
238            "num_emitters": 12
239        }, {
240            "class": IsoLineEmitter,
241            "kwargs": {
242                "iso_sigma": 0.01,
243                "line_sigma": 0.1
244            },
245            "num_emitters": 12
246        }],
247        "scheduler": {
248            "class": BanditScheduler,
249            "kwargs": {
250                "num_active": 12,
251                "reselect": "terminated"
252            }
253        }
254    },
255    "cma_me_mixed": {
256        "dim": 20,
257        "iters": 4500,
258        "archive_dims": (500, 500),
259        "use_result_archive": False,
260        "is_dqd": False,
261        "batch_size": 37,
262        "archive": {
263            "class": GridArchive,
264            "kwargs": {
265                "threshold_min": -np.inf
266            }
267        },
268        "emitters": [{
269            "class": EvolutionStrategyEmitter,
270            "kwargs": {
271                "sigma0": 0.5,
272                "ranker": "2rd"
273            },
274            "num_emitters": 7
275        }, {
276            "class": EvolutionStrategyEmitter,
277            "kwargs": {
278                "sigma0": 0.5,
279                "ranker": "2imp"
280            },
281            "num_emitters": 8
282        }],
283        "scheduler": {
284            "class": Scheduler,
285            "kwargs": {}
286        }
287    },
288    "cma_me_imp": {
289        "dim": 20,
290        "iters": 4500,
291        "archive_dims": (500, 500),
292        "use_result_archive": False,
293        "is_dqd": False,
294        "batch_size": 37,
295        "archive": {
296            "class": GridArchive,
297            "kwargs": {
298                "threshold_min": -np.inf
299            }
300        },
301        "emitters": [{
302            "class": EvolutionStrategyEmitter,
303            "kwargs": {
304                "sigma0": 0.5,
305                "ranker": "2imp",
306                "selection_rule": "filter",
307                "restart_rule": "no_improvement"
308            },
309            "num_emitters": 15
310        }],
311        "scheduler": {
312            "class": Scheduler,
313            "kwargs": {}
314        }
315    },
316    "cma_me_imp_mu": {
317        "dim": 20,
318        "iters": 4500,
319        "archive_dims": (500, 500),
320        "use_result_archive": False,
321        "is_dqd": False,
322        "batch_size": 37,
323        "archive": {
324            "class": GridArchive,
325            "kwargs": {
326                "threshold_min": -np.inf
327            }
328        },
329        "emitters": [{
330            "class": EvolutionStrategyEmitter,
331            "kwargs": {
332                "sigma0": 0.5,
333                "ranker": "2imp",
334                "selection_rule": "mu",
335                "restart_rule": "no_improvement"
336            },
337            "num_emitters": 15
338        }],
339        "scheduler": {
340            "class": Scheduler,
341            "kwargs": {}
342        }
343    },
344    "cma_me_rd": {
345        "dim": 20,
346        "iters": 4500,
347        "archive_dims": (500, 500),
348        "use_result_archive": False,
349        "is_dqd": False,
350        "batch_size": 37,
351        "archive": {
352            "class": GridArchive,
353            "kwargs": {
354                "threshold_min": -np.inf
355            }
356        },
357        "emitters": [{
358            "class": EvolutionStrategyEmitter,
359            "kwargs": {
360                "sigma0": 0.5,
361                "ranker": "2rd",
362                "selection_rule": "filter",
363                "restart_rule": "no_improvement"
364            },
365            "num_emitters": 15
366        }],
367        "scheduler": {
368            "class": Scheduler,
369            "kwargs": {}
370        }
371    },
372    "cma_me_rd_mu": {
373        "dim": 20,
374        "iters": 4500,
375        "archive_dims": (500, 500),
376        "use_result_archive": False,
377        "is_dqd": False,
378        "batch_size": 37,
379        "archive": {
380            "class": GridArchive,
381            "kwargs": {
382                "threshold_min": -np.inf
383            }
384        },
385        "emitters": [{
386            "class": EvolutionStrategyEmitter,
387            "kwargs": {
388                "sigma0": 0.5,
389                "ranker": "2rd",
390                "selection_rule": "mu",
391                "restart_rule": "no_improvement"
392            },
393            "num_emitters": 15
394        }],
395        "scheduler": {
396            "class": Scheduler,
397            "kwargs": {}
398        }
399    },
400    "cma_me_opt": {
401        "dim": 20,
402        "iters": 4500,
403        "archive_dims": (500, 500),
404        "use_result_archive": False,
405        "is_dqd": False,
406        "batch_size": 37,
407        "archive": {
408            "class": GridArchive,
409            "kwargs": {
410                "threshold_min": -np.inf
411            }
412        },
413        "emitters": [{
414            "class": EvolutionStrategyEmitter,
415            "kwargs": {
416                "sigma0": 0.5,
417                "ranker": "obj",
418                "selection_rule": "mu",
419                "restart_rule": "basic"
420            },
421            "num_emitters": 15
422        }],
423        "scheduler": {
424            "class": Scheduler,
425            "kwargs": {}
426        }
427    },
428    "cma_mega": {
429        "dim": 1_000,
430        "iters": 10_000,
431        "archive_dims": (100, 100),
432        "use_result_archive": False,
433        "is_dqd": True,
434        "batch_size": 36,
435        "archive": {
436            "class": GridArchive,
437            "kwargs": {
438                "threshold_min": -np.inf
439            }
440        },
441        "emitters": [{
443            "kwargs": {
444                "sigma0": 10.0,
445                "lr": 1.0,
447                "selection_rule": "mu"
448            },
449            "num_emitters": 1
450        }],
451        "scheduler": {
452            "class": Scheduler,
453            "kwargs": {}
454        }
455    },
457        "dim": 1_000,
458        "iters": 10_000,
459        "archive_dims": (100, 100),
460        "use_result_archive": False,
461        "is_dqd": True,
462        "batch_size": 36,
463        "archive": {
464            "class": GridArchive,
465            "kwargs": {
466                "threshold_min": -np.inf
467            }
468        },
469        "emitters": [{
471            "kwargs": {
472                "sigma0": 10.0,
473                "lr": 0.002,
475                "selection_rule": "mu"
476            },
477            "num_emitters": 1
478        }],
479        "scheduler": {
480            "class": Scheduler,
481            "kwargs": {}
482        }
483    },
484    "cma_mae": {
485        "dim": 100,
486        "iters": 10_000,
487        "archive_dims": (100, 100),
488        "use_result_archive": True,
489        "is_dqd": False,
490        "batch_size": 37,
491        "archive": {
492            "class": GridArchive,
493            "kwargs": {
494                "threshold_min": 0,
495                "learning_rate": 0.01
496            }
497        },
498        "emitters": [{
499            "class": EvolutionStrategyEmitter,
500            "kwargs": {
501                "sigma0": 0.5,
502                "ranker": "imp",
503                "selection_rule": "mu",
504                "restart_rule": "basic"
505            },
506            "num_emitters": 15
507        }],
508        "scheduler": {
509            "class": Scheduler,
510            "kwargs": {}
511        }
512    },
513    "cma_maega": {
514        "dim": 1_000,
515        "iters": 10_000,
516        "archive_dims": (100, 100),
517        "use_result_archive": True,
518        "is_dqd": True,
519        "batch_size": 37,
520        "archive": {
521            "class": GridArchive,
522            "kwargs": {
523                "threshold_min": 0,
524                "learning_rate": 0.01
525            }
526        },
527        "emitters": [{
529            "kwargs": {
530                "sigma0": 10.0,
531                "lr": 1.0,
532                "ranker": "imp",
534                "restart_rule": "basic"
535            },
536            "num_emitters": 15
537        }],
538        "scheduler": {
539            "class": Scheduler,
540            "kwargs": {}
541        }
542    }
543}
544
545
546def sphere(solution_batch):
547    """Sphere function evaluation and measures for a batch of solutions.
548
549    Args:
550        solution_batch (np.ndarray): (batch_size, dim) batch of solutions.
551    Returns:
552        objective_batch (np.ndarray): (batch_size,) batch of objectives.
553        objective_grad_batch (np.ndarray): (batch_size, solution_dim) batch of
555        measures_batch (np.ndarray): (batch_size, 2) batch of measures.
556        measures_grad_batch (np.ndarray): (batch_size, 2, solution_dim) batch of
558    """
559    dim = solution_batch.shape[1]
560
561    # Shift the Sphere function so that the optimal value is at x_i = 2.048.
562    sphere_shift = 5.12 * 0.4
563
564    # Normalize the objective to the range [0, 100] where 100 is optimal.
565    best_obj = 0.0
566    worst_obj = (-5.12 - sphere_shift)**2 * dim
567    raw_obj = np.sum(np.square(solution_batch - sphere_shift), axis=1)
568    objective_batch = (raw_obj - worst_obj) / (best_obj - worst_obj) * 100
569
570    # Compute gradient of the objective.
571    objective_grad_batch = -2 * (solution_batch - sphere_shift)
572
573    # Calculate measures.
574    clipped = solution_batch.copy()
575    clip_mask = (clipped < -5.12) | (clipped > 5.12)
577    measures_batch = np.concatenate(
578        (
579            np.sum(clipped[:, :dim // 2], axis=1, keepdims=True),
580            np.sum(clipped[:, dim // 2:], axis=1, keepdims=True),
581        ),
582        axis=1,
583    )
584
585    # Compute gradient of the measures.
586    derivatives = np.ones(solution_batch.shape)
588
589    mask_0 = np.concatenate((np.ones(dim // 2), np.zeros(dim - dim // 2)))
590    mask_1 = np.concatenate((np.zeros(dim // 2), np.ones(dim - dim // 2)))
591
592    d_measure0 = derivatives * mask_0
593    d_measure1 = derivatives * mask_1
594
595    measures_grad_batch = np.stack((d_measure0, d_measure1), axis=1)
596
597    return (
598        objective_batch,
600        measures_batch,
602    )
603
604
605def create_scheduler(config, algorithm, seed=None):
606    """Creates a scheduler based on the algorithm.
607
608    Args:
609        config (dict): Configuration dictionary with parameters for the various
610            components.
611        algorithm (string): Name of the algorithm
612        seed (int): Main seed or the various components.
613    Returns:
614        ribs.schedulers.Scheduler: A ribs scheduler for running the algorithm.
615    """
616    solution_dim = config["dim"]
617    archive_dims = config["archive_dims"]
618    learning_rate = 1.0 if "learning_rate" not in config["archive"][
619        "kwargs"] else config["archive"]["kwargs"]["learning_rate"]
620    use_result_archive = config["use_result_archive"]
621    max_bound = solution_dim / 2 * 5.12
622    bounds = [(-max_bound, max_bound), (-max_bound, max_bound)]
623    initial_sol = np.zeros(solution_dim)
624    mode = "batch"
625
626    # Create archive.
627    archive_class = config["archive"]["class"]
628    if archive_class == GridArchive:
629        archive = archive_class(solution_dim=solution_dim,
630                                ranges=bounds,
631                                dims=archive_dims,
632                                seed=seed,
633                                **config["archive"]["kwargs"])
634    else:
635        archive = archive_class(solution_dim=solution_dim,
636                                ranges=bounds,
637                                **config["archive"]["kwargs"])
638
639    # Create result archive.
640    result_archive = None
641    if use_result_archive:
642        result_archive = GridArchive(solution_dim=solution_dim,
643                                     dims=archive_dims,
644                                     ranges=bounds,
645                                     seed=seed)
646
647    # Create emitters. Each emitter needs a different seed so that they do not
648    # all do the same thing, hence we create an rng here to generate seeds. The
649    # rng may be seeded with None or with a user-provided seed.
650    rng = np.random.default_rng(seed)
651    emitters = []
652    for e in config["emitters"]:
653        emitter_class = e["class"]
654        emitters += [
655            emitter_class(archive,
656                          x0=initial_sol,
657                          **e["kwargs"],
658                          batch_size=config["batch_size"],
659                          seed=s)
660            for s in rng.integers(0, 1_000_000, e["num_emitters"])
661        ]
662
663    # Create Scheduler
664    scheduler_class = config["scheduler"]["class"]
665    scheduler = scheduler_class(archive,
666                                emitters,
667                                result_archive=result_archive,
669                                **config["scheduler"]["kwargs"])
670    scheduler_name = scheduler.__class__.__name__
671
672    print(f"Create {scheduler_name} for {algorithm} with learning rate "
673          f"{learning_rate} and add mode {mode}, using solution dim "
674          f"{solution_dim}, archive dims {archive_dims}, and "
675          f"{len(emitters)} emitters.")
676    return scheduler
677
678
679def save_heatmap(archive, heatmap_path):
680    """Saves a heatmap of the archive to the given path.
681
682    Args:
683        archive (GridArchive or CVTArchive): The archive to save.
684        heatmap_path: Image path for the heatmap.
685    """
686    if isinstance(archive, GridArchive):
687        plt.figure(figsize=(8, 6))
688        grid_archive_heatmap(archive, vmin=0, vmax=100)
689        plt.tight_layout()
690        plt.savefig(heatmap_path)
691    elif isinstance(archive, CVTArchive):
692        plt.figure(figsize=(16, 12))
693        cvt_archive_heatmap(archive, vmin=0, vmax=100)
694        plt.tight_layout()
695        plt.savefig(heatmap_path)
696    plt.close(plt.gcf())
697
698
699def sphere_main(algorithm,
700                dim=None,
701                itrs=None,
702                archive_dims=None,
703                learning_rate=None,
704                outdir="sphere_output",
705                log_freq=250,
706                seed=None):
707    """Demo on the Sphere function.
708
709    Args:
710        algorithm (str): Name of the algorithm.
711        dim (int): Dimensionality of the sphere function.
712        itrs (int): Iterations to run.
713        archive_dims (tuple): Dimensionality of the archive.
714        learning_rate (float): The archive learning rate.
715        outdir (str): Directory to save output.
716        log_freq (int): Number of iterations to wait before recording metrics
717            and saving heatmap.
718        seed (int): Seed for the algorithm. By default, there is no seed.
719    """
720    config = copy.deepcopy(CONFIG[algorithm])
721
722    # Use default dim for each algorithm.
723    if dim is not None:
724        config["dim"] = dim
725
726    # Use default itrs for each algorithm.
727    if itrs is not None:
728        config["iters"] = itrs
729
730    # Use default archive_dim for each algorithm.
731    if archive_dims is not None:
732        config["archive_dims"] = archive_dims
733
734    # Use default learning_rate for each algorithm.
735    if learning_rate is not None:
736        config["archive"]["kwargs"]["learning_rate"] = learning_rate
737
738    name = f"{algorithm}_{config['dim']}"
739    outdir = Path(outdir)
740    if not outdir.is_dir():
741        outdir.mkdir()
742
743    scheduler = create_scheduler(config, algorithm, seed=seed)
744    result_archive = scheduler.result_archive
745    is_dqd = config["is_dqd"]
746    itrs = config["iters"]
747    metrics = {
748        "QD Score": {
749            "x": [0],
750            "y": [0.0],
751        },
752        "Archive Coverage": {
753            "x": [0],
754            "y": [0.0],
755        },
756    }
757
758    non_logging_time = 0.0
759    save_heatmap(result_archive, str(outdir / f"{name}_heatmap_{0:05d}.png"))
760
761    for itr in tqdm.trange(1, itrs + 1):
762        itr_start = time.time()
763
764        if is_dqd:
769            jacobian_batch = np.concatenate(
771            scheduler.tell_dqd(objective_batch, measures_batch, jacobian_batch)
772
774        objective_batch, _, measure_batch, _ = sphere(solution_batch)
775        scheduler.tell(objective_batch, measure_batch)
776        non_logging_time += time.time() - itr_start
777
778        # Logging and output.
779        final_itr = itr == itrs
780        if itr % log_freq == 0 or final_itr:
781            if final_itr:
782                result_archive.as_pandas(include_solutions=final_itr).to_csv(
783                    outdir / f"{name}_archive.csv")
784
785            # Record and display metrics.
786            metrics["QD Score"]["x"].append(itr)
787            metrics["QD Score"]["y"].append(result_archive.stats.qd_score)
788            metrics["Archive Coverage"]["x"].append(itr)
789            metrics["Archive Coverage"]["y"].append(
790                result_archive.stats.coverage)
791            tqdm.tqdm.write(
792                f"Iteration {itr} | Archive Coverage: "
793                f"{metrics['Archive Coverage']['y'][-1] * 100:.3f}% "
794                f"QD Score: {metrics['QD Score']['y'][-1]:.3f}")
795
796            save_heatmap(result_archive,
797                         str(outdir / f"{name}_heatmap_{itr:05d}.png"))
798
799    # Plot metrics.
800    print(f"Algorithm Time (Excludes Logging and Setup): {non_logging_time}s")
801    for metric in metrics:
802        plt.plot(metrics[metric]["x"], metrics[metric]["y"])
803        plt.title(metric)
804        plt.xlabel("Iteration")
805        plt.savefig(
806            str(outdir / f"{name}_{metric.lower().replace(' ', '_')}.png"))
807        plt.clf()
808    with (outdir / f"{name}_metrics.json").open("w") as file:
809        json.dump(metrics, file, indent=2)
810
811
812if __name__ == '__main__':
813    fire.Fire(sphere_main)