File size: 5,560 Bytes
ad48e75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import random
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Type

# from lhotse import CutSet
# from lhotse.dataset.collation import collate_features
# from lhotse.dataset.input_strategies import (
#     ExecutorType,
#     PrecomputedFeatures,
#     _get_executor,
# )
# from lhotse.utils import fastcopy


class PromptedFeatures:
    def __init__(self, prompts, features):
        self.prompts = prompts
        self.features = features

    def to(self, device):
        return PromptedFeatures(
            self.prompts.to(device), self.features.to(device)
        )

    def sum(self):
        return self.features.sum()

    @property
    def ndim(self):
        return self.features.ndim

    @property
    def data(self):
        return (self.prompts, self.features)


# class PromptedPrecomputedFeatures(PrecomputedFeatures):
#     """
#     :class:`InputStrategy` that reads pre-computed features, whose manifests
#     are attached to cuts, from disk.
#
#     It automatically pads the feature matrices with pre or post feature.
#
#     .. automethod:: __call__
#     """
#
#     def __init__(
#         self,
#         dataset: str,
#         cuts: CutSet,
#         num_workers: int = 0,
#         executor_type: Type[ExecutorType] = ThreadPoolExecutor,
#     ) -> None:
#         super(PromptedPrecomputedFeatures, self).__init__(
#             num_workers, executor_type
#         )
#
#         self.utt2neighbors = defaultdict(lambda: [])
#
#         if dataset.lower() == "libritts":
#             # 909_131041_000013_000002
#             # 909_131041_000013_000003
#             speaker2utts = defaultdict(lambda: [])
#
#             utt2cut = {}
#             for cut in cuts:
#                 speaker = cut.supervisions[0].speaker
#                 speaker2utts[speaker].append(cut.id)
#                 utt2cut[cut.id] = cut
#
#             for spk in speaker2utts:
#                 uttids = sorted(speaker2utts[spk])
#                 # Using the property of sorted keys to find previous utterance
#                 # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
#                 if len(uttids) == 1:
#                     self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
#                     continue
#
#                 utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
#                 utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
#
#                 for utt in utt2prevutt:
#                     self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
#
#                 for utt in utt2postutt:
#                     self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
#         elif dataset.lower() == "ljspeech":
#             utt2cut = {}
#             uttids = []
#             for cut in cuts:
#                 uttids.append(cut.id)
#                 utt2cut[cut.id] = cut
#
#             if len(uttids) == 1:
#                 self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
#             else:
#                 # Using the property of sorted keys to find previous utterance
#                 # The keys has structure: LJ001-0010
#                 utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
#                 utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
#
#                 for utt in utt2postutt:
#                     postutt = utt2postutt[utt]
#                     if utt[:5] == postutt[:5]:
#                         self.utt2neighbors[utt].append(utt2cut[postutt])
#
#                 for utt in utt2prevutt:
#                     prevutt = utt2prevutt[utt]
#                     if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
#                         self.utt2neighbors[utt].append(utt2cut[prevutt])
#         else:
#             raise ValueError
#
#     def __call__(
#         self, cuts: CutSet
#     ) -> Tuple[PromptedFeatures, PromptedFeatures]:
#         """
#         Reads the pre-computed features from disk/other storage.
#         The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
#
#         :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
#         """
#         features, features_lens = collate_features(
#             cuts,
#             executor=_get_executor(
#                 self.num_workers, executor_type=self._executor_type
#             ),
#         )
#
#         prompts_cuts = []
#         for k, cut in enumerate(cuts):
#             prompts_cut = random.choice(self.utt2neighbors[cut.id])
#             prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
#
#         mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
#         # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
#         #     max_duration=mini_duration,
#         #     offset_type="random",
#         #     preserve_id=True,
#         # )
#         prompts_cuts = CutSet(
#             cuts={k: cut for k, cut in enumerate(prompts_cuts)}
#         ).truncate(
#             max_duration=mini_duration,
#             offset_type="random",
#             preserve_id=False,
#         )
#
#         prompts, prompts_lens = collate_features(
#             prompts_cuts,
#             executor=_get_executor(
#                 self.num_workers, executor_type=self._executor_type
#             ),
#         )
#
#         return PromptedFeatures(prompts, features), PromptedFeatures(
#             prompts_lens, features_lens
#         )