-
Notifications
You must be signed in to change notification settings - Fork 1
/
expl.py
124 lines (107 loc) · 4.21 KB
/
expl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import tensorflow as tf
from tensorflow_probability import distributions as tfd
import agent
import common
class Random(common.Module):
def __init__(self, config, act_space, wm, tfstep, reward):
self.config = config
self.act_space = self.act_space
def actor(self, feat):
shape = feat.shape[:-1] + self.act_space.shape
if self.config.actor.dist == 'onehot':
return common.OneHotDist(tf.zeros(shape))
else:
dist = tfd.Uniform(-tf.ones(shape), tf.ones(shape))
return tfd.Independent(dist, 1)
def train(self, start, context, data):
return None, {}
class Plan2Explore(common.Module):
def __init__(self, config, act_space, wm, tfstep, reward):
self.config = config
self.reward = reward
self.wm = wm
self.ac = agent.ActorCritic(config, act_space, tfstep)
self.actor = self.ac.actor
stoch_size = config.rssm.stoch
if config.rssm.discrete:
stoch_size *= config.rssm.discrete
size = {
'embed': 32 * config.encoder.cnn_depth,
'stoch': stoch_size,
'deter': config.rssm.deter,
'feat': config.rssm.stoch + config.rssm.deter,
}[self.config.disag_target]
self._networks = [
common.MLP(size, **config.expl_head)
for _ in range(config.disag_models)]
self.opt = common.Optimizer('expl', **config.expl_opt)
self.extr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm)
self.intr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm)
def train(self, start, context, data):
metrics = {}
stoch = start['stoch']
if self.config.rssm.discrete:
stoch = tf.reshape(
stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
target = {
'embed': context['embed'],
'stoch': stoch,
'deter': start['deter'],
'feat': context['feat'],
}[self.config.disag_target]
inputs = context['feat']
if self.config.disag_action_cond:
action = tf.cast(data['action'], inputs.dtype)
inputs = tf.concat([inputs, action], -1)
metrics.update(self._train_ensemble(inputs, target))
metrics.update(self.ac.train(
self.wm, start, data['is_terminal'], self._intr_reward))
return None, metrics
def _intr_reward(self, seq):
inputs = seq['feat']
if self.config.disag_action_cond:
action = tf.cast(seq['action'], inputs.dtype)
inputs = tf.concat([inputs, action], -1)
preds = [head(inputs).mode() for head in self._networks]
disag = tf.tensor(preds).std(0).mean(-1)
if self.config.disag_log:
disag = tf.math.log(disag)
reward = self.config.expl_intr_scale * self.intr_rewnorm(disag)[0]
if self.config.expl_extr_scale:
reward += self.config.expl_extr_scale * self.extr_rewnorm(
self.reward(seq))[0]
return reward
def _train_ensemble(self, inputs, targets):
if self.config.disag_offset:
targets = targets[:, self.config.disag_offset:]
inputs = inputs[:, :-self.config.disag_offset]
targets = tf.stop_gradient(targets)
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs) for head in self._networks]
loss = -sum([pred.log_prob(targets).mean() for pred in preds])
metrics = self.opt(tape, loss, self._networks)
return metrics
class ModelLoss(common.Module):
def __init__(self, config, act_space, wm, tfstep, reward):
self.config = config
self.reward = reward
self.wm = wm
self.ac = agent.ActorCritic(config, act_space, tfstep)
self.actor = self.ac.actor
self.head = common.MLP([], **self.config.expl_head)
self.opt = common.Optimizer('expl', **self.config.expl_opt)
def train(self, start, context, data):
metrics = {}
target = tf.cast(context[self.config.expl_model_loss], tf.float32)
with tf.GradientTape() as tape:
loss = -self.head(context['feat']).log_prob(target).mean()
metrics.update(self.opt(tape, loss, self.head))
metrics.update(self.ac.train(
self.wm, start, data['is_terminal'], self._intr_reward))
return None, metrics
def _intr_reward(self, seq):
reward = self.config.expl_intr_scale * self.head(seq['feat']).mode()
if self.config.expl_extr_scale:
reward += self.config.expl_extr_scale * self.reward(seq)
return reward