GenSim / cliport /demos.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame
No virus
3.88 kB
"""Data collection script."""
import os
import hydra
import numpy as np
import random
from cliport import tasks
from cliport.dataset import RavensDataset
from cliport.environments.environment import Environment
import IPython
import random
@hydra.main(config_path='./cfg', config_name='data')
def main(cfg):
# Initialize environment and task.
env = Environment(
cfg['assets_root'],
disp=cfg['disp'],
shared_memory=cfg['shared_memory'],
hz=480,
record_cfg=cfg['record']
)
task = tasks.names[cfg['task']]()
task.mode = cfg['mode']
record = cfg['record']['save_video']
save_data = cfg['save_data']
# Initialize scripted oracle agent and dataset.
agent = task.oracle(env)
data_path = os.path.join(cfg['data_dir'], "{}-{}".format(cfg['task'], task.mode))
dataset = RavensDataset(data_path, cfg, n_demos=0, augment=False)
print(f"Saving to: {data_path}")
print(f"Mode: {task.mode}")
# Train seeds are even and val/test seeds are odd. Test seeds are offset by 10000
seed = dataset.max_seed
max_eps = 3 * cfg['n']
if seed < 0:
if task.mode == 'train':
seed = -2
elif task.mode == 'val': # NOTE: beware of increasing val set to >100
seed = -1
elif task.mode == 'test':
seed = -1 + 10000
else:
raise Exception("Invalid mode. Valid options: train, val, test")
if 'regenerate_data' in cfg:
dataset.n_episodes = 0
curr_run_eps = 0
# Collect training data from oracle demonstrations.
while dataset.n_episodes < cfg['n'] and curr_run_eps < max_eps:
# for epi_idx in range(cfg['n']):
episode, total_reward = [], 0
seed += 2
# Set seeds.
np.random.seed(seed)
random.seed(seed)
print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, cfg['n'], seed))
try:
curr_run_eps += 1 # make sure exits the loop
env.set_task(task)
obs = env.reset()
info = env.info
reward = 0
# Unlikely, but a safety check to prevent leaks.
if task.mode == 'val' and seed > (-1 + 10000):
raise Exception("!!! Seeds for val set will overlap with the test set !!!")
# Start video recording (NOTE: super slow)
if record:
env.start_rec(f'{dataset.n_episodes+1:06d}')
# Rollout expert policy
for _ in range(task.max_steps):
act = agent.act(obs, info)
episode.append((obs, act, reward, info))
lang_goal = info['lang_goal']
obs, reward, done, info = env.step(act)
total_reward += reward
print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}')
if done:
break
if record:
env.end_rec()
except Exception as e:
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import TerminalFormatter
import traceback
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())
print(to_print)
if record:
env.end_rec()
continue
episode.append((obs, None, reward, info))
# Only save completed demonstrations.
if save_data and total_reward > 0.99:
dataset.add(seed, episode)
if hasattr(env, 'blender_recorder'):
print("blender pickle saved to ", '{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes))
env.blender_recorder.save('{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes))
if __name__ == '__main__':
main()