diff --git a/.gitignore b/.gitignore index 73c66b60..506d5d35 100644 --- a/.gitignore +++ b/.gitignore @@ -168,7 +168,5 @@ Dockerfile start_docker.sh start.sh -checkpoints - # Mac .DS_Store diff --git a/app_sadtalker.py b/app_sadtalker.py index 1401a600..deed3937 100644 --- a/app_sadtalker.py +++ b/app_sadtalker.py @@ -1,111 +1,185 @@ -import os, sys +import os +import sys import gradio as gr -from src.gradio_demo import SadTalker +import logging +from src.gradio_demo import SadTalker +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) +# Check if running in webui environment try: - import webui # in webui + import webui in_webui = True -except: +except ImportError: in_webui = False +def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None): + """Creates a Gradio interface for SadTalker.""" + # Validate paths + try: + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist.") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config path {config_path} does not exist.") + logger.info(f"Checkpoint path: {checkpoint_path}") + logger.info(f"Config path: {config_path}") + except FileNotFoundError as e: + logger.error(str(e)) + raise -def toggle_audio_file(choice): - if choice == False: - return gr.update(visible=True), gr.update(visible=False) - else: - return gr.update(visible=False), gr.update(visible=True) - -def ref_video_fn(path_of_ref_video): - if path_of_ref_video is not None: - return gr.update(value=True) - else: - return gr.update(value=False) + # Validate checkpoint files + required_checkpoints = [ + 'SadTalker_V0.0.2_256.safetensors', # Updated to match logs + 'mapping_00229-model.pth.tar' + ] + for chk in required_checkpoints: + chk_path = os.path.join(checkpoint_path, chk) + if not os.path.exists(chk_path): + logger.error(f"Missing checkpoint file: {chk_path}") + raise FileNotFoundError(f"Missing checkpoint file: {chk_path}") -def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None): + # Initialize SadTalker + try: + sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True) + logger.info("SadTalker initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize SadTalker: {str(e)}") + raise - sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True) + with gr.Blocks() as sadtalker_interface: + gr.Markdown( + """ +
+

😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)

