"""Integration tests for dvnets tasks.""" from absl.testing import absltest from absl.testing import parameterized from cliport import tasks from cliport.environments import environment ASSETS_PATH = 'cliport/environments/assets/' class TaskTest(parameterized.TestCase): def _create_env(self): assets_root = ASSETS_PATH env = environment.Environment(assets_root) env.seed(0) return env def _run_oracle_in_env(self, env): agent = env.task.oracle(env) obs = env.reset() info = None done = False for _ in range(10): act = agent.act(obs, info) obs, _, done, info = env.step(act) if done: break @parameterized.named_parameters(( # demo conditioned 'AlignBoxCorner', tasks.AlignBoxCorner(), ), ( 'AssemblingKits', tasks.AssemblingKits(), ), ( 'AssemblingKitsEasy', tasks.AssemblingKitsEasy(), ), ( 'BlockInsertion', tasks.BlockInsertion(), ), ( 'ManipulatingRope', tasks.ManipulatingRope(), ), ( 'PackingBoxes', tasks.PackingBoxes(), ), ( 'PalletizingBoxes', tasks.PalletizingBoxes(), ), ( 'PlaceRedInGreen', tasks.PlaceRedInGreen(), ), ( 'StackBlockPyramid', tasks.StackBlockPyramid(), ), ( 'SweepingPiles', tasks.SweepingPiles(), ), ( 'TowersOfHanoi', tasks.TowersOfHanoi(), # goal conditioned ), ( 'AlignRope', tasks.AlignRope(), ), ( 'AssemblingKitsSeqSeenColors', tasks.AssemblingKitsSeqSeenColors(), ), ( 'AssemblingKitsSeqUnseenColors', tasks.AssemblingKitsSeqUnseenColors(), ), ( 'AssemblingKitsSeqFull', tasks.AssemblingKitsSeqFull(), ), ( 'PackingShapes', tasks.PackingShapes(), ), ( 'PackingBoxesPairsSeenColors', tasks.PackingBoxesPairsSeenColors(), ), ( 'PackingBoxesPairsUnseenColors', tasks.PackingBoxesPairsUnseenColors(), ), ( 'PackingBoxesPairsFull', tasks.PackingBoxesPairsFull(), ), ( 'PackingSeenGoogleObjectsSeq', tasks.PackingSeenGoogleObjectsSeq(), ), ( 'PackingUnseenGoogleObjectsSeq', tasks.PackingUnseenGoogleObjectsSeq(), ), ( 'PackingSeenGoogleObjectsGroup', tasks.PackingSeenGoogleObjectsGroup(), ), ( 'PackingUnseenGoogleObjectsGroup', tasks.PackingUnseenGoogleObjectsGroup(), ), ( 'PutBlockInBowlSeenColors', tasks.PutBlockInBowlSeenColors(), ), ( 'PutBlockInBowlUnseenColors', tasks.PutBlockInBowlUnseenColors(), ), ( 'PutBlockInBowlFull', tasks.PutBlockInBowlFull(), ), ( 'StackBlockPyramidSeqSeenColors', tasks.StackBlockPyramidSeqSeenColors(), ), ( 'StackBlockPyramidSeqUnseenColors', tasks.StackBlockPyramidSeqUnseenColors(), ), ( 'StackBlockPyramidSeqFull', tasks.StackBlockPyramidSeqFull(), ), ( 'SeparatingPilesSeenColors', tasks.SeparatingPilesUnseenColors(), ), ( 'SeparatingPilesUnseenColors', tasks.SeparatingPilesUnseenColors(), ), ( 'SeparatingPilesFull', tasks.SeparatingPilesFull(), ), ( 'TowersOfHanoiSeqSeenColors', tasks.TowersOfHanoiSeqSeenColors(), ), ( 'TowersOfHanoiSeqUnseenColors', tasks.TowersOfHanoiSeqUnseenColors(), ), ( 'TowersOfHanoiSeqFull', tasks.TowersOfHanoiSeqFull(), )) def test_all_tasks(self, dvnets_task): env = self._create_env() env.set_task(dvnets_task) self._run_oracle_in_env(env) if __name__ == '__main__': absltest.main()