Sphere Function with Various Algorithms

  1"""Runs various QD algorithms on the Sphere function.
  2
  3Install the following dependencies before running this example:
  4    pip install ribs[visualize] tqdm fire
  5
  6The sphere function in this example is adapted from Section 4 of Fontaine 2020
  7(https://arxiv.org/abs/1912.02400). Namely, each solution value is clipped to
  8the range [-5.12, 5.12], and the optimum is moved from [0,..] to [0.4 * 5.12 =
  92.048,..]. Furthermore, the objectives are normalized to the range [0,
 10100] where 100 is the maximum and corresponds to 0 on the original sphere
 11function.
 12
 13There are two measures in this example. The first is the sum of the first n/2
 14clipped values of the solution, and the second is the sum of the last n/2
 15clipped values of the solution. Having each measure depend equally on several
 16values in the solution space makes the problem more difficult (refer to
 17Fontaine 2020 for more info).
 18
 19The supported algorithms are:
 20- `map_elites`: GridArchive with GaussianEmitter.
 21- `line_map_elites`: GridArchive with IsoLineEmitter.
 22- `cvt_map_elites`: CVTArchive with GaussianEmitter.
 23- `line_cvt_map_elites`: CVTArchive with IsoLineEmitter.
 24- `me_map_elites`: MAP-Elites with Bandit Scheduler.
 25- `cma_me_imp`: GridArchive with EvolutionStrategyEmitter using
 26  TwoStageImprovmentRanker.
 27- `cma_me_imp_mu`: GridArchive with EvolutionStrategyEmitter using
 28  TwoStageImprovmentRanker and mu selection rule.
 29- `cma_me_rd`: GridArchive with EvolutionStrategyEmitter using
 30  RandomDirectionRanker.
 31- `cma_me_rd_mu`: GridArchive with EvolutionStrategyEmitter using
 32  TwoStageRandomDirectionRanker and mu selection rule.
 33- `cma_me_opt`: GridArchive with EvolutionStrategyEmitter using ObjectiveRanker
 34  with mu selection rule.
 35- `cma_me_mixed`: GridArchive with EvolutionStrategyEmitter, where half (7) of
 36  the emitter are using TwoStageRandomDirectionRanker and half (8) are
 37  TwoStageImprovementRanker.
 38- `og_map_elites`: GridArchive with GradientOperatorEmitter, does not use
 39  measure gradients.
 40- `omg_mega`: GridArchive with GradientOperatorEmitter, uses measure gradients.
 41- `cma_mega`: GridArchive with GradientArborescenceEmitter.
 42- `cma_mega_adam`: GridArchive with GradientArborescenceEmitter using Adam
 43  Optimizer.
 44- `cma_mae`: GridArchive (learning_rate = 0.01) with EvolutionStrategyEmitter
 45  using ImprovementRanker.
 46- `cma_maega`: GridArchive (learning_rate = 0.01) with
 47  GradientArborescenceEmitter using ImprovementRanker.
 48
 49The parameters for each algorithm are stored in CONFIG. The parameters
 50reproduce the experiments presented in the paper in which each algorithm is
 51introduced.
 52
 53Outputs are saved in the `sphere_output/` directory by default. The archive is
 54saved as a CSV named `{algorithm}_{dim}_archive.csv`, while snapshots of the
 55heatmap are saved as `{algorithm}_{dim}_heatmap_{iteration}.png`. Metrics about
 56the run are also saved in `{algorithm}_{dim}_metrics.json`, and plots of the
 57metrics are saved in PNG's with the name `{algorithm}_{dim}_metric_name.png`.
 58
 59To generate a video of the heatmap from the heatmap images, use a tool like
 60ffmpeg. For example, the following will generate a 6FPS video showing the
 61heatmap for cma_me_imp with 20 dims.
 62
 63    ffmpeg -r 6 -i "sphere_output/cma_me_imp_20_heatmap_%*.png" \
 64        sphere_output/cma_me_imp_20_heatmap_video.mp4
 65
 66Usage (see sphere_main function for all args or run `python sphere.py --help`):
 67    python sphere.py ALGORITHM
 68Example:
 69    python sphere.py map_elites
 70
 71    # To make numpy and sklearn run single-threaded, set env variables for BLAS
 72    # and OpenMP:
 73    OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 python sphere.py map_elites 20
 74Help:
 75    python sphere.py --help
 76"""
 77import copy
 78import json
 79import time
 80from pathlib import Path
 81
 82import fire
 83import matplotlib.pyplot as plt
 84import numpy as np
 85import tqdm
 86
 87from ribs.archives import CVTArchive, GridArchive
 88from ribs.emitters import (EvolutionStrategyEmitter, GaussianEmitter,
 89                           GradientArborescenceEmitter, GradientOperatorEmitter,
 90                           IsoLineEmitter)
 91from ribs.schedulers import BanditScheduler, Scheduler
 92from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
 93
 94CONFIG = {
 95    "map_elites": {
 96        "dim": 20,
 97        "iters": 4500,
 98        "archive_dims": (500, 500),
 99        "use_result_archive": False,
100        "is_dqd": False,
101        "batch_size": 37,
102        "archive": {
103            "class": GridArchive,
104            "kwargs": {
105                "threshold_min": -np.inf
106            }
107        },
108        "emitters": [{
109            "class": GaussianEmitter,
110            "kwargs": {
111                "sigma": 0.5
112            },
113            "num_emitters": 15
114        }],
115        "scheduler": {
116            "class": Scheduler,
117            "kwargs": {}
118        }
119    },
120    "line_map_elites": {
121        "dim": 20,
122        "iters": 4500,
123        "archive_dims": (500, 500),
124        "use_result_archive": False,
125        "is_dqd": False,
126        "batch_size": 37,
127        "archive": {
128            "class": GridArchive,
129            "kwargs": {
130                "threshold_min": -np.inf
131            }
132        },
133        "emitters": [{
134            "class": IsoLineEmitter,
135            "kwargs": {
136                "iso_sigma": 0.1,
137                "line_sigma": 0.2
138            },
139            "num_emitters": 15
140        }],
141        "scheduler": {
142            "class": Scheduler,
143            "kwargs": {}
144        }
145    },
146    "cvt_map_elites": {
147        "dim": 20,
148        "iters": 4500,
149        "archive_dims": (500, 500),
150        "use_result_archive": False,
151        "is_dqd": False,
152        "batch_size": 37,
153        "archive": {
154            "class": CVTArchive,
155            "kwargs": {
156                "cells": 10_000,
157                "samples": 100_000,
158                "use_kd_tree": True
159            }
160        },
161        "emitters": [{
162            "class": GaussianEmitter,
163            "kwargs": {
164                "sigma": 0.5
165            },
166            "num_emitters": 15
167        }],
168        "scheduler": {
169            "class": Scheduler,
170            "kwargs": {}
171        }
172    },
173    "line_cvt_map_elites": {
174        "dim": 20,
175        "iters": 4500,
176        "archive_dims": (500, 500),
177        "use_result_archive": False,
178        "is_dqd": False,
179        "batch_size": 37,
180        "archive": {
181            "class": CVTArchive,
182            "kwargs": {
183                "cells": 10_000,
184                "samples": 100_000,
185                "use_kd_tree": True
186            }
187        },
188        "emitters": [{
189            "class": IsoLineEmitter,
190            "kwargs": {
191                "iso_sigma": 0.1,
192                "line_sigma": 0.2
193            },
194            "num_emitters": 15
195        }],
196        "scheduler": {
197            "class": Scheduler,
198            "kwargs": {}
199        }
200    },
201    "me_map_elites": {
202        "dim": 100,
203        "iters": 20_000,
204        "archive_dims": (100, 100),
205        "use_result_archive": False,
206        "is_dqd": False,
207        "batch_size": 50,
208        "archive": {
209            "class": GridArchive,
210            "kwargs": {
211                "threshold_min": -np.inf
212            }
213        },
214        "emitters": [{
215            "class": EvolutionStrategyEmitter,
216            "kwargs": {
217                "sigma0": 0.5,
218                "ranker": "obj"
219            },
220            "num_emitters": 12
221        }, {
222            "class": EvolutionStrategyEmitter,
223            "kwargs": {
224                "sigma0": 0.5,
225                "ranker": "2rd"
226            },
227            "num_emitters": 12
228        }, {
229            "class": EvolutionStrategyEmitter,
230            "kwargs": {
231                "sigma0": 0.5,
232                "ranker": "2imp"
233            },
234            "num_emitters": 12
235        }, {
236            "class": IsoLineEmitter,
237            "kwargs": {
238                "iso_sigma": 0.01,
239                "line_sigma": 0.1
240            },
241            "num_emitters": 12
242        }],
243        "scheduler": {
244            "class": BanditScheduler,
245            "kwargs": {
246                "num_active": 12,
247                "reselect": "terminated"
248            }
249        }
250    },
251    "cma_me_mixed": {
252        "dim": 20,
253        "iters": 4500,
254        "archive_dims": (500, 500),
255        "use_result_archive": False,
256        "is_dqd": False,
257        "batch_size": 37,
258        "archive": {
259            "class": GridArchive,
260            "kwargs": {
261                "threshold_min": -np.inf
262            }
263        },
264        "emitters": [{
265            "class": EvolutionStrategyEmitter,
266            "kwargs": {
267                "sigma0": 0.5,
268                "ranker": "2rd"
269            },
270            "num_emitters": 7
271        }, {
272            "class": EvolutionStrategyEmitter,
273            "kwargs": {
274                "sigma0": 0.5,
275                "ranker": "2imp"
276            },
277            "num_emitters": 8
278        }],
279        "scheduler": {
280            "class": Scheduler,
281            "kwargs": {}
282        }
283    },
284    "cma_me_imp": {
285        "dim": 20,
286        "iters": 4500,
287        "archive_dims": (500, 500),
288        "use_result_archive": False,
289        "is_dqd": False,
290        "batch_size": 37,
291        "archive": {
292            "class": GridArchive,
293            "kwargs": {
294                "threshold_min": -np.inf
295            }
296        },
297        "emitters": [{
298            "class": EvolutionStrategyEmitter,
299            "kwargs": {
300                "sigma0": 0.5,
301                "ranker": "2imp",
302                "selection_rule": "filter",
303                "restart_rule": "no_improvement"
304            },
305            "num_emitters": 15
306        }],
307        "scheduler": {
308            "class": Scheduler,
309            "kwargs": {}
310        }
311    },
312    "cma_me_imp_mu": {
313        "dim": 20,
314        "iters": 4500,
315        "archive_dims": (500, 500),
316        "use_result_archive": False,
317        "is_dqd": False,
318        "batch_size": 37,
319        "archive": {
320            "class": GridArchive,
321            "kwargs": {
322                "threshold_min": -np.inf
323            }
324        },
325        "emitters": [{
326            "class": EvolutionStrategyEmitter,
327            "kwargs": {
328                "sigma0": 0.5,
329                "ranker": "2imp",
330                "selection_rule": "mu",
331                "restart_rule": "no_improvement"
332            },
333            "num_emitters": 15
334        }],
335        "scheduler": {
336            "class": Scheduler,
337            "kwargs": {}
338        }
339    },
340    "cma_me_rd": {
341        "dim": 20,
342        "iters": 4500,
343        "archive_dims": (500, 500),
344        "use_result_archive": False,
345        "is_dqd": False,
346        "batch_size": 37,
347        "archive": {
348            "class": GridArchive,
349            "kwargs": {
350                "threshold_min": -np.inf
351            }
352        },
353        "emitters": [{
354            "class": EvolutionStrategyEmitter,
355            "kwargs": {
356                "sigma0": 0.5,
357                "ranker": "2rd",
358                "selection_rule": "filter",
359                "restart_rule": "no_improvement"
360            },
361            "num_emitters": 15
362        }],
363        "scheduler": {
364            "class": Scheduler,
365            "kwargs": {}
366        }
367    },
368    "cma_me_rd_mu": {
369        "dim": 20,
370        "iters": 4500,
371        "archive_dims": (500, 500),
372        "use_result_archive": False,
373        "is_dqd": False,
374        "batch_size": 37,
375        "archive": {
376            "class": GridArchive,
377            "kwargs": {
378                "threshold_min": -np.inf
379            }
380        },
381        "emitters": [{
382            "class": EvolutionStrategyEmitter,
383            "kwargs": {
384                "sigma0": 0.5,
385                "ranker": "2rd",
386                "selection_rule": "mu",
387                "restart_rule": "no_improvement"
388            },
389            "num_emitters": 15
390        }],
391        "scheduler": {
392            "class": Scheduler,
393            "kwargs": {}
394        }
395    },
396    "cma_me_opt": {
397        "dim": 20,
398        "iters": 4500,
399        "archive_dims": (500, 500),
400        "use_result_archive": False,
401        "is_dqd": False,
402        "batch_size": 37,
403        "archive": {
404            "class": GridArchive,
405            "kwargs": {
406                "threshold_min": -np.inf
407            }
408        },
409        "emitters": [{
410            "class": EvolutionStrategyEmitter,
411            "kwargs": {
412                "sigma0": 0.5,
413                "ranker": "obj",
414                "selection_rule": "mu",
415                "restart_rule": "basic"
416            },
417            "num_emitters": 15
418        }],
419        "scheduler": {
420            "class": Scheduler,
421            "kwargs": {}
422        }
423    },
424    "og_map_elites": {
425        "dim": 1_000,
426        "iters": 10_000,
427        "archive_dims": (100, 100),
428        "use_result_archive": False,
429        "is_dqd": True,
430        # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
431        # the other half are used in ask().
432        "batch_size": 36 // 2,
433        "archive": {
434            "class": GridArchive,
435            "kwargs": {
436                "threshold_min": -np.inf
437            }
438        },
439        "emitters": [{
440            "class": GradientOperatorEmitter,
441            "kwargs": {
442                "sigma": 0.5,
443                "sigma_g": 0.5,
444                "measure_gradients": False,
445                "normalize_grad": False,
446            },
447            "num_emitters": 1
448        }],
449        "scheduler": {
450            "class": Scheduler,
451            "kwargs": {}
452        }
453    },
454    "omg_mega": {
455        "dim": 1_000,
456        "iters": 10_000,
457        "archive_dims": (100, 100),
458        "use_result_archive": False,
459        "is_dqd": True,
460        # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
461        # the other half are used in ask().
462        "batch_size": 36 // 2,
463        "archive": {
464            "class": GridArchive,
465            "kwargs": {
466                "threshold_min": -np.inf
467            }
468        },
469        "emitters": [{
470            "class": GradientOperatorEmitter,
471            "kwargs": {
472                "sigma": 0.0,
473                "sigma_g": 10.0,
474                "measure_gradients": True,
475                "normalize_grad": True,
476            },
477            "num_emitters": 1
478        }],
479        "scheduler": {
480            "class": Scheduler,
481            "kwargs": {}
482        }
483    },
484    "cma_mega": {
485        "dim": 1_000,
486        "iters": 10_000,
487        "archive_dims": (100, 100),
488        "use_result_archive": False,
489        "is_dqd": True,
490        "batch_size": 35,
491        "archive": {
492            "class": GridArchive,
493            "kwargs": {
494                "threshold_min": -np.inf
495            }
496        },
497        "emitters": [{
498            "class": GradientArborescenceEmitter,
499            "kwargs": {
500                "sigma0": 10.0,
501                "lr": 1.0,
502                "grad_opt": "gradient_ascent",
503                "selection_rule": "mu"
504            },
505            "num_emitters": 1
506        }],
507        "scheduler": {
508            "class": Scheduler,
509            "kwargs": {}
510        }
511    },
512    "cma_mega_adam": {
513        "dim": 1_000,
514        "iters": 10_000,
515        "archive_dims": (100, 100),
516        "use_result_archive": False,
517        "is_dqd": True,
518        "batch_size": 35,
519        "archive": {
520            "class": GridArchive,
521            "kwargs": {
522                "threshold_min": -np.inf
523            }
524        },
525        "emitters": [{
526            "class": GradientArborescenceEmitter,
527            "kwargs": {
528                "sigma0": 10.0,
529                "lr": 0.002,
530                "grad_opt": "adam",
531                "selection_rule": "mu"
532            },
533            "num_emitters": 1
534        }],
535        "scheduler": {
536            "class": Scheduler,
537            "kwargs": {}
538        }
539    },
540    "cma_mae": {
541        "dim": 100,
542        "iters": 10_000,
543        "archive_dims": (100, 100),
544        "use_result_archive": True,
545        "is_dqd": False,
546        "batch_size": 36,
547        "archive": {
548            "class": GridArchive,
549            "kwargs": {
550                "threshold_min": 0,
551                "learning_rate": 0.01
552            }
553        },
554        "emitters": [{
555            "class": EvolutionStrategyEmitter,
556            "kwargs": {
557                "sigma0": 0.5,
558                "ranker": "imp",
559                "selection_rule": "mu",
560                "restart_rule": "basic"
561            },
562            "num_emitters": 15
563        }],
564        "scheduler": {
565            "class": Scheduler,
566            "kwargs": {}
567        }
568    },
569    "cma_maega": {
570        "dim": 1_000,
571        "iters": 10_000,
572        "archive_dims": (100, 100),
573        "use_result_archive": True,
574        "is_dqd": True,
575        "batch_size": 35,
576        "archive": {
577            "class": GridArchive,
578            "kwargs": {
579                "threshold_min": 0,
580                "learning_rate": 0.01
581            }
582        },
583        "emitters": [{
584            "class": GradientArborescenceEmitter,
585            "kwargs": {
586                "sigma0": 10.0,
587                "lr": 1.0,
588                "ranker": "imp",
589                "grad_opt": "gradient_ascent",
590                "restart_rule": "basic"
591            },
592            "num_emitters": 15
593        }],
594        "scheduler": {
595            "class": Scheduler,
596            "kwargs": {}
597        }
598    }
599}
600
601
602def sphere(solution_batch):
603    """Sphere function evaluation and measures for a batch of solutions.
604
605    Args:
606        solution_batch (np.ndarray): (batch_size, dim) batch of solutions.
607    Returns:
608        objective_batch (np.ndarray): (batch_size,) batch of objectives.
609        objective_grad_batch (np.ndarray): (batch_size, solution_dim) batch of
610            objective gradients.
611        measures_batch (np.ndarray): (batch_size, 2) batch of measures.
612        measures_grad_batch (np.ndarray): (batch_size, 2, solution_dim) batch of
613            measure gradients.
614    """
615    dim = solution_batch.shape[1]
616
617    # Shift the Sphere function so that the optimal value is at x_i = 2.048.
618    sphere_shift = 5.12 * 0.4
619
620    # Normalize the objective to the range [0, 100] where 100 is optimal.
621    best_obj = 0.0
622    worst_obj = (-5.12 - sphere_shift)**2 * dim
623    raw_obj = np.sum(np.square(solution_batch - sphere_shift), axis=1)
624    objective_batch = (raw_obj - worst_obj) / (best_obj - worst_obj) * 100
625
626    # Compute gradient of the objective.
627    objective_grad_batch = -2 * (solution_batch - sphere_shift)
628
629    # Calculate measures.
630    clipped = solution_batch.copy()
631    clip_mask = (clipped < -5.12) | (clipped > 5.12)
632    clipped[clip_mask] = 5.12 / clipped[clip_mask]
633    measures_batch = np.concatenate(
634        (
635            np.sum(clipped[:, :dim // 2], axis=1, keepdims=True),
636            np.sum(clipped[:, dim // 2:], axis=1, keepdims=True),
637        ),
638        axis=1,
639    )
640
641    # Compute gradient of the measures.
642    derivatives = np.ones(solution_batch.shape)
643    derivatives[clip_mask] = -5.12 / np.square(solution_batch[clip_mask])
644
645    mask_0 = np.concatenate((np.ones(dim // 2), np.zeros(dim - dim // 2)))
646    mask_1 = np.concatenate((np.zeros(dim // 2), np.ones(dim - dim // 2)))
647
648    d_measure0 = derivatives * mask_0
649    d_measure1 = derivatives * mask_1
650
651    measures_grad_batch = np.stack((d_measure0, d_measure1), axis=1)
652
653    return (
654        objective_batch,
655        objective_grad_batch,
656        measures_batch,
657        measures_grad_batch,
658    )
659
660
661def create_scheduler(config, algorithm, seed=None):
662    """Creates a scheduler based on the algorithm.
663
664    Args:
665        config (dict): Configuration dictionary with parameters for the various
666            components.
667        algorithm (string): Name of the algorithm.
668        seed (int): Main seed for the various components.
669    Returns:
670        ribs.schedulers.Scheduler: A ribs scheduler for running the algorithm.
671    """
672    solution_dim = config["dim"]
673    archive_dims = config["archive_dims"]
674    learning_rate = 1.0 if "learning_rate" not in config["archive"][
675        "kwargs"] else config["archive"]["kwargs"]["learning_rate"]
676    use_result_archive = config["use_result_archive"]
677    max_bound = solution_dim / 2 * 5.12
678    bounds = [(-max_bound, max_bound), (-max_bound, max_bound)]
679    initial_sol = np.zeros(solution_dim)
680    mode = "batch"
681
682    # Create archive.
683    archive_class = config["archive"]["class"]
684    if archive_class == GridArchive:
685        archive = archive_class(solution_dim=solution_dim,
686                                ranges=bounds,
687                                dims=archive_dims,
688                                seed=seed,
689                                **config["archive"]["kwargs"])
690    else:
691        archive = archive_class(solution_dim=solution_dim,
692                                ranges=bounds,
693                                seed=seed,
694                                **config["archive"]["kwargs"])
695
696    # Create result archive.
697    result_archive = None
698    if use_result_archive:
699        result_archive = GridArchive(solution_dim=solution_dim,
700                                     dims=archive_dims,
701                                     ranges=bounds,
702                                     seed=seed)
703
704    # Create emitters. Each emitter needs a different seed so that they do not
705    # all do the same thing, hence we create an rng here to generate seeds. The
706    # rng may be seeded with None or with a user-provided seed.
707    seed_sequence = np.random.SeedSequence(seed)
708    emitters = []
709    for e in config["emitters"]:
710        emitter_class = e["class"]
711        emitters += [
712            emitter_class(
713                archive,
714                x0=initial_sol,
715                **e["kwargs"],
716                batch_size=config["batch_size"],
717                seed=s,
718            ) for s in seed_sequence.spawn(e["num_emitters"])
719        ]
720
721    # Create Scheduler
722    scheduler_class = config["scheduler"]["class"]
723    scheduler = scheduler_class(archive,
724                                emitters,
725                                result_archive=result_archive,
726                                add_mode=mode,
727                                **config["scheduler"]["kwargs"])
728    scheduler_name = scheduler.__class__.__name__
729
730    print(f"Create {scheduler_name} for {algorithm} with learning rate "
731          f"{learning_rate} and add mode {mode}, using solution dim "
732          f"{solution_dim}, archive dims {archive_dims}, and "
733          f"{len(emitters)} emitters.")
734    return scheduler
735
736
737def save_heatmap(archive, heatmap_path):
738    """Saves a heatmap of the archive to the given path.
739
740    Args:
741        archive (GridArchive or CVTArchive): The archive to save.
742        heatmap_path: Image path for the heatmap.
743    """
744    if isinstance(archive, GridArchive):
745        plt.figure(figsize=(8, 6))
746        grid_archive_heatmap(archive, vmin=0, vmax=100)
747        plt.tight_layout()
748        plt.savefig(heatmap_path)
749    elif isinstance(archive, CVTArchive):
750        plt.figure(figsize=(16, 12))
751        cvt_archive_heatmap(archive, vmin=0, vmax=100)
752        plt.tight_layout()
753        plt.savefig(heatmap_path)
754    plt.close(plt.gcf())
755
756
757def sphere_main(algorithm,
758                dim=None,
759                itrs=None,
760                archive_dims=None,
761                learning_rate=None,
762                es=None,
763                outdir="sphere_output",
764                log_freq=250,
765                seed=None):
766    """Demo on the Sphere function.
767
768    Args:
769        algorithm (str): Name of the algorithm.
770        dim (int): Dimensionality of the sphere function.
771        itrs (int): Iterations to run.
772        archive_dims (tuple): Dimensionality of the archive.
773        learning_rate (float): The archive learning rate.
774        es (str): If passed, this will set the ES for all
775            EvolutionStrategyEmitter instances.
776        outdir (str): Directory to save output.
777        log_freq (int): Number of iterations to wait before recording metrics
778            and saving heatmap.
779        seed (int): Seed for the algorithm. By default, there is no seed.
780    """
781    config = copy.deepcopy(CONFIG[algorithm])
782
783    # Use default dim for each algorithm.
784    if dim is not None:
785        config["dim"] = dim
786
787    # Use default itrs for each algorithm.
788    if itrs is not None:
789        config["iters"] = itrs
790
791    # Use default archive_dim for each algorithm.
792    if archive_dims is not None:
793        config["archive_dims"] = archive_dims
794
795    # Use default learning_rate for each algorithm.
796    if learning_rate is not None:
797        config["archive"]["kwargs"]["learning_rate"] = learning_rate
798
799    # Set ES for all EvolutionStrategyEmitter.
800    if es is not None:
801        for e in config["emitters"]:
802            if e["class"] == EvolutionStrategyEmitter:
803                e["kwargs"]["es"] = es
804
805    name = f"{algorithm}_{config['dim']}"
806    if es is not None:
807        name += f"_{es}"
808    outdir = Path(outdir)
809    if not outdir.is_dir():
810        outdir.mkdir()
811
812    scheduler = create_scheduler(config, algorithm, seed=seed)
813    result_archive = scheduler.result_archive
814    is_dqd = config["is_dqd"]
815    itrs = config["iters"]
816    metrics = {
817        "QD Score": {
818            "x": [0],
819            "y": [0.0],
820        },
821        "Archive Coverage": {
822            "x": [0],
823            "y": [0.0],
824        },
825    }
826
827    non_logging_time = 0.0
828    save_heatmap(result_archive, str(outdir / f"{name}_heatmap_{0:05d}.png"))
829
830    for itr in tqdm.trange(1, itrs + 1):
831        itr_start = time.time()
832
833        if is_dqd:
834            solution_batch = scheduler.ask_dqd()
835            (objective_batch, objective_grad_batch, measures_batch,
836             measures_grad_batch) = sphere(solution_batch)
837            objective_grad_batch = np.expand_dims(objective_grad_batch, axis=1)
838            jacobian_batch = np.concatenate(
839                (objective_grad_batch, measures_grad_batch), axis=1)
840            scheduler.tell_dqd(objective_batch, measures_batch, jacobian_batch)
841
842        solution_batch = scheduler.ask()
843        objective_batch, _, measure_batch, _ = sphere(solution_batch)
844        scheduler.tell(objective_batch, measure_batch)
845        non_logging_time += time.time() - itr_start
846
847        # Logging and output.
848        final_itr = itr == itrs
849        if itr % log_freq == 0 or final_itr:
850            if final_itr:
851                result_archive.data(return_type="pandas").to_csv(
852                    outdir / f"{name}_archive.csv")
853
854            # Record and display metrics.
855            metrics["QD Score"]["x"].append(itr)
856            metrics["QD Score"]["y"].append(result_archive.stats.qd_score)
857            metrics["Archive Coverage"]["x"].append(itr)
858            metrics["Archive Coverage"]["y"].append(
859                result_archive.stats.coverage)
860            tqdm.tqdm.write(
861                f"Iteration {itr} | Archive Coverage: "
862                f"{metrics['Archive Coverage']['y'][-1] * 100:.3f}% "
863                f"QD Score: {metrics['QD Score']['y'][-1]:.3f}")
864
865            save_heatmap(result_archive,
866                         str(outdir / f"{name}_heatmap_{itr:05d}.png"))
867
868    # Plot metrics.
869    print(f"Algorithm Time (Excludes Logging and Setup): {non_logging_time}s")
870    for metric, values in metrics.items():
871        plt.plot(values["x"], values["y"])
872        plt.title(metric)
873        plt.xlabel("Iteration")
874        plt.savefig(
875            str(outdir / f"{name}_{metric.lower().replace(' ', '_')}.png"))
876        plt.clf()
877    with (outdir / f"{name}_metrics.json").open("w") as file:
878        json.dump(metrics, file, indent=2)
879
880
881if __name__ == '__main__':
882    fire.Fire(sphere_main)