mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
Fix raw.arrow missing rows (#1145)
* fix raw.arrow missing rows --------- Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
@@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|||||||
out_dir.mkdir(exist_ok=True, parents=True)
|
out_dir.mkdir(exist_ok=True, parents=True)
|
||||||
print(f"\nSaving to {out_dir} ...")
|
print(f"\nSaving to {out_dir} ...")
|
||||||
|
|
||||||
# Save dataset with improved batch size for better I/O performance
|
|
||||||
raw_arrow_path = out_dir / "raw.arrow"
|
raw_arrow_path = out_dir / "raw.arrow"
|
||||||
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
|
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
|
||||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||||
writer.write(line)
|
writer.write(line)
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
# Save durations to JSON
|
# Save durations to JSON
|
||||||
dur_json_path = out_dir / "duration.json"
|
dur_json_path = out_dir / "duration.json"
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ def main():
|
|||||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||||
writer.write(line)
|
writer.write(line)
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ def main():
|
|||||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||||
writer.write(line)
|
writer.write(line)
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ def main():
|
|||||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||||
writer.write(line)
|
writer.write(line)
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ def main():
|
|||||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||||
writer.write(line)
|
writer.write(line)
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -796,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
|||||||
min_second = round(min(duration_list), 2)
|
min_second = round(min(duration_list), 2)
|
||||||
max_second = round(max(duration_list), 2)
|
max_second = round(max(duration_list), 2)
|
||||||
|
|
||||||
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
|
with ArrowWriter(path=file_raw) as writer:
|
||||||
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
||||||
writer.write(line)
|
writer.write(line)
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
with open(file_duration, "w") as f:
|
with open(file_duration, "w") as f:
|
||||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user