1"""Runs various QD algorithms on the Sphere function.
2
3Install the following dependencies before running this example:
4 pip install ribs[visualize] tqdm fire
5
6The sphere function in this example is adapted from Section 4 of Fontaine 2020
7(https://arxiv.org/abs/1912.02400). Namely, each solution value is clipped to
8the range [-5.12, 5.12], and the optimum is moved from [0,..] to [0.4 * 5.12 =
92.048,..]. Furthermore, the objectives are normalized to the range [0,
10100] where 100 is the maximum and corresponds to 0 on the original sphere
11function.
12
13There are two measures in this example. The first is the sum of the first n/2
14clipped values of the solution, and the second is the sum of the last n/2
15clipped values of the solution. Having each measure depend equally on several
16values in the solution space makes the problem more difficult (refer to
17Fontaine 2020 for more info).
18
19The supported algorithms are:
20- `map_elites`: GridArchive with GaussianEmitter.
21- `line_map_elites`: GridArchive with IsoLineEmitter.
22- `cvt_map_elites`: CVTArchive with GaussianEmitter.
23- `line_cvt_map_elites`: CVTArchive with IsoLineEmitter.
24- `me_map_elites`: MAP-Elites with Bandit Scheduler.
25- `cma_me_imp`: GridArchive with EvolutionStrategyEmitter using
26 TwoStageImprovmentRanker.
27- `cma_me_imp_mu`: GridArchive with EvolutionStrategyEmitter using
28 TwoStageImprovmentRanker and mu selection rule.
29- `cma_me_rd`: GridArchive with EvolutionStrategyEmitter using
30 RandomDirectionRanker.
31- `cma_me_rd_mu`: GridArchive with EvolutionStrategyEmitter using
32 TwoStageRandomDirectionRanker and mu selection rule.
33- `cma_me_opt`: GridArchive with EvolutionStrategyEmitter using ObjectiveRanker
34 with mu selection rule.
35- `cma_me_mixed`: GridArchive with EvolutionStrategyEmitter, where half (7) of
36 the emitter are using TwoStageRandomDirectionRanker and half (8) are
37 TwoStageImprovementRanker.
38- `og_map_elites`: GridArchive with GradientOperatorEmitter, does not use
39 measure gradients.
40- `omg_mega`: GridArchive with GradientOperatorEmitter, uses measure gradients.
41- `cma_mega`: GridArchive with GradientArborescenceEmitter.
42- `cma_mega_adam`: GridArchive with GradientArborescenceEmitter using Adam
43 Optimizer.
44- `cma_mae`: GridArchive (learning_rate = 0.01) with EvolutionStrategyEmitter
45 using ImprovementRanker.
46- `cma_maega`: GridArchive (learning_rate = 0.01) with
47 GradientArborescenceEmitter using ImprovementRanker.
48
49The parameters for each algorithm are stored in CONFIG. The parameters
50reproduce the experiments presented in the paper in which each algorithm is
51introduced.
52
53Outputs are saved in the `sphere_output/` directory by default. The archive is
54saved as a CSV named `{algorithm}_{dim}_archive.csv`, while snapshots of the
55heatmap are saved as `{algorithm}_{dim}_heatmap_{iteration}.png`. Metrics about
56the run are also saved in `{algorithm}_{dim}_metrics.json`, and plots of the
57metrics are saved in PNG's with the name `{algorithm}_{dim}_metric_name.png`.
58
59To generate a video of the heatmap from the heatmap images, use a tool like
60ffmpeg. For example, the following will generate a 6FPS video showing the
61heatmap for cma_me_imp with 20 dims.
62
63 ffmpeg -r 6 -i "sphere_output/cma_me_imp_20_heatmap_%*.png" \
64 sphere_output/cma_me_imp_20_heatmap_video.mp4
65
66Usage (see sphere_main function for all args or run `python sphere.py --help`):
67 python sphere.py ALGORITHM
68Example:
69 python sphere.py map_elites
70
71 # To make numpy and sklearn run single-threaded, set env variables for BLAS
72 # and OpenMP:
73 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 python sphere.py map_elites 20
74Help:
75 python sphere.py --help
76"""
77import copy
78import json
79import time
80from pathlib import Path
81
82import fire
83import matplotlib.pyplot as plt
84import numpy as np
85import tqdm
86
87from ribs.archives import CVTArchive, GridArchive
88from ribs.emitters import (EvolutionStrategyEmitter, GaussianEmitter,
89 GradientArborescenceEmitter, GradientOperatorEmitter,
90 IsoLineEmitter)
91from ribs.schedulers import BanditScheduler, Scheduler
92from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
93
94CONFIG = {
95 "map_elites": {
96 "dim": 20,
97 "iters": 4500,
98 "archive_dims": (500, 500),
99 "use_result_archive": False,
100 "is_dqd": False,
101 "batch_size": 37,
102 "archive": {
103 "class": GridArchive,
104 "kwargs": {
105 "threshold_min": -np.inf
106 }
107 },
108 "emitters": [{
109 "class": GaussianEmitter,
110 "kwargs": {
111 "sigma": 0.5
112 },
113 "num_emitters": 15
114 }],
115 "scheduler": {
116 "class": Scheduler,
117 "kwargs": {}
118 }
119 },
120 "line_map_elites": {
121 "dim": 20,
122 "iters": 4500,
123 "archive_dims": (500, 500),
124 "use_result_archive": False,
125 "is_dqd": False,
126 "batch_size": 37,
127 "archive": {
128 "class": GridArchive,
129 "kwargs": {
130 "threshold_min": -np.inf
131 }
132 },
133 "emitters": [{
134 "class": IsoLineEmitter,
135 "kwargs": {
136 "iso_sigma": 0.1,
137 "line_sigma": 0.2
138 },
139 "num_emitters": 15
140 }],
141 "scheduler": {
142 "class": Scheduler,
143 "kwargs": {}
144 }
145 },
146 "cvt_map_elites": {
147 "dim": 20,
148 "iters": 4500,
149 "archive_dims": (500, 500),
150 "use_result_archive": False,
151 "is_dqd": False,
152 "batch_size": 37,
153 "archive": {
154 "class": CVTArchive,
155 "kwargs": {
156 "cells": 10_000,
157 "samples": 100_000,
158 "use_kd_tree": True
159 }
160 },
161 "emitters": [{
162 "class": GaussianEmitter,
163 "kwargs": {
164 "sigma": 0.5
165 },
166 "num_emitters": 15
167 }],
168 "scheduler": {
169 "class": Scheduler,
170 "kwargs": {}
171 }
172 },
173 "line_cvt_map_elites": {
174 "dim": 20,
175 "iters": 4500,
176 "archive_dims": (500, 500),
177 "use_result_archive": False,
178 "is_dqd": False,
179 "batch_size": 37,
180 "archive": {
181 "class": CVTArchive,
182 "kwargs": {
183 "cells": 10_000,
184 "samples": 100_000,
185 "use_kd_tree": True
186 }
187 },
188 "emitters": [{
189 "class": IsoLineEmitter,
190 "kwargs": {
191 "iso_sigma": 0.1,
192 "line_sigma": 0.2
193 },
194 "num_emitters": 15
195 }],
196 "scheduler": {
197 "class": Scheduler,
198 "kwargs": {}
199 }
200 },
201 "me_map_elites": {
202 "dim": 100,
203 "iters": 20_000,
204 "archive_dims": (100, 100),
205 "use_result_archive": False,
206 "is_dqd": False,
207 "batch_size": 50,
208 "archive": {
209 "class": GridArchive,
210 "kwargs": {
211 "threshold_min": -np.inf
212 }
213 },
214 "emitters": [{
215 "class": EvolutionStrategyEmitter,
216 "kwargs": {
217 "sigma0": 0.5,
218 "ranker": "obj"
219 },
220 "num_emitters": 12
221 }, {
222 "class": EvolutionStrategyEmitter,
223 "kwargs": {
224 "sigma0": 0.5,
225 "ranker": "2rd"
226 },
227 "num_emitters": 12
228 }, {
229 "class": EvolutionStrategyEmitter,
230 "kwargs": {
231 "sigma0": 0.5,
232 "ranker": "2imp"
233 },
234 "num_emitters": 12
235 }, {
236 "class": IsoLineEmitter,
237 "kwargs": {
238 "iso_sigma": 0.01,
239 "line_sigma": 0.1
240 },
241 "num_emitters": 12
242 }],
243 "scheduler": {
244 "class": BanditScheduler,
245 "kwargs": {
246 "num_active": 12,
247 "reselect": "terminated"
248 }
249 }
250 },
251 "cma_me_mixed": {
252 "dim": 20,
253 "iters": 4500,
254 "archive_dims": (500, 500),
255 "use_result_archive": False,
256 "is_dqd": False,
257 "batch_size": 37,
258 "archive": {
259 "class": GridArchive,
260 "kwargs": {
261 "threshold_min": -np.inf
262 }
263 },
264 "emitters": [{
265 "class": EvolutionStrategyEmitter,
266 "kwargs": {
267 "sigma0": 0.5,
268 "ranker": "2rd"
269 },
270 "num_emitters": 7
271 }, {
272 "class": EvolutionStrategyEmitter,
273 "kwargs": {
274 "sigma0": 0.5,
275 "ranker": "2imp"
276 },
277 "num_emitters": 8
278 }],
279 "scheduler": {
280 "class": Scheduler,
281 "kwargs": {}
282 }
283 },
284 "cma_me_imp": {
285 "dim": 20,
286 "iters": 4500,
287 "archive_dims": (500, 500),
288 "use_result_archive": False,
289 "is_dqd": False,
290 "batch_size": 37,
291 "archive": {
292 "class": GridArchive,
293 "kwargs": {
294 "threshold_min": -np.inf
295 }
296 },
297 "emitters": [{
298 "class": EvolutionStrategyEmitter,
299 "kwargs": {
300 "sigma0": 0.5,
301 "ranker": "2imp",
302 "selection_rule": "filter",
303 "restart_rule": "no_improvement"
304 },
305 "num_emitters": 15
306 }],
307 "scheduler": {
308 "class": Scheduler,
309 "kwargs": {}
310 }
311 },
312 "cma_me_imp_mu": {
313 "dim": 20,
314 "iters": 4500,
315 "archive_dims": (500, 500),
316 "use_result_archive": False,
317 "is_dqd": False,
318 "batch_size": 37,
319 "archive": {
320 "class": GridArchive,
321 "kwargs": {
322 "threshold_min": -np.inf
323 }
324 },
325 "emitters": [{
326 "class": EvolutionStrategyEmitter,
327 "kwargs": {
328 "sigma0": 0.5,
329 "ranker": "2imp",
330 "selection_rule": "mu",
331 "restart_rule": "no_improvement"
332 },
333 "num_emitters": 15
334 }],
335 "scheduler": {
336 "class": Scheduler,
337 "kwargs": {}
338 }
339 },
340 "cma_me_rd": {
341 "dim": 20,
342 "iters": 4500,
343 "archive_dims": (500, 500),
344 "use_result_archive": False,
345 "is_dqd": False,
346 "batch_size": 37,
347 "archive": {
348 "class": GridArchive,
349 "kwargs": {
350 "threshold_min": -np.inf
351 }
352 },
353 "emitters": [{
354 "class": EvolutionStrategyEmitter,
355 "kwargs": {
356 "sigma0": 0.5,
357 "ranker": "2rd",
358 "selection_rule": "filter",
359 "restart_rule": "no_improvement"
360 },
361 "num_emitters": 15
362 }],
363 "scheduler": {
364 "class": Scheduler,
365 "kwargs": {}
366 }
367 },
368 "cma_me_rd_mu": {
369 "dim": 20,
370 "iters": 4500,
371 "archive_dims": (500, 500),
372 "use_result_archive": False,
373 "is_dqd": False,
374 "batch_size": 37,
375 "archive": {
376 "class": GridArchive,
377 "kwargs": {
378 "threshold_min": -np.inf
379 }
380 },
381 "emitters": [{
382 "class": EvolutionStrategyEmitter,
383 "kwargs": {
384 "sigma0": 0.5,
385 "ranker": "2rd",
386 "selection_rule": "mu",
387 "restart_rule": "no_improvement"
388 },
389 "num_emitters": 15
390 }],
391 "scheduler": {
392 "class": Scheduler,
393 "kwargs": {}
394 }
395 },
396 "cma_me_opt": {
397 "dim": 20,
398 "iters": 4500,
399 "archive_dims": (500, 500),
400 "use_result_archive": False,
401 "is_dqd": False,
402 "batch_size": 37,
403 "archive": {
404 "class": GridArchive,
405 "kwargs": {
406 "threshold_min": -np.inf
407 }
408 },
409 "emitters": [{
410 "class": EvolutionStrategyEmitter,
411 "kwargs": {
412 "sigma0": 0.5,
413 "ranker": "obj",
414 "selection_rule": "mu",
415 "restart_rule": "basic"
416 },
417 "num_emitters": 15
418 }],
419 "scheduler": {
420 "class": Scheduler,
421 "kwargs": {}
422 }
423 },
424 "og_map_elites": {
425 "dim": 1_000,
426 "iters": 10_000,
427 "archive_dims": (100, 100),
428 "use_result_archive": False,
429 "is_dqd": True,
430 # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
431 # the other half are used in ask().
432 "batch_size": 36 // 2,
433 "archive": {
434 "class": GridArchive,
435 "kwargs": {
436 "threshold_min": -np.inf
437 }
438 },
439 "emitters": [{
440 "class": GradientOperatorEmitter,
441 "kwargs": {
442 "sigma": 0.5,
443 "sigma_g": 0.5,
444 "measure_gradients": False,
445 "normalize_grad": False,
446 },
447 "num_emitters": 1
448 }],
449 "scheduler": {
450 "class": Scheduler,
451 "kwargs": {}
452 }
453 },
454 "omg_mega": {
455 "dim": 1_000,
456 "iters": 10_000,
457 "archive_dims": (100, 100),
458 "use_result_archive": False,
459 "is_dqd": True,
460 # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
461 # the other half are used in ask().
462 "batch_size": 36 // 2,
463 "archive": {
464 "class": GridArchive,
465 "kwargs": {
466 "threshold_min": -np.inf
467 }
468 },
469 "emitters": [{
470 "class": GradientOperatorEmitter,
471 "kwargs": {
472 "sigma": 0.0,
473 "sigma_g": 10.0,
474 "measure_gradients": True,
475 "normalize_grad": True,
476 },
477 "num_emitters": 1
478 }],
479 "scheduler": {
480 "class": Scheduler,
481 "kwargs": {}
482 }
483 },
484 "cma_mega": {
485 "dim": 1_000,
486 "iters": 10_000,
487 "archive_dims": (100, 100),
488 "use_result_archive": False,
489 "is_dqd": True,
490 "batch_size": 35,
491 "archive": {
492 "class": GridArchive,
493 "kwargs": {
494 "threshold_min": -np.inf
495 }
496 },
497 "emitters": [{
498 "class": GradientArborescenceEmitter,
499 "kwargs": {
500 "sigma0": 10.0,
501 "lr": 1.0,
502 "grad_opt": "gradient_ascent",
503 "selection_rule": "mu"
504 },
505 "num_emitters": 1
506 }],
507 "scheduler": {
508 "class": Scheduler,
509 "kwargs": {}
510 }
511 },
512 "cma_mega_adam": {
513 "dim": 1_000,
514 "iters": 10_000,
515 "archive_dims": (100, 100),
516 "use_result_archive": False,
517 "is_dqd": True,
518 "batch_size": 35,
519 "archive": {
520 "class": GridArchive,
521 "kwargs": {
522 "threshold_min": -np.inf
523 }
524 },
525 "emitters": [{
526 "class": GradientArborescenceEmitter,
527 "kwargs": {
528 "sigma0": 10.0,
529 "lr": 0.002,
530 "grad_opt": "adam",
531 "selection_rule": "mu"
532 },
533 "num_emitters": 1
534 }],
535 "scheduler": {
536 "class": Scheduler,
537 "kwargs": {}
538 }
539 },
540 "cma_mae": {
541 "dim": 100,
542 "iters": 10_000,
543 "archive_dims": (100, 100),
544 "use_result_archive": True,
545 "is_dqd": False,
546 "batch_size": 36,
547 "archive": {
548 "class": GridArchive,
549 "kwargs": {
550 "threshold_min": 0,
551 "learning_rate": 0.01
552 }
553 },
554 "emitters": [{
555 "class": EvolutionStrategyEmitter,
556 "kwargs": {
557 "sigma0": 0.5,
558 "ranker": "imp",
559 "selection_rule": "mu",
560 "restart_rule": "basic"
561 },
562 "num_emitters": 15
563 }],
564 "scheduler": {
565 "class": Scheduler,
566 "kwargs": {}
567 }
568 },
569 "cma_maega": {
570 "dim": 1_000,
571 "iters": 10_000,
572 "archive_dims": (100, 100),
573 "use_result_archive": True,
574 "is_dqd": True,
575 "batch_size": 35,
576 "archive": {
577 "class": GridArchive,
578 "kwargs": {
579 "threshold_min": 0,
580 "learning_rate": 0.01
581 }
582 },
583 "emitters": [{
584 "class": GradientArborescenceEmitter,
585 "kwargs": {
586 "sigma0": 10.0,
587 "lr": 1.0,
588 "ranker": "imp",
589 "grad_opt": "gradient_ascent",
590 "restart_rule": "basic"
591 },
592 "num_emitters": 15
593 }],
594 "scheduler": {
595 "class": Scheduler,
596 "kwargs": {}
597 }
598 }
599}
600
601
602def sphere(solution_batch):
603 """Sphere function evaluation and measures for a batch of solutions.
604
605 Args:
606 solution_batch (np.ndarray): (batch_size, dim) batch of solutions.
607 Returns:
608 objective_batch (np.ndarray): (batch_size,) batch of objectives.
609 objective_grad_batch (np.ndarray): (batch_size, solution_dim) batch of
610 objective gradients.
611 measures_batch (np.ndarray): (batch_size, 2) batch of measures.
612 measures_grad_batch (np.ndarray): (batch_size, 2, solution_dim) batch of
613 measure gradients.
614 """
615 dim = solution_batch.shape[1]
616
617 # Shift the Sphere function so that the optimal value is at x_i = 2.048.
618 sphere_shift = 5.12 * 0.4
619
620 # Normalize the objective to the range [0, 100] where 100 is optimal.
621 best_obj = 0.0
622 worst_obj = (-5.12 - sphere_shift)**2 * dim
623 raw_obj = np.sum(np.square(solution_batch - sphere_shift), axis=1)
624 objective_batch = (raw_obj - worst_obj) / (best_obj - worst_obj) * 100
625
626 # Compute gradient of the objective.
627 objective_grad_batch = -2 * (solution_batch - sphere_shift)
628
629 # Calculate measures.
630 clipped = solution_batch.copy()
631 clip_mask = (clipped < -5.12) | (clipped > 5.12)
632 clipped[clip_mask] = 5.12 / clipped[clip_mask]
633 measures_batch = np.concatenate(
634 (
635 np.sum(clipped[:, :dim // 2], axis=1, keepdims=True),
636 np.sum(clipped[:, dim // 2:], axis=1, keepdims=True),
637 ),
638 axis=1,
639 )
640
641 # Compute gradient of the measures.
642 derivatives = np.ones(solution_batch.shape)
643 derivatives[clip_mask] = -5.12 / np.square(solution_batch[clip_mask])
644
645 mask_0 = np.concatenate((np.ones(dim // 2), np.zeros(dim - dim // 2)))
646 mask_1 = np.concatenate((np.zeros(dim // 2), np.ones(dim - dim // 2)))
647
648 d_measure0 = derivatives * mask_0
649 d_measure1 = derivatives * mask_1
650
651 measures_grad_batch = np.stack((d_measure0, d_measure1), axis=1)
652
653 return (
654 objective_batch,
655 objective_grad_batch,
656 measures_batch,
657 measures_grad_batch,
658 )
659
660
661def create_scheduler(config, algorithm, seed=None):
662 """Creates a scheduler based on the algorithm.
663
664 Args:
665 config (dict): Configuration dictionary with parameters for the various
666 components.
667 algorithm (string): Name of the algorithm.
668 seed (int): Main seed for the various components.
669 Returns:
670 ribs.schedulers.Scheduler: A ribs scheduler for running the algorithm.
671 """
672 solution_dim = config["dim"]
673 archive_dims = config["archive_dims"]
674 learning_rate = 1.0 if "learning_rate" not in config["archive"][
675 "kwargs"] else config["archive"]["kwargs"]["learning_rate"]
676 use_result_archive = config["use_result_archive"]
677 max_bound = solution_dim / 2 * 5.12
678 bounds = [(-max_bound, max_bound), (-max_bound, max_bound)]
679 initial_sol = np.zeros(solution_dim)
680 mode = "batch"
681
682 # Create archive.
683 archive_class = config["archive"]["class"]
684 if archive_class == GridArchive:
685 archive = archive_class(solution_dim=solution_dim,
686 ranges=bounds,
687 dims=archive_dims,
688 seed=seed,
689 **config["archive"]["kwargs"])
690 else:
691 archive = archive_class(solution_dim=solution_dim,
692 ranges=bounds,
693 seed=seed,
694 **config["archive"]["kwargs"])
695
696 # Create result archive.
697 result_archive = None
698 if use_result_archive:
699 result_archive = GridArchive(solution_dim=solution_dim,
700 dims=archive_dims,
701 ranges=bounds,
702 seed=seed)
703
704 # Create emitters. Each emitter needs a different seed so that they do not
705 # all do the same thing, hence we create an rng here to generate seeds. The
706 # rng may be seeded with None or with a user-provided seed.
707 seed_sequence = np.random.SeedSequence(seed)
708 emitters = []
709 for e in config["emitters"]:
710 emitter_class = e["class"]
711 emitters += [
712 emitter_class(
713 archive,
714 x0=initial_sol,
715 **e["kwargs"],
716 batch_size=config["batch_size"],
717 seed=s,
718 ) for s in seed_sequence.spawn(e["num_emitters"])
719 ]
720
721 # Create Scheduler
722 scheduler_class = config["scheduler"]["class"]
723 scheduler = scheduler_class(archive,
724 emitters,
725 result_archive=result_archive,
726 add_mode=mode,
727 **config["scheduler"]["kwargs"])
728 scheduler_name = scheduler.__class__.__name__
729
730 print(f"Create {scheduler_name} for {algorithm} with learning rate "
731 f"{learning_rate} and add mode {mode}, using solution dim "
732 f"{solution_dim}, archive dims {archive_dims}, and "
733 f"{len(emitters)} emitters.")
734 return scheduler
735
736
737def save_heatmap(archive, heatmap_path):
738 """Saves a heatmap of the archive to the given path.
739
740 Args:
741 archive (GridArchive or CVTArchive): The archive to save.
742 heatmap_path: Image path for the heatmap.
743 """
744 if isinstance(archive, GridArchive):
745 plt.figure(figsize=(8, 6))
746 grid_archive_heatmap(archive, vmin=0, vmax=100)
747 plt.tight_layout()
748 plt.savefig(heatmap_path)
749 elif isinstance(archive, CVTArchive):
750 plt.figure(figsize=(16, 12))
751 cvt_archive_heatmap(archive, vmin=0, vmax=100)
752 plt.tight_layout()
753 plt.savefig(heatmap_path)
754 plt.close(plt.gcf())
755
756
757def sphere_main(algorithm,
758 dim=None,
759 itrs=None,
760 archive_dims=None,
761 learning_rate=None,
762 es=None,
763 outdir="sphere_output",
764 log_freq=250,
765 seed=None):
766 """Demo on the Sphere function.
767
768 Args:
769 algorithm (str): Name of the algorithm.
770 dim (int): Dimensionality of the sphere function.
771 itrs (int): Iterations to run.
772 archive_dims (tuple): Dimensionality of the archive.
773 learning_rate (float): The archive learning rate.
774 es (str): If passed, this will set the ES for all
775 EvolutionStrategyEmitter instances.
776 outdir (str): Directory to save output.
777 log_freq (int): Number of iterations to wait before recording metrics
778 and saving heatmap.
779 seed (int): Seed for the algorithm. By default, there is no seed.
780 """
781 config = copy.deepcopy(CONFIG[algorithm])
782
783 # Use default dim for each algorithm.
784 if dim is not None:
785 config["dim"] = dim
786
787 # Use default itrs for each algorithm.
788 if itrs is not None:
789 config["iters"] = itrs
790
791 # Use default archive_dim for each algorithm.
792 if archive_dims is not None:
793 config["archive_dims"] = archive_dims
794
795 # Use default learning_rate for each algorithm.
796 if learning_rate is not None:
797 config["archive"]["kwargs"]["learning_rate"] = learning_rate
798
799 # Set ES for all EvolutionStrategyEmitter.
800 if es is not None:
801 for e in config["emitters"]:
802 if e["class"] == EvolutionStrategyEmitter:
803 e["kwargs"]["es"] = es
804
805 name = f"{algorithm}_{config['dim']}"
806 if es is not None:
807 name += f"_{es}"
808 outdir = Path(outdir)
809 if not outdir.is_dir():
810 outdir.mkdir()
811
812 scheduler = create_scheduler(config, algorithm, seed=seed)
813 result_archive = scheduler.result_archive
814 is_dqd = config["is_dqd"]
815 itrs = config["iters"]
816 metrics = {
817 "QD Score": {
818 "x": [0],
819 "y": [0.0],
820 },
821 "Archive Coverage": {
822 "x": [0],
823 "y": [0.0],
824 },
825 }
826
827 non_logging_time = 0.0
828 save_heatmap(result_archive, str(outdir / f"{name}_heatmap_{0:05d}.png"))
829
830 for itr in tqdm.trange(1, itrs + 1):
831 itr_start = time.time()
832
833 if is_dqd:
834 solution_batch = scheduler.ask_dqd()
835 (objective_batch, objective_grad_batch, measures_batch,
836 measures_grad_batch) = sphere(solution_batch)
837 objective_grad_batch = np.expand_dims(objective_grad_batch, axis=1)
838 jacobian_batch = np.concatenate(
839 (objective_grad_batch, measures_grad_batch), axis=1)
840 scheduler.tell_dqd(objective_batch, measures_batch, jacobian_batch)
841
842 solution_batch = scheduler.ask()
843 objective_batch, _, measure_batch, _ = sphere(solution_batch)
844 scheduler.tell(objective_batch, measure_batch)
845 non_logging_time += time.time() - itr_start
846
847 # Logging and output.
848 final_itr = itr == itrs
849 if itr % log_freq == 0 or final_itr:
850 if final_itr:
851 result_archive.data(return_type="pandas").to_csv(
852 outdir / f"{name}_archive.csv")
853
854 # Record and display metrics.
855 metrics["QD Score"]["x"].append(itr)
856 metrics["QD Score"]["y"].append(result_archive.stats.qd_score)
857 metrics["Archive Coverage"]["x"].append(itr)
858 metrics["Archive Coverage"]["y"].append(
859 result_archive.stats.coverage)
860 tqdm.tqdm.write(
861 f"Iteration {itr} | Archive Coverage: "
862 f"{metrics['Archive Coverage']['y'][-1] * 100:.3f}% "
863 f"QD Score: {metrics['QD Score']['y'][-1]:.3f}")
864
865 save_heatmap(result_archive,
866 str(outdir / f"{name}_heatmap_{itr:05d}.png"))
867
868 # Plot metrics.
869 print(f"Algorithm Time (Excludes Logging and Setup): {non_logging_time}s")
870 for metric, values in metrics.items():
871 plt.plot(values["x"], values["y"])
872 plt.title(metric)
873 plt.xlabel("Iteration")
874 plt.savefig(
875 str(outdir / f"{name}_{metric.lower().replace(' ', '_')}.png"))
876 plt.clf()
877 with (outdir / f"{name}_metrics.json").open("w") as file:
878 json.dump(metrics, file, indent=2)
879
880
881if __name__ == '__main__':
882 fire.Fire(sphere_main)