1"""Runs various QD algorithms on the Sphere function.
2
3Install the following dependencies before running this example:
4 pip install ribs[visualize] tqdm fire
5
6The sphere function in this example is adapted from Section 4 of Fontaine 2020
7(https://arxiv.org/abs/1912.02400). Namely, each solution value is clipped to
8the range [-5.12, 5.12], and the optimum is moved from [0,..] to [0.4 * 5.12 =
92.048,..]. Furthermore, the objectives are normalized to the range [0,
10100] where 100 is the maximum and corresponds to 0 on the original sphere
11function.
12
13There are two measures in this example. The first is the sum of the first n/2
14clipped values of the solution, and the second is the sum of the last n/2
15clipped values of the solution. Having each measure depend equally on several
16values in the solution space makes the problem more difficult (refer to
17Fontaine 2020 for more info).
18
19The supported algorithms are:
20- `map_elites`: GridArchive with GaussianEmitter.
21- `line_map_elites`: GridArchive with IsoLineEmitter.
22- `cvt_map_elites`: CVTArchive with GaussianEmitter.
23- `line_cvt_map_elites`: CVTArchive with IsoLineEmitter.
24- `me_map_elites`: MAP-Elites with Bandit Scheduler.
25- `cma_me_imp`: GridArchive with EvolutionStrategyEmitter using
26 TwoStageImprovmentRanker.
27- `cma_me_imp_mu`: GridArchive with EvolutionStrategyEmitter using
28 TwoStageImprovmentRanker and mu selection rule.
29- `cma_me_rd`: GridArchive with EvolutionStrategyEmitter using
30 RandomDirectionRanker.
31- `cma_me_rd_mu`: GridArchive with EvolutionStrategyEmitter using
32 TwoStageRandomDirectionRanker and mu selection rule.
33- `cma_me_opt`: GridArchive with EvolutionStrategyEmitter using ObjectiveRanker
34 with mu selection rule.
35- `cma_me_mixed`: GridArchive with EvolutionStrategyEmitter, where half (7) of
36 the emitter are using TwoStageRandomDirectionRanker and half (8) are
37 TwoStageImprovementRanker.
38- `og_map_elites`: GridArchive with GradientOperatorEmitter, does not use
39 measure gradients.
40- `omg_mega`: GridArchive with GradientOperatorEmitter, uses measure gradients.
41- `cma_mega`: GridArchive with GradientArborescenceEmitter.
42- `cma_mega_adam`: GridArchive with GradientArborescenceEmitter using Adam
43 Optimizer.
44- `cma_mae`: GridArchive (learning_rate = 0.01) with EvolutionStrategyEmitter
45 using ImprovementRanker.
46- `cma_maega`: GridArchive (learning_rate = 0.01) with
47 GradientArborescenceEmitter using ImprovementRanker.
48
49The parameters for each algorithm are stored in CONFIG. The parameters
50reproduce the experiments presented in the paper in which each algorithm is
51introduced.
52
53Outputs are saved in the `sphere_output/` directory by default. The archive is
54saved as a CSV named `{algorithm}_{dim}_archive.csv`, while snapshots of the
55heatmap are saved as `{algorithm}_{dim}_heatmap_{iteration}.png`. Metrics about
56the run are also saved in `{algorithm}_{dim}_metrics.json`, and plots of the
57metrics are saved in PNG's with the name `{algorithm}_{dim}_metric_name.png`.
58
59To generate a video of the heatmap from the heatmap images, use a tool like
60ffmpeg. For example, the following will generate a 6FPS video showing the
61heatmap for cma_me_imp with 20 dims.
62
63 ffmpeg -r 6 -i "sphere_output/cma_me_imp_20_heatmap_%*.png" \
64 sphere_output/cma_me_imp_20_heatmap_video.mp4
65
66Usage (see sphere_main function for all args or run `python sphere.py --help`):
67 python sphere.py ALGORITHM
68Example:
69 python sphere.py map_elites
70
71 # To make numpy and sklearn run single-threaded, set env variables for BLAS
72 # and OpenMP:
73 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 python sphere.py map_elites 20
74Help:
75 python sphere.py --help
76"""
77import copy
78import json
79import time
80from pathlib import Path
81
82import fire
83import matplotlib.pyplot as plt
84import numpy as np
85import tqdm
86
87from ribs.archives import CVTArchive, GridArchive
88from ribs.emitters import (EvolutionStrategyEmitter, GaussianEmitter,
89 GradientArborescenceEmitter, GradientOperatorEmitter,
90 IsoLineEmitter)
91from ribs.schedulers import BanditScheduler, Scheduler
92from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
93
94CONFIG = {
95 "map_elites": {
96 "dim": 20,
97 "iters": 4500,
98 "archive_dims": (500, 500),
99 "use_result_archive": False,
100 "is_dqd": False,
101 "batch_size": 37,
102 "archive": {
103 "class": GridArchive,
104 "kwargs": {
105 "threshold_min": -np.inf
106 }
107 },
108 "emitters": [{
109 "class": GaussianEmitter,
110 "kwargs": {
111 "sigma": 0.5
112 },
113 "num_emitters": 15
114 }],
115 "scheduler": {
116 "class": Scheduler,
117 "kwargs": {}
118 }
119 },
120 "line_map_elites": {
121 "dim": 20,
122 "iters": 4500,
123 "archive_dims": (500, 500),
124 "use_result_archive": False,
125 "is_dqd": False,
126 "batch_size": 37,
127 "archive": {
128 "class": GridArchive,
129 "kwargs": {
130 "threshold_min": -np.inf
131 }
132 },
133 "emitters": [{
134 "class": IsoLineEmitter,
135 "kwargs": {
136 "iso_sigma": 0.1,
137 "line_sigma": 0.2
138 },
139 "num_emitters": 15
140 }],
141 "scheduler": {
142 "class": Scheduler,
143 "kwargs": {}
144 }
145 },
146 "cvt_map_elites": {
147 "dim": 20,
148 "iters": 4500,
149 "archive_dims": (500, 500),
150 "use_result_archive": False,
151 "is_dqd": False,
152 "batch_size": 37,
153 "archive": {
154 "class": CVTArchive,
155 "kwargs": {
156 "cells": 10_000,
157 "samples": 100_000,
158 "use_kd_tree": True
159 }
160 },
161 "emitters": [{
162 "class": GaussianEmitter,
163 "kwargs": {
164 "sigma": 0.5
165 },
166 "num_emitters": 15
167 }],
168 "scheduler": {
169 "class": Scheduler,
170 "kwargs": {}
171 }
172 },
173 "line_cvt_map_elites": {
174 "dim": 20,
175 "iters": 4500,
176 "archive_dims": (500, 500),
177 "use_result_archive": False,
178 "is_dqd": False,
179 "batch_size": 37,
180 "archive": {
181 "class": CVTArchive,
182 "kwargs": {
183 "cells": 10_000,
184 "samples": 100_000,
185 "use_kd_tree": True
186 }
187 },
188 "emitters": [{
189 "class": IsoLineEmitter,
190 "kwargs": {
191 "iso_sigma": 0.1,
192 "line_sigma": 0.2
193 },
194 "num_emitters": 15
195 }],
196 "scheduler": {
197 "class": Scheduler,
198 "kwargs": {}
199 }
200 },
201 "me_map_elites": {
202 "dim": 100,
203 "iters": 20_000,
204 "archive_dims": (100, 100),
205 "use_result_archive": False,
206 "is_dqd": False,
207 "batch_size": 50,
208 "archive": {
209 "class": GridArchive,
210 "kwargs": {
211 "threshold_min": -np.inf
212 }
213 },
214 "emitters": [{
215 "class": EvolutionStrategyEmitter,
216 "kwargs": {
217 "sigma0": 0.5,
218 "ranker": "obj"
219 },
220 "num_emitters": 12
221 }, {
222 "class": EvolutionStrategyEmitter,
223 "kwargs": {
224 "sigma0": 0.5,
225 "ranker": "2rd"
226 },
227 "num_emitters": 12
228 }, {
229 "class": EvolutionStrategyEmitter,
230 "kwargs": {
231 "sigma0": 0.5,
232 "ranker": "2imp"
233 },
234 "num_emitters": 12
235 }, {
236 "class": IsoLineEmitter,
237 "kwargs": {
238 "iso_sigma": 0.01,
239 "line_sigma": 0.1
240 },
241 "num_emitters": 12
242 }],
243 "scheduler": {
244 "class": BanditScheduler,
245 "kwargs": {
246 "num_active": 12,
247 "reselect": "terminated"
248 }
249 }
250 },
251 "cma_me_mixed": {
252 "dim": 20,
253 "iters": 4500,
254 "archive_dims": (500, 500),
255 "use_result_archive": False,
256 "is_dqd": False,
257 "batch_size": 37,
258 "archive": {
259 "class": GridArchive,
260 "kwargs": {
261 "threshold_min": -np.inf
262 }
263 },
264 "emitters": [{
265 "class": EvolutionStrategyEmitter,
266 "kwargs": {
267 "sigma0": 0.5,
268 "ranker": "2rd"
269 },
270 "num_emitters": 7
271 }, {
272 "class": EvolutionStrategyEmitter,
273 "kwargs": {
274 "sigma0": 0.5,
275 "ranker": "2imp"
276 },
277 "num_emitters": 8
278 }],
279 "scheduler": {
280 "class": Scheduler,
281 "kwargs": {}
282 }
283 },
284 "cma_me_imp": {
285 "dim": 20,
286 "iters": 4500,
287 "archive_dims": (500, 500),
288 "use_result_archive": False,
289 "is_dqd": False,
290 "batch_size": 37,
291 "archive": {
292 "class": GridArchive,
293 "kwargs": {
294 "threshold_min": -np.inf
295 }
296 },
297 "emitters": [{
298 "class": EvolutionStrategyEmitter,
299 "kwargs": {
300 "sigma0": 0.5,
301 "ranker": "2imp",
302 "selection_rule": "filter",
303 "restart_rule": "no_improvement"
304 },
305 "num_emitters": 15
306 }],
307 "scheduler": {
308 "class": Scheduler,
309 "kwargs": {}
310 }
311 },
312 "cma_me_imp_mu": {
313 "dim": 20,
314 "iters": 4500,
315 "archive_dims": (500, 500),
316 "use_result_archive": False,
317 "is_dqd": False,
318 "batch_size": 37,
319 "archive": {
320 "class": GridArchive,
321 "kwargs": {
322 "threshold_min": -np.inf
323 }
324 },
325 "emitters": [{
326 "class": EvolutionStrategyEmitter,
327 "kwargs": {
328 "sigma0": 0.5,
329 "ranker": "2imp",
330 "selection_rule": "mu",
331 "restart_rule": "no_improvement"
332 },
333 "num_emitters": 15
334 }],
335 "scheduler": {
336 "class": Scheduler,
337 "kwargs": {}
338 }
339 },
340 "cma_me_rd": {
341 "dim": 20,
342 "iters": 4500,
343 "archive_dims": (500, 500),
344 "use_result_archive": False,
345 "is_dqd": False,
346 "batch_size": 37,
347 "archive": {
348 "class": GridArchive,
349 "kwargs": {
350 "threshold_min": -np.inf
351 }
352 },
353 "emitters": [{
354 "class": EvolutionStrategyEmitter,
355 "kwargs": {
356 "sigma0": 0.5,
357 "ranker": "2rd",
358 "selection_rule": "filter",
359 "restart_rule": "no_improvement"
360 },
361 "num_emitters": 15
362 }],
363 "scheduler": {
364 "class": Scheduler,
365 "kwargs": {}
366 }
367 },
368 "cma_me_rd_mu": {
369 "dim": 20,
370 "iters": 4500,
371 "archive_dims": (500, 500),
372 "use_result_archive": False,
373 "is_dqd": False,
374 "batch_size": 37,
375 "archive": {
376 "class": GridArchive,
377 "kwargs": {
378 "threshold_min": -np.inf
379 }
380 },
381 "emitters": [{
382 "class": EvolutionStrategyEmitter,
383 "kwargs": {
384 "sigma0": 0.5,
385 "ranker": "2rd",
386 "selection_rule": "mu",
387 "restart_rule": "no_improvement"
388 },
389 "num_emitters": 15
390 }],
391 "scheduler": {
392 "class": Scheduler,
393 "kwargs": {}
394 }
395 },
396 "cma_me_opt": {
397 "dim": 20,
398 "iters": 4500,
399 "archive_dims": (500, 500),
400 "use_result_archive": False,
401 "is_dqd": False,
402 "batch_size": 37,
403 "archive": {
404 "class": GridArchive,
405 "kwargs": {
406 "threshold_min": -np.inf
407 }
408 },
409 "emitters": [{
410 "class": EvolutionStrategyEmitter,
411 "kwargs": {
412 "sigma0": 0.5,
413 "ranker": "obj",
414 "selection_rule": "mu",
415 "restart_rule": "basic"
416 },
417 "num_emitters": 15
418 }],
419 "scheduler": {
420 "class": Scheduler,
421 "kwargs": {}
422 }
423 },
424 "og_map_elites": {
425 "dim": 1_000,
426 "iters": 10_000,
427 "archive_dims": (100, 100),
428 "use_result_archive": False,
429 "is_dqd": True,
430 # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
431 # the other half are used in ask().
432 "batch_size": 36 // 2,
433 "archive": {
434 "class": GridArchive,
435 "kwargs": {
436 "threshold_min": -np.inf
437 }
438 },
439 "emitters": [{
440 "class": GradientOperatorEmitter,
441 "kwargs": {
442 "sigma": 0.5,
443 "sigma_g": 0.5,
444 "measure_gradients": False,
445 "normalize_grad": False,
446 },
447 "num_emitters": 1
448 }],
449 "scheduler": {
450 "class": Scheduler,
451 "kwargs": {}
452 }
453 },
454 "omg_mega": {
455 "dim": 1_000,
456 "iters": 10_000,
457 "archive_dims": (100, 100),
458 "use_result_archive": False,
459 "is_dqd": True,
460 # Divide by 2 since half of the 36 solutions are used in ask_dqd(), and
461 # the other half are used in ask().
462 "batch_size": 36 // 2,
463 "archive": {
464 "class": GridArchive,
465 "kwargs": {
466 "threshold_min": -np.inf
467 }
468 },
469 "emitters": [{
470 "class": GradientOperatorEmitter,
471 "kwargs": {
472 "sigma": 0.0,
473 "sigma_g": 10.0,
474 "measure_gradients": True,
475 "normalize_grad": True,
476 },
477 "num_emitters": 1
478 }],
479 "scheduler": {
480 "class": Scheduler,
481 "kwargs": {}
482 }
483 },
484 "cma_mega": {
485 "dim": 1_000,
486 "iters": 10_000,
487 "archive_dims": (100, 100),
488 "use_result_archive": False,
489 "is_dqd": True,
490 "batch_size": 35,
491 "archive": {
492 "class": GridArchive,
493 "kwargs": {
494 "threshold_min": -np.inf
495 }
496 },
497 "emitters": [{
498 "class": GradientArborescenceEmitter,
499 "kwargs": {
500 "sigma0": 10.0,
501 "lr": 1.0,
502 "grad_opt": "gradient_ascent",
503 "selection_rule": "mu"
504 },
505 "num_emitters": 1
506 }],
507 "scheduler": {
508 "class": Scheduler,
509 "kwargs": {}
510 }
511 },
512 "cma_mega_adam": {
513 "dim": 1_000,
514 "iters": 10_000,
515 "archive_dims": (100, 100),
516 "use_result_archive": False,
517 "is_dqd": True,
518 "batch_size": 35,
519 "archive": {
520 "class": GridArchive,
521 "kwargs": {
522 "threshold_min": -np.inf
523 }
524 },
525 "emitters": [{
526 "class": GradientArborescenceEmitter,
527 "kwargs": {
528 "sigma0": 10.0,
529 "lr": 0.002,
530 "grad_opt": "adam",
531 "selection_rule": "mu"
532 },
533 "num_emitters": 1
534 }],
535 "scheduler": {
536 "class": Scheduler,
537 "kwargs": {}
538 }
539 },
540 "cma_mae": {
541 "dim": 100,
542 "iters": 10_000,
543 "archive_dims": (100, 100),
544 "use_result_archive": True,
545 "is_dqd": False,
546 "batch_size": 36,
547 "archive": {
548 "class": GridArchive,
549 "kwargs": {
550 "threshold_min": 0,
551 "learning_rate": 0.01
552 }
553 },
554 "emitters": [{
555 "class": EvolutionStrategyEmitter,
556 "kwargs": {
557 "sigma0": 0.5,
558 "ranker": "imp",
559 "selection_rule": "mu",
560 "restart_rule": "basic"
561 },
562 "num_emitters": 15
563 }],
564 "scheduler": {
565 "class": Scheduler,
566 "kwargs": {}
567 }
568 },
569 "cma_maega": {
570 "dim": 1_000,
571 "iters": 10_000,
572 "archive_dims": (100, 100),
573 "use_result_archive": True,
574 "is_dqd": True,
575 "batch_size": 35,
576 "archive": {
577 "class": GridArchive,
578 "kwargs": {
579 "threshold_min": 0,
580 "learning_rate": 0.01
581 }
582 },
583 "emitters": [{
584 "class": GradientArborescenceEmitter,
585 "kwargs": {
586 "sigma0": 10.0,
587 "lr": 1.0,
588 "ranker": "imp",
589 "grad_opt": "gradient_ascent",
590 "restart_rule": "basic"
591 },
592 "num_emitters": 15
593 }],
594 "scheduler": {
595 "class": Scheduler,
596 "kwargs": {}
597 }
598 }
599}
600
601
602def sphere(solution_batch):
603 """Sphere function evaluation and measures for a batch of solutions.
604
605 Args:
606 solution_batch (np.ndarray): (batch_size, dim) batch of solutions.
607 Returns:
608 objective_batch (np.ndarray): (batch_size,) batch of objectives.
609 objective_grad_batch (np.ndarray): (batch_size, solution_dim) batch of
610 objective gradients.
611 measures_batch (np.ndarray): (batch_size, 2) batch of measures.
612 measures_grad_batch (np.ndarray): (batch_size, 2, solution_dim) batch of
613 measure gradients.
614 """
615 dim = solution_batch.shape[1]
616
617 # Shift the Sphere function so that the optimal value is at x_i = 2.048.
618 sphere_shift = 5.12 * 0.4
619
620 # Normalize the objective to the range [0, 100] where 100 is optimal.
621 best_obj = 0.0
622 worst_obj = (-5.12 - sphere_shift)**2 * dim
623 raw_obj = np.sum(np.square(solution_batch - sphere_shift), axis=1)
624 objective_batch = (raw_obj - worst_obj) / (best_obj - worst_obj) * 100
625
626 # Compute gradient of the objective.
627 objective_grad_batch = -2 * (solution_batch - sphere_shift)
628
629 # Calculate measures.
630 clipped = solution_batch.copy()
631 clip_mask = (clipped < -5.12) | (clipped > 5.12)
632 clipped[clip_mask] = 5.12 / clipped[clip_mask]
633 measures_batch = np.concatenate(
634 (
635 np.sum(clipped[:, :dim // 2], axis=1, keepdims=True),
636 np.sum(clipped[:, dim // 2:], axis=1, keepdims=True),
637 ),
638 axis=1,
639 )
640
641 # Compute gradient of the measures.
642 derivatives = np.ones(solution_batch.shape)
643 derivatives[clip_mask] = -5.12 / np.square(solution_batch[clip_mask])
644
645 mask_0 = np.concatenate((np.ones(dim // 2), np.zeros(dim - dim // 2)))
646 mask_1 = np.concatenate((np.zeros(dim // 2), np.ones(dim - dim // 2)))
647
648 d_measure0 = derivatives * mask_0
649 d_measure1 = derivatives * mask_1
650
651 measures_grad_batch = np.stack((d_measure0, d_measure1), axis=1)
652
653 return (
654 objective_batch,
655 objective_grad_batch,
656 measures_batch,
657 measures_grad_batch,
658 )
659
660
661def create_scheduler(config, algorithm, seed=None):
662 """Creates a scheduler based on the algorithm.
663
664 Args:
665 config (dict): Configuration dictionary with parameters for the various
666 components.
667 algorithm (string): Name of the algorithm
668 seed (int): Main seed or the various components.
669 Returns:
670 ribs.schedulers.Scheduler: A ribs scheduler for running the algorithm.
671 """
672 solution_dim = config["dim"]
673 archive_dims = config["archive_dims"]
674 learning_rate = 1.0 if "learning_rate" not in config["archive"][
675 "kwargs"] else config["archive"]["kwargs"]["learning_rate"]
676 use_result_archive = config["use_result_archive"]
677 max_bound = solution_dim / 2 * 5.12
678 bounds = [(-max_bound, max_bound), (-max_bound, max_bound)]
679 initial_sol = np.zeros(solution_dim)
680 mode = "batch"
681
682 # Create archive.
683 archive_class = config["archive"]["class"]
684 if archive_class == GridArchive:
685 archive = archive_class(solution_dim=solution_dim,
686 ranges=bounds,
687 dims=archive_dims,
688 seed=seed,
689 **config["archive"]["kwargs"])
690 else:
691 archive = archive_class(solution_dim=solution_dim,
692 ranges=bounds,
693 **config["archive"]["kwargs"])
694
695 # Create result archive.
696 result_archive = None
697 if use_result_archive:
698 result_archive = GridArchive(solution_dim=solution_dim,
699 dims=archive_dims,
700 ranges=bounds,
701 seed=seed)
702
703 # Create emitters. Each emitter needs a different seed so that they do not
704 # all do the same thing, hence we create an rng here to generate seeds. The
705 # rng may be seeded with None or with a user-provided seed.
706 rng = np.random.default_rng(seed)
707 emitters = []
708 for e in config["emitters"]:
709 emitter_class = e["class"]
710 emitters += [
711 emitter_class(archive,
712 x0=initial_sol,
713 **e["kwargs"],
714 batch_size=config["batch_size"],
715 seed=s)
716 for s in rng.integers(0, 1_000_000, e["num_emitters"])
717 ]
718
719 # Create Scheduler
720 scheduler_class = config["scheduler"]["class"]
721 scheduler = scheduler_class(archive,
722 emitters,
723 result_archive=result_archive,
724 add_mode=mode,
725 **config["scheduler"]["kwargs"])
726 scheduler_name = scheduler.__class__.__name__
727
728 print(f"Create {scheduler_name} for {algorithm} with learning rate "
729 f"{learning_rate} and add mode {mode}, using solution dim "
730 f"{solution_dim}, archive dims {archive_dims}, and "
731 f"{len(emitters)} emitters.")
732 return scheduler
733
734
735def save_heatmap(archive, heatmap_path):
736 """Saves a heatmap of the archive to the given path.
737
738 Args:
739 archive (GridArchive or CVTArchive): The archive to save.
740 heatmap_path: Image path for the heatmap.
741 """
742 if isinstance(archive, GridArchive):
743 plt.figure(figsize=(8, 6))
744 grid_archive_heatmap(archive, vmin=0, vmax=100)
745 plt.tight_layout()
746 plt.savefig(heatmap_path)
747 elif isinstance(archive, CVTArchive):
748 plt.figure(figsize=(16, 12))
749 cvt_archive_heatmap(archive, vmin=0, vmax=100)
750 plt.tight_layout()
751 plt.savefig(heatmap_path)
752 plt.close(plt.gcf())
753
754
755def sphere_main(algorithm,
756 dim=None,
757 itrs=None,
758 archive_dims=None,
759 learning_rate=None,
760 outdir="sphere_output",
761 log_freq=250,
762 seed=None):
763 """Demo on the Sphere function.
764
765 Args:
766 algorithm (str): Name of the algorithm.
767 dim (int): Dimensionality of the sphere function.
768 itrs (int): Iterations to run.
769 archive_dims (tuple): Dimensionality of the archive.
770 learning_rate (float): The archive learning rate.
771 outdir (str): Directory to save output.
772 log_freq (int): Number of iterations to wait before recording metrics
773 and saving heatmap.
774 seed (int): Seed for the algorithm. By default, there is no seed.
775 """
776 config = copy.deepcopy(CONFIG[algorithm])
777
778 # Use default dim for each algorithm.
779 if dim is not None:
780 config["dim"] = dim
781
782 # Use default itrs for each algorithm.
783 if itrs is not None:
784 config["iters"] = itrs
785
786 # Use default archive_dim for each algorithm.
787 if archive_dims is not None:
788 config["archive_dims"] = archive_dims
789
790 # Use default learning_rate for each algorithm.
791 if learning_rate is not None:
792 config["archive"]["kwargs"]["learning_rate"] = learning_rate
793
794 name = f"{algorithm}_{config['dim']}"
795 outdir = Path(outdir)
796 if not outdir.is_dir():
797 outdir.mkdir()
798
799 scheduler = create_scheduler(config, algorithm, seed=seed)
800 result_archive = scheduler.result_archive
801 is_dqd = config["is_dqd"]
802 itrs = config["iters"]
803 metrics = {
804 "QD Score": {
805 "x": [0],
806 "y": [0.0],
807 },
808 "Archive Coverage": {
809 "x": [0],
810 "y": [0.0],
811 },
812 }
813
814 non_logging_time = 0.0
815 save_heatmap(result_archive, str(outdir / f"{name}_heatmap_{0:05d}.png"))
816
817 for itr in tqdm.trange(1, itrs + 1):
818 itr_start = time.time()
819
820 if is_dqd:
821 solution_batch = scheduler.ask_dqd()
822 (objective_batch, objective_grad_batch, measures_batch,
823 measures_grad_batch) = sphere(solution_batch)
824 objective_grad_batch = np.expand_dims(objective_grad_batch, axis=1)
825 jacobian_batch = np.concatenate(
826 (objective_grad_batch, measures_grad_batch), axis=1)
827 scheduler.tell_dqd(objective_batch, measures_batch, jacobian_batch)
828
829 solution_batch = scheduler.ask()
830 objective_batch, _, measure_batch, _ = sphere(solution_batch)
831 scheduler.tell(objective_batch, measure_batch)
832 non_logging_time += time.time() - itr_start
833
834 # Logging and output.
835 final_itr = itr == itrs
836 if itr % log_freq == 0 or final_itr:
837 if final_itr:
838 result_archive.as_pandas(include_solutions=final_itr).to_csv(
839 outdir / f"{name}_archive.csv")
840
841 # Record and display metrics.
842 metrics["QD Score"]["x"].append(itr)
843 metrics["QD Score"]["y"].append(result_archive.stats.qd_score)
844 metrics["Archive Coverage"]["x"].append(itr)
845 metrics["Archive Coverage"]["y"].append(
846 result_archive.stats.coverage)
847 tqdm.tqdm.write(
848 f"Iteration {itr} | Archive Coverage: "
849 f"{metrics['Archive Coverage']['y'][-1] * 100:.3f}% "
850 f"QD Score: {metrics['QD Score']['y'][-1]:.3f}")
851
852 save_heatmap(result_archive,
853 str(outdir / f"{name}_heatmap_{itr:05d}.png"))
854
855 # Plot metrics.
856 print(f"Algorithm Time (Excludes Logging and Setup): {non_logging_time}s")
857 for metric, values in metrics.items():
858 plt.plot(values["x"], values["y"])
859 plt.title(metric)
860 plt.xlabel("Iteration")
861 plt.savefig(
862 str(outdir / f"{name}_{metric.lower().replace(' ', '_')}.png"))
863 plt.clf()
864 with (outdir / f"{name}_metrics.json").open("w") as file:
865 json.dump(metrics, file, indent=2)
866
867
868if __name__ == '__main__':
869 fire.Fire(sphere_main)