On this tutorial, we construct a safety-critical reinforcement studying pipeline that learns fully from fastened, offline information reasonably than dwell exploration. We design a customized setting, generate a conduct dataset from a constrained coverage, after which practice each a Conduct Cloning baseline and a Conservative Q-Studying agent utilizing d3rlpy. By structuring the workflow round offline datasets, cautious analysis, and conservative studying aims, we reveal how strong decision-making insurance policies may be educated in settings the place unsafe exploration isn’t an choice. Try the FULL CODES right here.
!pip -q set up -U "d3rlpy" "gymnasium" "numpy" "torch" "matplotlib" "scikit-learn"
import os
import time
import random
import examine
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gymnasium
from gymnasium import areas
import torch
import d3rlpy
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
def pick_device():
if torch.cuda.is_available():
return "cuda:0"
return "cpu"
DEVICE = pick_device()
print("d3rlpy:", getattr(d3rlpy, "__version__", "unknown"), "| torch:", torch.__version__, "| gadget:", DEVICE)
def make_config(cls, **kwargs):
sig = examine.signature(cls.__init__)
allowed = set(sig.parameters.keys())
allowed.discard("self")
filtered = {ok: v for ok, v in kwargs.gadgets() if ok in allowed}
return cls(**filtered)
We arrange the setting by putting in dependencies, importing libraries, and fixing random seeds for reproducibility. We detect and configure the computation gadget to make sure constant execution throughout methods. We additionally outline a utility to create configuration objects safely throughout totally different d3rlpy variations. Try the FULL CODES right here.
class SafetyCriticalGridWorld(gymnasium.Env):
metadata = {"render_modes": []}
def __init__(
self,
measurement=15,
max_steps=80,
hazard_coords=None,
begin=(0, 0),
purpose=None,
slip_prob=0.05,
seed=0,
):
tremendous().__init__()
self.measurement = int(measurement)
self.max_steps = int(max_steps)
self.begin = tuple(begin)
self.purpose = tuple(purpose) if purpose isn't None else (self.measurement - 1, self.measurement - 1)
self.slip_prob = float(slip_prob)
if hazard_coords is None:
hz = set()
rng = np.random.default_rng(seed)
for _ in vary(max(1, self.measurement // 2)):
x = rng.integers(2, self.measurement - 2)
y = rng.integers(2, self.measurement - 2)
hz.add((int(x), int(y)))
self.hazards = hz
else:
self.hazards = set(tuple(x) for x in hazard_coords)
self.action_space = areas.Discrete(4)
self.observation_space = areas.Field(low=0.0, excessive=float(self.measurement - 1), form=(2,), dtype=np.float32)
self._rng = np.random.default_rng(seed)
self._pos = None
self._t = 0
def reset(self, *, seed=None, choices=None):
if seed isn't None:
self._rng = np.random.default_rng(seed)
self._pos = [int(self.start[0]), int(self.begin[1])]
self._t = 0
obs = np.array(self._pos, dtype=np.float32)
return obs, {}
def _clip(self):
self._pos[0] = int(np.clip(self._pos[0], 0, self.measurement - 1))
self._pos[1] = int(np.clip(self._pos[1], 0, self.measurement - 1))
def step(self, motion):
self._t += 1
a = int(motion)
if self._rng.random() < self.slip_prob:
a = int(self._rng.integers(0, 4))
if a == 0:
self._pos[1] += 1
elif a == 1:
self._pos[0] += 1
elif a == 2:
self._pos[1] -= 1
elif a == 3:
self._pos[0] -= 1
self._clip()
x, y = int(self._pos[0]), int(self._pos[1])
terminated = False
truncated = self._t >= self.max_steps
reward = -1.0
if (x, y) in self.hazards:
reward = -100.0
terminated = True
if (x, y) == self.purpose:
reward = +50.0
terminated = True
obs = np.array([x, y], dtype=np.float32)
return obs, float(reward), terminated, truncated, {}
We outline a safety-critical GridWorld setting with hazards, terminal states, and stochastic transitions. We encode penalties for unsafe states and rewards for profitable process completion. We make sure the setting strictly controls dynamics to replicate real-world security constraints. Try the FULL CODES right here.
def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
x, y = int(obs[0]), int(obs[1])
gx, gy = env.purpose
most popular = []
if gx > x:
most popular.append(1)
elif gx < x:
most popular.append(3)
if gy > y:
most popular.append(0)
elif gy < y:
most popular.append(2)
if len(most popular) == 0:
most popular = [int(env._rng.integers(0, 4))]
if env._rng.random() < epsilon:
return int(env._rng.integers(0, 4))
candidates = []
for a in most popular:
nx, ny = x, y
if a == 0:
ny += 1
elif a == 1:
nx += 1
elif a == 2:
ny -= 1
elif a == 3:
nx -= 1
nx = int(np.clip(nx, 0, env.measurement - 1))
ny = int(np.clip(ny, 0, env.measurement - 1))
if (nx, ny) not in env.hazards:
candidates.append(a)
if len(candidates) == 0:
return most popular[0]
return int(random.alternative(candidates))
def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
episodes = []
for i in vary(n_episodes):
obs, _ = env.reset(seed=int(seed + i))
obs_list = []
act_list = []
rew_list = []
done_list = []
achieved = False
whereas not achieved:
a = safe_behavior_policy(obs, env, epsilon=epsilon)
nxt, r, terminated, truncated, _ = env.step(a)
achieved = bool(terminated or truncated)
obs_list.append(np.array(obs, dtype=np.float32))
act_list.append(np.array([a], dtype=np.int64))
rew_list.append(np.array([r], dtype=np.float32))
done_list.append(np.array([1.0 if done else 0.0], dtype=np.float32))
obs = nxt
episodes.append(
{
"observations": np.stack(obs_list, axis=0),
"actions": np.stack(act_list, axis=0),
"rewards": np.stack(rew_list, axis=0),
"terminals": np.stack(done_list, axis=0),
}
)
return episodes
def build_mdpdataset(episodes):
obs = np.concatenate([ep["observations"] for ep in episodes], axis=0).astype(np.float32)
acts = np.concatenate([ep["actions"] for ep in episodes], axis=0).astype(np.int64)
rews = np.concatenate([ep["rewards"] for ep in episodes], axis=0).astype(np.float32)
phrases = np.concatenate([ep["terminals"] for ep in episodes], axis=0).astype(np.float32)
if hasattr(d3rlpy, "dataset") and hasattr(d3rlpy.dataset, "MDPDataset"):
return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=phrases)
elevate RuntimeError("d3rlpy.dataset.MDPDataset not discovered. Improve d3rlpy.")
We design a constrained conduct coverage that generates offline information with out dangerous exploration. We roll out this coverage to gather trajectories and construction them into episodes. We then convert these episodes right into a format suitable with d3rlpy’s offline studying APIs. Try the FULL CODES right here.
def _get_episodes_from_dataset(dataset):
if hasattr(dataset, "episodes") and dataset.episodes isn't None:
return dataset.episodes
if hasattr(dataset, "get_episodes"):
return dataset.get_episodes()
elevate AttributeError("Couldn't discover episodes in dataset (d3rlpy model mismatch).")
def _iter_all_observations(dataset):
for ep in _get_episodes_from_dataset(dataset):
obs = getattr(ep, "observations", None)
if obs is None:
proceed
for o in obs:
yield o
def _iter_all_transitions(dataset):
for ep in _get_episodes_from_dataset(dataset):
obs = getattr(ep, "observations", None)
acts = getattr(ep, "actions", None)
rews = getattr(ep, "rewards", None)
if obs is None or acts is None:
proceed
n = min(len(obs), len(acts))
for i in vary(n):
o = obs[i]
a = acts[i]
r = rews[i] if rews isn't None and that i < len(rews) else None
yield o, a, r
def visualize_dataset(dataset, env, title="Offline Dataset"):
state_visits = np.zeros((env.measurement, env.measurement), dtype=np.float32)
for obs in _iter_all_observations(dataset):
x, y = int(obs[0]), int(obs[1])
x = int(np.clip(x, 0, env.measurement - 1))
y = int(np.clip(y, 0, env.measurement - 1))
state_visits[y, x] += 1
plt.determine(figsize=(6, 5))
plt.imshow(state_visits, origin="decrease")
plt.colorbar(label="Visits")
plt.scatter([env.start[0]], [env.start[1]], marker="o", label="begin")
plt.scatter([env.goal[0]], [env.goal[1]], marker="*", label="purpose")
if len(env.hazards) > 0:
hz = np.array(checklist(env.hazards), dtype=np.int32)
plt.scatter(hz[:, 0], hz[:, 1], marker="x", label="hazards")
plt.title(f"{title} — State visitation")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.present()
rewards = []
for _, _, r in _iter_all_transitions(dataset):
if r isn't None:
rewards.append(float(r))
if len(rewards) > 0:
plt.determine(figsize=(6, 4))
plt.hist(rewards, bins=60)
plt.title(f"{title} — Reward distribution")
plt.xlabel("reward")
plt.ylabel("rely")
plt.present()
We implement dataset utilities that accurately iterate by episodes reasonably than assuming flat arrays. We visualize state visitation to know protection and information bias within the offline dataset. We additionally analyze reward distributions to examine the training sign out there to the agent. Try the FULL CODES right here.
def rollout_eval(env, algo, n_episodes=25, seed=0):
returns = []
lengths = []
hazard_hits = 0
goal_hits = 0
for i in vary(n_episodes):
obs, _ = env.reset(seed=seed + i)
achieved = False
complete = 0.0
steps = 0
whereas not achieved:
a = int(algo.predict(np.asarray(obs, dtype=np.float32)[None, ...])[0])
obs, r, terminated, truncated, _ = env.step(a)
complete += float(r)
steps += 1
achieved = bool(terminated or truncated)
if terminated:
x, y = int(obs[0]), int(obs[1])
if (x, y) in env.hazards:
hazard_hits += 1
if (x, y) == env.purpose:
goal_hits += 1
returns.append(complete)
lengths.append(steps)
return {
"return_mean": float(np.imply(returns)),
"return_std": float(np.std(returns)),
"len_mean": float(np.imply(lengths)),
"hazard_rate": float(hazard_hits / max(1, n_episodes)),
"goal_rate": float(goal_hits / max(1, n_episodes)),
"returns": np.asarray(returns, dtype=np.float32),
}
def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
rng = np.random.default_rng(seed)
obs_all = []
act_all = []
for o, a, _ in _iter_all_transitions(dataset):
obs_all.append(np.asarray(o, dtype=np.float32))
act_all.append(int(np.asarray(a).reshape(-1)[0]))
if len(obs_all) >= 80_000:
break
obs_all = np.stack(obs_all, axis=0)
act_all = np.asarray(act_all, dtype=np.int64)
idx = rng.alternative(len(obs_all), measurement=min(sample_obs, len(obs_all)), substitute=False)
obs_probe = obs_all[idx]
act_probe_data = act_all[idx]
act_probe_pi = algo.predict(obs_probe).astype(np.int64)
mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
return float(mismatch.imply())
def create_discrete_bc(gadget):
if hasattr(d3rlpy.algos, "DiscreteBCConfig"):
cls = d3rlpy.algos.DiscreteBCConfig
cfg = make_config(
cls,
learning_rate=3e-4,
batch_size=256,
)
return cfg.create(gadget=gadget)
if hasattr(d3rlpy.algos, "DiscreteBC"):
return d3rlpy.algos.DiscreteBC()
elevate RuntimeError("DiscreteBC not out there on this d3rlpy model.")
def create_discrete_cql(gadget, conservative_weight=6.0):
if hasattr(d3rlpy.algos, "DiscreteCQLConfig"):
cls = d3rlpy.algos.DiscreteCQLConfig
cfg = make_config(
cls,
learning_rate=3e-4,
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
temp_learning_rate=3e-4,
alpha_learning_rate=3e-4,
batch_size=256,
conservative_weight=float(conservative_weight),
n_action_samples=10,
rollout_interval=0,
)
return cfg.create(gadget=gadget)
if hasattr(d3rlpy.algos, "DiscreteCQL"):
algo = d3rlpy.algos.DiscreteCQL()
if hasattr(algo, "conservative_weight"):
attempt:
algo.conservative_weight = float(conservative_weight)
besides Exception:
cross
return algo
elevate RuntimeError("DiscreteCQL not out there on this d3rlpy model.")
We outline managed analysis routines to measure coverage efficiency with out uncontrolled exploration. We compute returns and security metrics, together with hazard and purpose charges. We additionally introduce a mismatch diagnostic to quantify how typically discovered actions deviate from the dataset conduct. Try the FULL CODES right here.
def primary():
env = SafetyCriticalGridWorld(
measurement=15,
max_steps=80,
slip_prob=0.05,
seed=SEED,
)
raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
dataset = build_mdpdataset(raw_eps)
print("dataset constructed:", kind(dataset).__name__)
visualize_dataset(dataset, env, title="Conduct Dataset (Offline)")
bc = create_discrete_bc(DEVICE)
cql = create_discrete_cql(DEVICE, conservative_weight=6.0)
print("nTraining Discrete BC (offline)...")
t0 = time.time()
bc.match(
dataset,
n_steps=25_000,
n_steps_per_epoch=2_500,
experiment_name="grid_bc_offline",
)
print("BC practice sec:", spherical(time.time() - t0, 2))
print("nTraining Discrete CQL (offline)...")
t0 = time.time()
cql.match(
dataset,
n_steps=80_000,
n_steps_per_epoch=8_000,
experiment_name="grid_cql_offline",
)
print("CQL practice sec:", spherical(time.time() - t0, 2))
print("nControlled on-line analysis (small variety of rollouts):")
bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)
print("BC :", {ok: v for ok, v in bc_metrics.gadgets() if ok != "returns"})
print("CQL:", {ok: v for ok, v in cql_metrics.gadgets() if ok != "returns"})
print("nOOD-ish diagnostic (coverage motion mismatch vs information motion at identical states):")
bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
print("BC mismatch charge :", bc_mismatch)
print("CQL mismatch charge:", cql_mismatch)
plt.determine(figsize=(6, 4))
labels = ["BC", "CQL"]
means = [bc_metrics["return_mean"], cql_metrics["return_mean"]]
stds = [bc_metrics["return_std"], cql_metrics["return_std"]]
plt.bar(labels, means, yerr=stds)
plt.ylabel("Return")
plt.title("On-line Rollout Return (Managed)")
plt.present()
plt.determine(figsize=(6, 4))
plt.plot(np.type(bc_metrics["returns"]), label="BC")
plt.plot(np.type(cql_metrics["returns"]), label="CQL")
plt.xlabel("Episode (sorted)")
plt.ylabel("Return")
plt.title("Return Distribution (Sorted)")
plt.legend()
plt.present()
out_dir = "/content material/offline_rl_artifacts"
os.makedirs(out_dir, exist_ok=True)
bc_path = os.path.be part of(out_dir, "grid_bc_policy.pt")
cql_path = os.path.be part of(out_dir, "grid_cql_policy.pt")
if hasattr(bc, "save_policy"):
bc.save_policy(bc_path)
print("Saved BC coverage:", bc_path)
if hasattr(cql, "save_policy"):
cql.save_policy(cql_path)
print("Saved CQL coverage:", cql_path)
print("nDone.")
if __name__ == "__main__":
primary()
We practice each Conduct Cloning and Conservative Q-Studying brokers purely from offline information. We examine their efficiency utilizing managed rollouts and diagnostic metrics. We finalize the workflow by saving educated insurance policies and summarizing safety-aware studying outcomes.
In conclusion, we demonstrated that Conservative Q-Studying yields a extra dependable coverage than easy imitation when studying from historic information in safety-sensitive environments. By evaluating offline coaching outcomes, managed on-line evaluations, and action-distribution mismatches, we illustrated how conservatism helps cut back dangerous, out-of-distribution conduct. Total, we introduced a whole, reproducible offline RL workflow that we are able to lengthen to extra complicated domains similar to robotics, healthcare, or finance with out compromising security.
Try the FULL CODES right here. Additionally, be happy to comply with us on Twitter and don’t neglect to hitch our 100k+ ML SubReddit and Subscribe to our Publication. Wait! are you on telegram? now you’ll be able to be part of us on telegram as effectively.