+ Arxiv    + Homepage    + Github +
+ """ + ) - with gr.Blocks(analytics_enabled=False) as sadtalker_interface: - gr.Markdown("

😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)

\ - Arxiv       \ - Homepage       \ - Github
") - - with gr.Row().style(equal_height=False): + with gr.Row(): + # Input Column with gr.Column(variant='panel'): - with gr.Tabs(elem_id="sadtalker_source_image"): - with gr.TabItem('Upload image'): - with gr.Row(): - source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512) + gr.Markdown("### Upload Inputs") + source_image = gr.File( + label="Source Image (PNG/JPG, ideally 512x512)", + file_types=[".png", ".jpg", ".jpeg"], + interactive=True + ) + image_preview = gr.Image(label="Image Preview", interactive=False) + image_status = gr.Textbox(label="Image Upload Status", interactive=False) + driven_audio = gr.Audio( + label="Input Audio (WAV/MP3)", + type="filepath", + interactive=True + ) + audio_status = gr.Textbox(label="Audio Upload Status", interactive=False) + error_message = gr.Textbox(label="Error Message", interactive=False, visible=False) - with gr.Tabs(elem_id="sadtalker_driven_audio"): - with gr.TabItem('Upload OR TTS'): - with gr.Column(variant='panel'): - driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath") + # Log upload events + source_image.upload( + fn=lambda x: (f"Image uploaded: {x}", x), + inputs=[source_image], + outputs=[image_status, image_preview] + ) + driven_audio.upload( + fn=lambda x: f"Audio uploaded: {x}", + inputs=[driven_audio], + outputs=[audio_status] + ) - if sys.platform != 'win32' and not in_webui: - from src.utils.text2speech import TTSTalker - tts_talker = TTSTalker() - with gr.Column(variant='panel'): - input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.") - tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary') - tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio]) - - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="sadtalker_checkbox"): - with gr.TabItem('Settings'): - gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials") - with gr.Column(variant='panel'): - # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width - # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width - pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) # - size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") # - preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?") - is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)") - batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2) - enhancer = gr.Checkbox(label="GFPGAN as Face enhancer") - submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary') - - with gr.Tabs(elem_id="sadtalker_genearted"): - gen_video = gr.Video(label="Generated video", format="mp4").style(width=256) - - if warpfn: - submit.click( - fn=warpfn(sad_talker.test), - inputs=[source_image, - driven_audio, - preprocess_type, - is_still_mode, - enhancer, - batch_size, - size_of_image, - pose_style - ], - outputs=[gen_video] - ) - else: - submit.click( - fn=sad_talker.test, - inputs=[source_image, - driven_audio, - preprocess_type, - is_still_mode, - enhancer, - batch_size, - size_of_image, - pose_style - ], - outputs=[gen_video] - ) + # Settings and Output Column + with gr.Column(variant='panel'): + gr.Markdown("### Settings and Output") + pose_style = gr.Slider( + minimum=0, + maximum=46, + step=1, + label="Pose Style", + value=0 + ) + size_of_image = gr.Radio( + [256, 512], + value=256, + label="Face Model Resolution", + info="Use 256/512 model?" + ) + preprocess_type = gr.Radio( + ['crop', 'resize', 'full', 'extcrop', 'extfull'], + value='crop', + label="Preprocess", + info="How to handle input image?" + ) + is_still_mode = gr.Checkbox( + label="Still Mode (fewer head motions, works with preprocess 'full')" + ) + batch_size = gr.Slider( + label="Batch Size in Generation", + step=1, + minimum=1, + maximum=10, + value=2 + ) + enhancer = gr.Checkbox(label="GFPGAN as Face Enhancer") + submit = gr.Button('Generate', variant='primary') + gen_video = gr.Video( + label="Generated Video (MP4)", + interactive=False + ) - return sadtalker_interface - + # Submit button logic with input validation + def validate_and_run(source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style): + try: + logger.info(f"Received inputs: image={source_image}, audio={driven_audio}") + # Handle gr.File and gr.Audio inputs + image_path = source_image.name if hasattr(source_image, 'name') else source_image + audio_path = driven_audio if driven_audio else None + logger.info(f"Processed paths: image={image_path}, audio={audio_path}") + if not image_path or not os.path.exists(image_path): + logger.error("Invalid or missing source image") + return None, "Error: Please upload a valid source image." + if not audio_path or not os.path.exists(audio_path): + logger.error("Invalid or missing audio file") + return None, "Error: Please upload a valid audio file." + if not isinstance(audio_path, (str, bytes, os.PathLike)): + logger.error(f"Invalid audio path type: {type(audio_path)}") + return None, f"Error: Audio input must be a file path, not {type(audio_path)}." + logger.info("Inputs validated successfully") + result = (warpfn(sad_talker.test) if warpfn else sad_talker.test)( + image_path, audio_path, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style + ) + logger.info("Video generation completed") + return result, None + except Exception as e: + logger.error(f"Error during video generation: {str(e)}") + return None, f"Error: {str(e)}" -if __name__ == "__main__": + submit.click( + fn=validate_and_run, + inputs=[ + source_image, + driven_audio, + preprocess_type, + is_still_mode, + enhancer, + batch_size, + size_of_image, + pose_style + ], + outputs=[gen_video, error_message] + ) - demo = sadtalker_demo() - demo.queue() - demo.launch() + return sadtalker_interface +if __name__ == "__main__": + try: + demo = sadtalker_demo() + demo.queue() + demo.launch() + except Exception as e: + logger.error(f"Failed to launch Gradio interface: {str(e)}") + raise diff --git a/launcher.py b/launcher.py index 17ce9f1a..222d5813 100644 --- a/launcher.py +++ b/launcher.py @@ -201,4 +201,4 @@ def start(): if __name__ == "__main__": prepare_environment() - start() \ No newline at end of file + start() diff --git a/src/audio2pose_models/discriminator.py b/src/audio2pose_models/discriminator.py index 339c38e4..a73fde93 100644 --- a/src/audio2pose_models/discriminator.py +++ b/src/audio2pose_models/discriminator.py @@ -73,4 +73,4 @@ def forward(self, x): x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) x = self.seq(x) x = x.squeeze(1) - return x \ No newline at end of file + return x diff --git a/src/audio2pose_models/networks.py b/src/audio2pose_models/networks.py index 8aa0b139..3390eb0b 100644 --- a/src/audio2pose_models/networks.py +++ b/src/audio2pose_models/networks.py @@ -137,4 +137,4 @@ def __init__(self, input_encoder, input_decoder, output_dim): def forward(self, x1, x2): out = self.conv_encoder(x1) + self.conv_decoder(x2) out = self.conv_attn(out) - return out * x2 \ No newline at end of file + return out * x2 diff --git a/src/audio2pose_models/res_unet.py b/src/audio2pose_models/res_unet.py index f2611e1d..6f046062 100644 --- a/src/audio2pose_models/res_unet.py +++ b/src/audio2pose_models/res_unet.py @@ -62,4 +62,4 @@ def forward(self, x): output = self.output_layer(x10) - return output \ No newline at end of file + return output diff --git a/src/face3d/models/arcface_torch/backbones/__init__.py b/src/face3d/models/arcface_torch/backbones/__init__.py index 55bd4c5d..fe2e219f 100644 --- a/src/face3d/models/arcface_torch/backbones/__init__.py +++ b/src/face3d/models/arcface_torch/backbones/__init__.py @@ -22,4 +22,4 @@ def get_model(name, **kwargs): num_features = kwargs.get("num_features", 512) return get_mbf(fp16=fp16, num_features=num_features) else: - raise ValueError() \ No newline at end of file + raise ValueError() diff --git a/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py index 87731491..36ccc303 100644 --- a/src/face3d/models/arcface_torch/backbones/mobilefacenet.py +++ b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py @@ -127,4 +127,4 @@ def forward(self, x): def get_mbf(fp16, num_features): - return MobileFaceNet(fp16, num_features) \ No newline at end of file + return MobileFaceNet(fp16, num_features) diff --git a/src/face3d/models/arcface_torch/torch2onnx.py b/src/face3d/models/arcface_torch/torch2onnx.py index fc26ab82..60bd090a 100644 --- a/src/face3d/models/arcface_torch/torch2onnx.py +++ b/src/face3d/models/arcface_torch/torch2onnx.py @@ -6,7 +6,7 @@ def convert_onnx(net, path_module, output, opset=11, simplify=False): assert isinstance(net, torch.nn.Module) img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) - img = img.astype(np.float) + img = img.astype(float) img = (img / 255. - 0.5) / 0.5 # torch style norm img = img.transpose((2, 0, 1)) img = torch.from_numpy(img).unsqueeze(0).float() diff --git a/src/face3d/models/arcface_torch/utils/utils_config.py b/src/face3d/models/arcface_torch/utils/utils_config.py index 0c02eaf7..8d0efbf7 100644 --- a/src/face3d/models/arcface_torch/utils/utils_config.py +++ b/src/face3d/models/arcface_torch/utils/utils_config.py @@ -13,4 +13,4 @@ def get_config(config_file): cfg.update(job_cfg) if cfg.output is None: cfg.output = osp.join('work_dirs', temp_module_name) - return cfg \ No newline at end of file + return cfg diff --git a/src/face3d/models/bfm.py b/src/face3d/models/bfm.py index a75db682..13a78e06 100644 --- a/src/face3d/models/bfm.py +++ b/src/face3d/models/bfm.py @@ -328,4 +328,4 @@ def compute_for_render_woRotation(self, coeffs): if __name__ == '__main__': - transferBFM09() \ No newline at end of file + transferBFM09() diff --git a/src/face3d/util/load_mats.py b/src/face3d/util/load_mats.py index f9a6fcc7..d4a0f83c 100644 --- a/src/face3d/util/load_mats.py +++ b/src/face3d/util/load_mats.py @@ -117,4 +117,4 @@ def load_lm3d(bfm_folder): if __name__ == '__main__': - transferBFM09() \ No newline at end of file + transferBFM09() diff --git a/src/face3d/util/my_awing_arch.py b/src/face3d/util/my_awing_arch.py index cd565617..308752f6 100644 --- a/src/face3d/util/my_awing_arch.py +++ b/src/face3d/util/my_awing_arch.py @@ -15,7 +15,7 @@ def calculate_points(heatmaps): indexes = np.argmax(heatline, axis=2) preds = np.stack((indexes % W, indexes // W), axis=2) - preds = preds.astype(np.float, copy=False) + preds = preds.astype(float, copy=False) inr = indexes.ravel() diff --git a/src/face3d/util/util.py b/src/face3d/util/util.py index 0d689ca1..975bf2c9 100644 --- a/src/face3d/util/util.py +++ b/src/face3d/util/util.py @@ -120,7 +120,7 @@ def print_numpy(x, val=True, shp=False): val (bool) -- if print the values of the numpy array shp (bool) -- if print the shape of the numpy array """ - x = x.astype(np.float64) + x = x.astype(float64) if shp: print('shape,', x.shape) if val: diff --git a/src/facerender/modules/generator.py b/src/facerender/modules/generator.py index 5a9edcb3..345b23e9 100644 --- a/src/facerender/modules/generator.py +++ b/src/facerender/modules/generator.py @@ -252,4 +252,4 @@ def forward(self, source_image, kp_driving, kp_source): output_dict["prediction"] = out - return output_dict \ No newline at end of file + return output_dict diff --git a/src/facerender/modules/make_animation.py b/src/facerender/modules/make_animation.py index 3360c535..4808253b 100644 --- a/src/facerender/modules/make_animation.py +++ b/src/facerender/modules/make_animation.py @@ -167,4 +167,4 @@ def forward(self, x): self.mapping, use_exp = True, yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) - return predictions_video \ No newline at end of file + return predictions_video diff --git a/src/facerender/modules/mapping.py b/src/facerender/modules/mapping.py index 0e3a1c2d..728d4b0f 100644 --- a/src/facerender/modules/mapping.py +++ b/src/facerender/modules/mapping.py @@ -44,4 +44,4 @@ def forward(self, input_3dmm): t = self.fc_t(out) exp = self.fc_exp(out) - return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} \ No newline at end of file + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} diff --git a/src/facerender/modules/util.py b/src/facerender/modules/util.py index b916deef..db2cbb1b 100644 --- a/src/facerender/modules/util.py +++ b/src/facerender/modules/util.py @@ -561,4 +561,4 @@ def forward(self, source_image, target_audio): kp_source = self.keypoint_transformation(kp_canonical, pose_source) kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) - return generated \ No newline at end of file + return generated diff --git a/src/gradio_demo.py b/src/gradio_demo.py index 9c9ae056..59fc3805 100644 --- a/src/gradio_demo.py +++ b/src/gradio_demo.py @@ -152,4 +152,4 @@ def test(self, source_image, driven_audio, preprocess='crop', return return_path - \ No newline at end of file + diff --git a/src/utils/init_path.py b/src/utils/init_path.py index 5f38d119..62ecc1b7 100644 --- a/src/utils/init_path.py +++ b/src/utils/init_path.py @@ -44,4 +44,4 @@ def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preproces sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') - return sadtalker_paths \ No newline at end of file + return sadtalker_paths diff --git a/src/utils/model2safetensor.py b/src/utils/model2safetensor.py index 50c48500..ccf2802c 100644 --- a/src/utils/model2safetensor.py +++ b/src/utils/model2safetensor.py @@ -138,4 +138,4 @@ def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") ### test -load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) \ No newline at end of file +load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) diff --git a/src/utils/safetensor_helper.py b/src/utils/safetensor_helper.py index 3cdbdd21..9959a330 100644 --- a/src/utils/safetensor_helper.py +++ b/src/utils/safetensor_helper.py @@ -5,4 +5,4 @@ def load_x_from_safetensor(checkpoint, key): for k,v in checkpoint.items(): if key in k: x_generator[k.replace(key+'.', '')] = v - return x_generator \ No newline at end of file + return x_generator diff --git a/src/utils/videoio.py b/src/utils/videoio.py index 08bfbdd7..98cf746a 100644 --- a/src/utils/videoio.py +++ b/src/utils/videoio.py @@ -38,4 +38,4 @@ def save_video_with_watermark(video, audio, save_path, watermark=False): cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) os.system(cmd) - os.remove(temp_file) \ No newline at end of file + os.remove(temp_file)