Sphere Function with Various Algorithms

  1"""Runs various QD algorithms on the Sphere function.
  2
  3Install the following dependencies before running this example:
  4    pip install ribs[visualize] tqdm fire
  5
  6The sphere function in this example is adapted from Section 4 of Fontaine 2020
  7(https://arxiv.org/abs/1912.02400). Namely, each solution value is clipped to
  8the range [-5.12, 5.12], and the optimum is moved from [0,..] to [0.4 * 5.12 =
  92.048,..]. Furthermore, the objectives are normalized to the range [0,
 10100] where 100 is the maximum and corresponds to 0 on the original sphere
 11function.
 12
 13There are two measures in this example. The first is the sum of the first n/2
 14clipped values of the solution, and the second is the sum of the last n/2
 15clipped values of the solution. Having each measure depend equally on several
 16values in the solution space makes the problem more difficult (refer to
 17Fontaine 2020 for more info).
 18
 19The supported algorithms are:
 20- `map_elites`: GridArchive with GaussianEmitter.
 21- `line_map_elites`: GridArchive with IsoLineEmitter.
 22- `cvt_map_elites`: CVTArchive with GaussianEmitter.
 23- `line_cvt_map_elites`: CVTArchive with IsoLineEmitter.
 24- `me_map_elites`: MAP-Elites with Bandit Scheduler.
 25- `cma_me_imp`: GridArchive with EvolutionStrategyEmitter using
 26  TwoStageImprovmentRanker.
 27- `cma_me_imp_mu`: GridArchive with EvolutionStrategyEmitter using
 28  TwoStageImprovmentRanker and mu selection rule.
 29- `cma_me_rd`: GridArchive with EvolutionStrategyEmitter using
 30  RandomDirectionRanker.
 31- `cma_me_rd_mu`: GridArchive with EvolutionStrategyEmitter using
 32  TwoStageRandomDirectionRanker and mu selection rule.
 33- `cma_me_opt`: GridArchive with EvolutionStrategyEmitter using ObjectiveRanker
 34  with mu selection rule.
 35- `cma_me_mixed`: GridArchive with EvolutionStrategyEmitter, where half (7) of
 36  the emitter are using TwoStageRandomDirectionRanker and half (8) are
 37  TwoStageImprovementRanker.
 38- `og_map_elites`: GridArchive with GradientOperatorEmitter, does not use
 39  measure gradients.
 40- `omg_mega`: GridArchive with GradientOperatorEmitter, uses measure gradients.
 41- `cma_mega`: GridArchive with GradientArborescenceEmitter.
 42- `cma_mega_adam`: GridArchive with GradientArborescenceEmitter using Adam
 43  Optimizer.
 44- `cma_mae`: GridArchive (learning_rate = 0.01) with EvolutionStrategyEmitter
 45  using ImprovementRanker.
 46- `cma_maega`: GridArchive (learning_rate = 0.01) with
 47  GradientArborescenceEmitter using ImprovementRanker.
 48
 49The parameters for each algorithm are stored in CONFIG. The parameters
 50reproduce the experiments presented in the paper in which each algorithm is
 51introduced.
 52
 53Outputs are saved in the `sphere_output/` directory by default. The archive is
 54saved as a CSV named `{algorithm}_{dim}_archive.csv`, while snapshots of the
 55heatmap are saved as `{algorithm}_{dim}_heatmap_{iteration}.png`. Metrics about
 56the run are also saved in `{algorithm}_{dim}_metrics.json`, and plots of the
 57metrics are saved in PNG's with the name `{algorithm}_{dim}_metric_name.png`.
 58
 59To generate a video of the heatmap from the heatmap images, use a tool like
 60ffmpeg. For example, the following will generate a 6FPS video showing the
 61heatmap for cma_me_imp with 20 dims.
 62
 63    ffmpeg -r 6 -i "sphere_output/cma_me_imp_20_heatmap_%*.png" \
 64        sphere_output/cma_me_imp_20_heatmap_video.mp4
 65
 66Usage (see sphere_main function for all args or run `python sphere.py --help`):
 67    python sphere.py ALGORITHM
 68Example:
 69    python sphere.py map_elites
 70
 71    # To make numpy and sklearn run single-threaded, set env variables for BLAS
 72    # and OpenMP:
 73    OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 python sphere.py map_elites 20
 74Help:
 75    python sphere.py --help
 76"""
 77import copy
 78import json
 79import time
 80from pathlib import Path
 81
 82import fire
 83import matplotlib.pyplot as plt
 84import numpy as np
 85import tqdm
 86
 87from ribs.archives import CVTArchive, GridArchive
 88from ribs.emitters import (EvolutionStrategyEmitter, GaussianEmitter,
 89                           GradientArborescenceEmitter, GradientOperatorEmitter,
 90                           IsoLineEmitter)
 91from ribs.schedulers import BanditScheduler, Scheduler
 92from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
 93
 94CONFIG = {
 95    "map_elites": {
 96        "dim": 20,
 97        "iters": 4500,
 98        "archive_dims": (500, 500),
 99        "use_result_archive": False,
100        "is_dqd": False,
101        "batch_size": 37,
102        "archive": {
103            "class": GridArchive,
104            "kwargs": {
105                "threshold_min": -np.inf
106            }
107        },
108        "emitters": [{
109            "class": GaussianEmitter,
110            "kwargs": {
111                "sigma": 0.5
112            },
113            "num_emitters": 15
114        }],
115        "scheduler": {
116            "class": Scheduler,
117            "kwargs": {}
118        }
119    },
120    "line_map_elites": {
121        "dim": 20,
122        "iters": 4500,
123        "archive_dims": (500, 500),
124        "use_result_archive": False,
125        "is_dqd": False,
126        "batch_size": 37,
127        "archive": {
128            "class": GridArchive,
129            "kwargs": {
130                "threshold_min": -np.inf
131            }
132        },
133        "emitters": [{
134            "class": IsoLineEmitter,
135            "kwargs": {
136                "iso_sigma": 0.1,
137                "line_sigma": 0.2
138            },
139            "num_emitters": 15
140        }],
141        "scheduler": {
142            "class": Scheduler,
143            "kwargs": {}
144        }
145    },
146    "cvt_map_elites": {
147        "dim": 20,
148        "iters": 4500,
149        "archive_dims": (500, 500),
150        "use_result_archive": False,
151        "is_dqd": False,
152        "batch_size": 37,
153        "archive": {
154            "class": CVTArchive,
155            "kwargs": {
156                "cells": 10_000,
157                "samples": 100_000,
158                "use_kd_tree": True
159            }
160        },
161        "emitters": [{
162            "class": GaussianEmitter,
163            "kwargs": {
164                "sigma": 0.5
165            },
166            "num_emitters": 15
167        }],
168        "scheduler": {
169            "class": Scheduler,
170            "kwargs": {}
171        }
172    },
173    "line_cvt_map_elites": {
174        "dim": 20,
175        "iters": 4500,
176        "archive_dims": (500, 500),
177        "use_result_archive": False,
178        "is_dqd": False,
179        "batch_size": 37,
180        "archive": {
181            "class": CVTArchive,
182            "kwargs": {
183                "cells": 10_000,
184                "samples": 100_000,
185                "use_kd_tree": True
186            }
187        },
188        "emitters": [{
189            "class": IsoLineEmitter,
190            "kwargs": {
191                "iso_sigma": 0.1,
192                "line_sigma": 0.2
193            },
194            "num_emitters": 15
195        }],
196        "scheduler": {
197            "class": Scheduler,
198            "kwargs": {}
199        }
200    },
201    "me_map_elites": {
202        "dim": 100,
203        "iters": 20_000,
204        "archive_dims": (100, 100),
205        "use_result_archive": False,
206        "is_dqd": False,
207        "batch_size": 50,
208        "archive": {
209            "class": GridArchive,
210            "kwargs": {
211                "threshold_min": -np.inf
212            }
213        },
214        "emitters": [{
215            "class": EvolutionStrategyEmitter,
216            "kwargs": {
217                "sigma0": 0.5,
218                "ranker": "obj"
219            },
220            "num_emitters": 12
221        }, {
222            "class": EvolutionStrategyEmitter,
223            "kwargs": {
224                "sigma0": 0.5,
225                "ranker": "2rd"
226            },
227            "num_emitters": 12
228        }, {
229            "class": EvolutionStrategyEmitter,
230            "kwargs": {
231                "sigma0": 0.5,
232                "ranker": "2imp"
233            },
234            "num_emitters": 12
235        }, {
236            "class": IsoLineEmitter,
237            "kwargs": {
238                "iso_sigma": 0.01,
239                "line_sigma": 0.1
240            },
241            "num_emitters": 12
242        }],
243        "scheduler": {
244            "class": BanditScheduler,
245            "kwargs": {
246                "num_active": 12,
247                "reselect": "terminated"
248            }
249        }
250    },
251    "cma_me_mixed": {
252        "dim": 20,
253        "iters": 4500,
254        "archive_dims": (500, 500),
255        "use_result_archive": False,
256        "is_dqd": False,
257        "batch_size": 37,
258        "archive": {
259            "class": GridArchive,
260            "kwargs": {
261                "threshold_min": -np.inf
262            }
263        },
264        "emitters": [{
265            "class": EvolutionStrategyEmitter,
266            "kwargs": {
267                "sigma0": 0.5,
268                "ranker": "2rd"
269            },
270            "num_emitters": 7
271        }, {
272            "class": EvolutionStrategyEmitter,
273            "kwargs": {
274                "sigma0": 0.5,
275                "ranker": "2imp"
276            },
277            "num_emitters": 8
278        }],
279        "scheduler": {
280            "class": Scheduler,
281            "kwargs": {}
282        }
283    },
284    "cma_me_imp": {
285        "dim": 20,
286        "iters": 4500,
287        "archive_dims": (500, 500),
288        "use_result_archive": False,
289        "is_dqd": False,
290        "batch_size": 37,
291        "archive": {
292            "class": GridArchive,
293            "kwargs": {
294                "threshold_min": -np.inf
295            }
296        },
297        "emitters": [{
298            "class": EvolutionStrategyEmitter,
299            "kwargs": {
300                "sigma0": 0.5,
301                "ranker": "2imp",
302                "selection_rule": "filter",
303                "restart_rule": "no_improvement"
304            },
305            "num_emitters": 15
306        }],
307        "scheduler": {
308            "class": Scheduler,
309            "kwargs": {}
310        }
311    },
312    "cma_me_imp_mu": {
313        "dim": 20,
314        "iters": 4500,
315        "archive_dims": (500, 500),
316        "use_result_archive": False,
317        "is_dqd": False,
318        "batch_size": 37,
319        "archive": {
320            "class": GridArchive,
321            "kwargs": {
322                "threshold_min": -np.inf
323            }
324        },
325        "emitters": [{
326            "class": EvolutionStrategyEmitter,
327            "kwargs": {
328                "sigma0": 0.5,
329                "ranker": "2imp",
330                "selection_rule": "mu",
331                "restart_rule": "no_improvement"
332            },
333            "num_emitters": 15
334        }],
335        "scheduler": {
336            "class": Scheduler,
337            "kwargs": {}
338        }
339    },
340    "cma_me_rd": {
341        "dim": 20,
342        "iters": 4500,
343        "archive_dims": (500, 500),
344        "use_result_archive": False,
345        "is_dqd": False,
346        "batch_size": 37,
347        "archive": {
348            "class": GridArchive,
349            "kwargs": {
350                "threshold_min": -np.inf
351            }
352        },
353        "emitters": [{
354            "class": EvolutionStrategyEmitter,
355            "kwargs": {
356                "sigma0": 0.5,
357                "ranker": "2rd",
358                "selection_rule": "filter",
359                "restart_rule": "no_improvement"
360            },
361            "num_emitters": 15
362        }],
363        "scheduler": {
364            "class": Scheduler,
365            "kwargs": {}
366        }
367    },
368    "cma_me_rd_mu": {
369        "dim": 20,
370        "iters": 4500,
371        "archive_dims": (500, 500),
372        "use_result_archive": False,
373        "is_dqd": False,
374        "batch_size": 37,
375        "archive": {
376            "class": GridArchive,
377            "kwargs": {
378                "threshold_min": -np.inf
379            }
380        },
381        "emitters": [{
382            "class": EvolutionStrategyEmitter,
383            "kwargs": {
384                "sigma0": 0.5,
385                "ranker": "2rd",
386                "selection_rule": "mu",
387                "restart_rule": "no_improvement"
388            },
389            "num_emitters": 15
390        }],
391        "scheduler": {
392            "class": Scheduler,
393            "kwargs": {}
394        }
395    },
396    "cma_me_opt": {
397        "dim": 20,
398        "iters": 4500,
399        "archive_dims": (500, 500),
400        "use_result_archive": False,
401        "is_dqd": False,
402        "batch_size": 37,
403        "archive": {
404            "class": GridArchive,
405            "kwargs": {
406                "threshold_min": -np.inf
407            }
408        },
409        "emitters": [{
410            "class": EvolutionStrategyEmitter,
411            "kwargs": {
412                "sigma0": 0.5,
413                "ranker": "obj",
414                "selection_rule": "mu",
415                "restart_rule": "basic"
416            },
417            "num_emitters": 15
418        }],
419        "scheduler": {
420            "class": Scheduler,
421            "kwargs": {}
422        }
423    },
424    "og_map_elites": {
425        "dim": 1_000,
426        "iters": 10_000,
427        "archive_dims": (100, 100),
428        "use_result_archive": False,
429        "is_dqd": True,
430        # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
431        # the other half are used in ask().
432        "batch_size": 36 // 2,
433        "archive": {
434            "class": GridArchive,
435            "kwargs": {
436                "threshold_min": -np.inf
437            }
438        },
439        "emitters": [{
440            "class": GradientOperatorEmitter,
441            "kwargs": {
442                "sigma": 0.5,
443                "sigma_g": 0.5,
444                "measure_gradients": False,
445                "normalize_grad": False,
446            },
447            "num_emitters": 1
448        }],
449        "scheduler": {
450            "class": Scheduler,
451            "kwargs": {}
452        }
453    },
454    "omg_mega": {
455        "dim": 1_000,
456        "iters": 10_000,
457        "archive_dims": (100, 100),
458        "use_result_archive": False,
459        "is_dqd": True,
460        # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
461        # the other half are used in ask().
462        "batch_size": 36 // 2,
463        "archive": {
464            "class": GridArchive,
465            "kwargs": {
466                "threshold_min": -np.inf
467            }
468        },
469        "emitters": [{
470            "class": GradientOperatorEmitter,
471            "kwargs": {
472                "sigma": 0.0,
473                "sigma_g": 10.0,
474                "measure_gradients": True,
475                "normalize_grad": True,
476            },
477            "num_emitters": 1
478        }],
479        "scheduler": {
480            "class": Scheduler,
481            "kwargs": {}
482        }
483    },
484    "cma_mega": {
485        "dim": 1_000,
486        "iters": 10_000,
487        "archive_dims": (100, 100),
488        "use_result_archive": False,
489        "is_dqd": True,
490        "batch_size": 35,
491        "archive": {
492            "class": GridArchive,
493            "kwargs": {
494                "threshold_min": -np.inf
495            }
496        },
497        "emitters": [{
498            "class": GradientArborescenceEmitter,
499            "kwargs": {
500                "sigma0": 10.0,
501                "lr": 1.0,
502                "grad_opt": "gradient_ascent",
503                "selection_rule": "mu"
504            },
505            "num_emitters": 1
506        }],
507        "scheduler": {
508            "class": Scheduler,
509            "kwargs": {}
510        }
511    },
512    "cma_mega_adam": {
513        "dim": 1_000,
514        "iters": 10_000,
515        "archive_dims": (100, 100),
516        "use_result_archive": False,
517        "is_dqd": True,
518        "batch_size": 35,
519        "archive": {
520            "class": GridArchive,
521            "kwargs": {
522                "threshold_min": -np.inf
523            }
524        },
525        "emitters": [{
526            "class": GradientArborescenceEmitter,
527            "kwargs": {
528                "sigma0": 10.0,
529                "lr": 0.002,
530                "grad_opt": "adam",
531                "selection_rule": "mu"
532            },
533            "num_emitters": 1
534        }],
535        "scheduler": {
536            "class": Scheduler,
537            "kwargs": {}
538        }
539    },
540    "cma_mae": {
541        "dim": 100,
542        "iters": 10_000,
543        "archive_dims": (100, 100),
544        "use_result_archive": True,
545        "is_dqd": False,
546        "batch_size": 36,
547        "archive": {
548            "class": GridArchive,
549            "kwargs": {
550                "threshold_min": 0,
551                "learning_rate": 0.01
552            }
553        },
554        "emitters": [{
555            "class": EvolutionStrategyEmitter,
556            "kwargs": {
557                "sigma0": 0.5,
558                "ranker": "imp",
559                "selection_rule": "mu",
560                "restart_rule": "basic"
561            },
562            "num_emitters": 15
563        }],
564        "scheduler": {
565            "class": Scheduler,
566            "kwargs": {}
567        }
568    },
569    "cma_maega": {
570        "dim": 1_000,
571        "iters": 10_000,
572        "archive_dims": (100, 100),
573        "use_result_archive": True,
574        "is_dqd": True,
575        "batch_size": 35,
576        "archive": {
577            "class": GridArchive,
578            "kwargs": {
579                "threshold_min": 0,
580                "learning_rate": 0.01
581            }
582        },
583        "emitters": [{
584            "class": GradientArborescenceEmitter,
585            "kwargs": {
586                "sigma0": 10.0,
587                "lr": 1.0,
588                "ranker": "imp",
589                "grad_opt": "gradient_ascent",
590                "restart_rule": "basic"
591            },
592            "num_emitters": 15
593        }],
594        "scheduler": {
595            "class": Scheduler,
596            "kwargs": {}
597        }
598    }
599}
600
601
602def sphere(solution_batch):
603    """Sphere function evaluation and measures for a batch of solutions.
604
605    Args:
606        solution_batch (np.ndarray): (batch_size, dim) batch of solutions.
607    Returns:
608        objective_batch (np.ndarray): (batch_size,) batch of objectives.
609        objective_grad_batch (np.ndarray): (batch_size, solution_dim) batch of
610            objective gradients.
611        measures_batch (np.ndarray): (batch_size, 2) batch of measures.
612        measures_grad_batch (np.ndarray): (batch_size, 2, solution_dim) batch of
613            measure gradients.
614    """
615    dim = solution_batch.shape[1]
616
617    # Shift the Sphere function so that the optimal value is at x_i = 2.048.
618    sphere_shift = 5.12 * 0.4
619
620    # Normalize the objective to the range [0, 100] where 100 is optimal.
621    best_obj = 0.0
622    worst_obj = (-5.12 - sphere_shift)**2 * dim
623    raw_obj = np.sum(np.square(solution_batch - sphere_shift), axis=1)
624    objective_batch = (raw_obj - worst_obj) / (best_obj - worst_obj) * 100
625
626    # Compute gradient of the objective.
627    objective_grad_batch = -2 * (solution_batch - sphere_shift)
628
629    # Calculate measures.
630    clipped = solution_batch.copy()
631    clip_mask = (clipped < -5.12) | (clipped > 5.12)
632    clipped[clip_mask] = 5.12 / clipped[clip_mask]
633    measures_batch = np.concatenate(
634        (
635            np.sum(clipped[:, :dim // 2], axis=1, keepdims=True),
636            np.sum(clipped[:, dim // 2:], axis=1, keepdims=True),
637        ),
638        axis=1,
639    )
640
641    # Compute gradient of the measures.
642    derivatives = np.ones(solution_batch.shape)
643    derivatives[clip_mask] = -5.12 / np.square(solution_batch[clip_mask])
644
645    mask_0 = np.concatenate((np.ones(dim // 2), np.zeros(dim - dim // 2)))
646    mask_1 = np.concatenate((np.zeros(dim // 2), np.ones(dim - dim // 2)))
647
648    d_measure0 = derivatives * mask_0
649    d_measure1 = derivatives * mask_1
650
651    measures_grad_batch = np.stack((d_measure0, d_measure1), axis=1)
652
653    return (
654        objective_batch,
655        objective_grad_batch,
656        measures_batch,
657        measures_grad_batch,
658    )
659
660
661def create_scheduler(config, algorithm, seed=None):
662    """Creates a scheduler based on the algorithm.
663
664    Args:
665        config (dict): Configuration dictionary with parameters for the various
666            components.
667        algorithm (string): Name of the algorithm
668        seed (int): Main seed or the various components.
669    Returns:
670        ribs.schedulers.Scheduler: A ribs scheduler for running the algorithm.
671    """
672    solution_dim = config["dim"]
673    archive_dims = config["archive_dims"]
674    learning_rate = 1.0 if "learning_rate" not in config["archive"][
675        "kwargs"] else config["archive"]["kwargs"]["learning_rate"]
676    use_result_archive = config["use_result_archive"]
677    max_bound = solution_dim / 2 * 5.12
678    bounds = [(-max_bound, max_bound), (-max_bound, max_bound)]
679    initial_sol = np.zeros(solution_dim)
680    mode = "batch"
681
682    # Create archive.
683    archive_class = config["archive"]["class"]
684    if archive_class == GridArchive:
685        archive = archive_class(solution_dim=solution_dim,
686                                ranges=bounds,
687                                dims=archive_dims,
688                                seed=seed,
689                                **config["archive"]["kwargs"])
690    else:
691        archive = archive_class(solution_dim=solution_dim,
692                                ranges=bounds,
693                                **config["archive"]["kwargs"])
694
695    # Create result archive.
696    result_archive = None
697    if use_result_archive:
698        result_archive = GridArchive(solution_dim=solution_dim,
699                                     dims=archive_dims,
700                                     ranges=bounds,
701                                     seed=seed)
702
703    # Create emitters. Each emitter needs a different seed so that they do not
704    # all do the same thing, hence we create an rng here to generate seeds. The
705    # rng may be seeded with None or with a user-provided seed.
706    rng = np.random.default_rng(seed)
707    emitters = []
708    for e in config["emitters"]:
709        emitter_class = e["class"]
710        emitters += [
711            emitter_class(archive,
712                          x0=initial_sol,
713                          **e["kwargs"],
714                          batch_size=config["batch_size"],
715                          seed=s)
716            for s in rng.integers(0, 1_000_000, e["num_emitters"])
717        ]
718
719    # Create Scheduler
720    scheduler_class = config["scheduler"]["class"]
721    scheduler = scheduler_class(archive,
722                                emitters,
723                                result_archive=result_archive,
724                                add_mode=mode,
725                                **config["scheduler"]["kwargs"])
726    scheduler_name = scheduler.__class__.__name__
727
728    print(f"Create {scheduler_name} for {algorithm} with learning rate "
729          f"{learning_rate} and add mode {mode}, using solution dim "
730          f"{solution_dim}, archive dims {archive_dims}, and "
731          f"{len(emitters)} emitters.")
732    return scheduler
733
734
735def save_heatmap(archive, heatmap_path):
736    """Saves a heatmap of the archive to the given path.
737
738    Args:
739        archive (GridArchive or CVTArchive): The archive to save.
740        heatmap_path: Image path for the heatmap.
741    """
742    if isinstance(archive, GridArchive):
743        plt.figure(figsize=(8, 6))
744        grid_archive_heatmap(archive, vmin=0, vmax=100)
745        plt.tight_layout()
746        plt.savefig(heatmap_path)
747    elif isinstance(archive, CVTArchive):
748        plt.figure(figsize=(16, 12))
749        cvt_archive_heatmap(archive, vmin=0, vmax=100)
750        plt.tight_layout()
751        plt.savefig(heatmap_path)
752    plt.close(plt.gcf())
753
754
755def sphere_main(algorithm,
756                dim=None,
757                itrs=None,
758                archive_dims=None,
759                learning_rate=None,
760                outdir="sphere_output",
761                log_freq=250,
762                seed=None):
763    """Demo on the Sphere function.
764
765    Args:
766        algorithm (str): Name of the algorithm.
767        dim (int): Dimensionality of the sphere function.
768        itrs (int): Iterations to run.
769        archive_dims (tuple): Dimensionality of the archive.
770        learning_rate (float): The archive learning rate.
771        outdir (str): Directory to save output.
772        log_freq (int): Number of iterations to wait before recording metrics
773            and saving heatmap.
774        seed (int): Seed for the algorithm. By default, there is no seed.
775    """
776    config = copy.deepcopy(CONFIG[algorithm])
777
778    # Use default dim for each algorithm.
779    if dim is not None:
780        config["dim"] = dim
781
782    # Use default itrs for each algorithm.
783    if itrs is not None:
784        config["iters"] = itrs
785
786    # Use default archive_dim for each algorithm.
787    if archive_dims is not None:
788        config["archive_dims"] = archive_dims
789
790    # Use default learning_rate for each algorithm.
791    if learning_rate is not None:
792        config["archive"]["kwargs"]["learning_rate"] = learning_rate
793
794    name = f"{algorithm}_{config['dim']}"
795    outdir = Path(outdir)
796    if not outdir.is_dir():
797        outdir.mkdir()
798
799    scheduler = create_scheduler(config, algorithm, seed=seed)
800    result_archive = scheduler.result_archive
801    is_dqd = config["is_dqd"]
802    itrs = config["iters"]
803    metrics = {
804        "QD Score": {
805            "x": [0],
806            "y": [0.0],
807        },
808        "Archive Coverage": {
809            "x": [0],
810            "y": [0.0],
811        },
812    }
813
814    non_logging_time = 0.0
815    save_heatmap(result_archive, str(outdir / f"{name}_heatmap_{0:05d}.png"))
816
817    for itr in tqdm.trange(1, itrs + 1):
818        itr_start = time.time()
819
820        if is_dqd:
821            solution_batch = scheduler.ask_dqd()
822            (objective_batch, objective_grad_batch, measures_batch,
823             measures_grad_batch) = sphere(solution_batch)
824            objective_grad_batch = np.expand_dims(objective_grad_batch, axis=1)
825            jacobian_batch = np.concatenate(
826                (objective_grad_batch, measures_grad_batch), axis=1)
827            scheduler.tell_dqd(objective_batch, measures_batch, jacobian_batch)
828
829        solution_batch = scheduler.ask()
830        objective_batch, _, measure_batch, _ = sphere(solution_batch)
831        scheduler.tell(objective_batch, measure_batch)
832        non_logging_time += time.time() - itr_start
833
834        # Logging and output.
835        final_itr = itr == itrs
836        if itr % log_freq == 0 or final_itr:
837            if final_itr:
838                result_archive.as_pandas(include_solutions=final_itr).to_csv(
839                    outdir / f"{name}_archive.csv")
840
841            # Record and display metrics.
842            metrics["QD Score"]["x"].append(itr)
843            metrics["QD Score"]["y"].append(result_archive.stats.qd_score)
844            metrics["Archive Coverage"]["x"].append(itr)
845            metrics["Archive Coverage"]["y"].append(
846                result_archive.stats.coverage)
847            tqdm.tqdm.write(
848                f"Iteration {itr} | Archive Coverage: "
849                f"{metrics['Archive Coverage']['y'][-1] * 100:.3f}% "
850                f"QD Score: {metrics['QD Score']['y'][-1]:.3f}")
851
852            save_heatmap(result_archive,
853                         str(outdir / f"{name}_heatmap_{itr:05d}.png"))
854
855    # Plot metrics.
856    print(f"Algorithm Time (Excludes Logging and Setup): {non_logging_time}s")
857    for metric, values in metrics.items():
858        plt.plot(values["x"], values["y"])
859        plt.title(metric)
860        plt.xlabel("Iteration")
861        plt.savefig(
862            str(outdir / f"{name}_{metric.lower().replace(' ', '_')}.png"))
863        plt.clf()
864    with (outdir / f"{name}_metrics.json").open("w") as file:
865        json.dump(metrics, file, indent=2)
866
867
868if __name__ == '__main__':
869    fire.Fire(sphere_main)