Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,5 @@ Dockerfile
start_docker.sh
start.sh

checkpoints

# Mac
.DS_Store
258 changes: 166 additions & 92 deletions app_sadtalker.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,185 @@
import os, sys
import os
import sys
import gradio as gr
from src.gradio_demo import SadTalker
import logging
from src.gradio_demo import SadTalker

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Check if running in webui environment
try:
import webui # in webui
import webui
in_webui = True
except:
except ImportError:
in_webui = False

def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):
"""Creates a Gradio interface for SadTalker."""
# Validate paths
try:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist.")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config path {config_path} does not exist.")
logger.info(f"Checkpoint path: {checkpoint_path}")
logger.info(f"Config path: {config_path}")
except FileNotFoundError as e:
logger.error(str(e))
raise

def toggle_audio_file(choice):
if choice == False:
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)

def ref_video_fn(path_of_ref_video):
if path_of_ref_video is not None:
return gr.update(value=True)
else:
return gr.update(value=False)
# Validate checkpoint files
required_checkpoints = [
'SadTalker_V0.0.2_256.safetensors', # Updated to match logs
'mapping_00229-model.pth.tar'
]
for chk in required_checkpoints:
chk_path = os.path.join(checkpoint_path, chk)
if not os.path.exists(chk_path):
logger.error(f"Missing checkpoint file: {chk_path}")
raise FileNotFoundError(f"Missing checkpoint file: {chk_path}")

def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):
# Initialize SadTalker
try:
sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)
logger.info("SadTalker initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize SadTalker: {str(e)}")
raise

sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)
with gr.Blocks() as sadtalker_interface:
gr.Markdown(
"""
<div align='center'>
<h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </h2>
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a>   
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a>   
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'>Github</a>
</div>
"""
)

with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")

with gr.Row().style(equal_height=False):
with gr.Row():
# Input Column
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.Row():
source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
gr.Markdown("### Upload Inputs")
source_image = gr.File(
label="Source Image (PNG/JPG, ideally 512x512)",
file_types=[".png", ".jpg", ".jpeg"],
interactive=True
)
image_preview = gr.Image(label="Image Preview", interactive=False)
image_status = gr.Textbox(label="Image Upload Status", interactive=False)
driven_audio = gr.Audio(
label="Input Audio (WAV/MP3)",
type="filepath",
interactive=True
)
audio_status = gr.Textbox(label="Audio Upload Status", interactive=False)
error_message = gr.Textbox(label="Error Message", interactive=False, visible=False)

with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload OR TTS'):
with gr.Column(variant='panel'):
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
# Log upload events
source_image.upload(
fn=lambda x: (f"Image uploaded: {x}", x),
inputs=[source_image],
outputs=[image_status, image_preview]
)
driven_audio.upload(
fn=lambda x: f"Audio uploaded: {x}",
inputs=[driven_audio],
outputs=[audio_status]
)

if sys.platform != 'win32' and not in_webui:
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
with gr.Column(variant='panel'):
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])

with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials")
with gr.Column(variant='panel'):
# width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
# height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')

with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)

if warpfn:
submit.click(
fn=warpfn(sad_talker.test),
inputs=[source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer,
batch_size,
size_of_image,
pose_style
],
outputs=[gen_video]
)
else:
submit.click(
fn=sad_talker.test,
inputs=[source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer,
batch_size,
size_of_image,
pose_style
],
outputs=[gen_video]
)
# Settings and Output Column
with gr.Column(variant='panel'):
gr.Markdown("### Settings and Output")
pose_style = gr.Slider(
minimum=0,
maximum=46,
step=1,
label="Pose Style",
value=0
)
size_of_image = gr.Radio(
[256, 512],
value=256,
label="Face Model Resolution",
info="Use 256/512 model?"
)
preprocess_type = gr.Radio(
['crop', 'resize', 'full', 'extcrop', 'extfull'],
value='crop',
label="Preprocess",
info="How to handle input image?"
)
is_still_mode = gr.Checkbox(
label="Still Mode (fewer head motions, works with preprocess 'full')"
)
batch_size = gr.Slider(
label="Batch Size in Generation",
step=1,
minimum=1,
maximum=10,
value=2
)
enhancer = gr.Checkbox(label="GFPGAN as Face Enhancer")
submit = gr.Button('Generate', variant='primary')
gen_video = gr.Video(
label="Generated Video (MP4)",
interactive=False
)

return sadtalker_interface

# Submit button logic with input validation
def validate_and_run(source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style):
try:
logger.info(f"Received inputs: image={source_image}, audio={driven_audio}")
# Handle gr.File and gr.Audio inputs
image_path = source_image.name if hasattr(source_image, 'name') else source_image
audio_path = driven_audio if driven_audio else None
logger.info(f"Processed paths: image={image_path}, audio={audio_path}")
if not image_path or not os.path.exists(image_path):
logger.error("Invalid or missing source image")
return None, "Error: Please upload a valid source image."
if not audio_path or not os.path.exists(audio_path):
logger.error("Invalid or missing audio file")
return None, "Error: Please upload a valid audio file."
if not isinstance(audio_path, (str, bytes, os.PathLike)):
logger.error(f"Invalid audio path type: {type(audio_path)}")
return None, f"Error: Audio input must be a file path, not {type(audio_path)}."
logger.info("Inputs validated successfully")
result = (warpfn(sad_talker.test) if warpfn else sad_talker.test)(
image_path, audio_path, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style
)
logger.info("Video generation completed")
return result, None
except Exception as e:
logger.error(f"Error during video generation: {str(e)}")
return None, f"Error: {str(e)}"

if __name__ == "__main__":
submit.click(
fn=validate_and_run,
inputs=[
source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer,
batch_size,
size_of_image,
pose_style
],
outputs=[gen_video, error_message]
)

demo = sadtalker_demo()
demo.queue()
demo.launch()
return sadtalker_interface


if __name__ == "__main__":
try:
demo = sadtalker_demo()
demo.queue()
demo.launch()
except Exception as e:
logger.error(f"Failed to launch Gradio interface: {str(e)}")
raise
2 changes: 1 addition & 1 deletion launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,4 @@ def start():

if __name__ == "__main__":
prepare_environment()
start()
start()
2 changes: 1 addition & 1 deletion src/audio2pose_models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def forward(self, x):
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
x = self.seq(x)
x = x.squeeze(1)
return x
return x
2 changes: 1 addition & 1 deletion src/audio2pose_models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,4 @@ def __init__(self, input_encoder, input_decoder, output_dim):
def forward(self, x1, x2):
out = self.conv_encoder(x1) + self.conv_decoder(x2)
out = self.conv_attn(out)
return out * x2
return out * x2
2 changes: 1 addition & 1 deletion src/audio2pose_models/res_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def forward(self, x):

output = self.output_layer(x10)

return output
return output
2 changes: 1 addition & 1 deletion src/face3d/models/arcface_torch/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def get_model(name, **kwargs):
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)
else:
raise ValueError()
raise ValueError()
2 changes: 1 addition & 1 deletion src/face3d/models/arcface_torch/backbones/mobilefacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def forward(self, x):


def get_mbf(fp16, num_features):
return MobileFaceNet(fp16, num_features)
return MobileFaceNet(fp16, num_features)
2 changes: 1 addition & 1 deletion src/face3d/models/arcface_torch/torch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def convert_onnx(net, path_module, output, opset=11, simplify=False):
assert isinstance(net, torch.nn.Module)
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
img = img.astype(np.float)
img = img.astype(float)
img = (img / 255. - 0.5) / 0.5 # torch style norm
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float()
Expand Down
2 changes: 1 addition & 1 deletion src/face3d/models/arcface_torch/utils/utils_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def get_config(config_file):
cfg.update(job_cfg)
if cfg.output is None:
cfg.output = osp.join('work_dirs', temp_module_name)
return cfg
return cfg
2 changes: 1 addition & 1 deletion src/face3d/models/bfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,4 +328,4 @@ def compute_for_render_woRotation(self, coeffs):


if __name__ == '__main__':
transferBFM09()
transferBFM09()
2 changes: 1 addition & 1 deletion src/face3d/util/load_mats.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ def load_lm3d(bfm_folder):


if __name__ == '__main__':
transferBFM09()
transferBFM09()
2 changes: 1 addition & 1 deletion src/face3d/util/my_awing_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def calculate_points(heatmaps):
indexes = np.argmax(heatline, axis=2)

preds = np.stack((indexes % W, indexes // W), axis=2)
preds = preds.astype(np.float, copy=False)
preds = preds.astype(float, copy=False)

inr = indexes.ravel()

Expand Down
2 changes: 1 addition & 1 deletion src/face3d/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def print_numpy(x, val=True, shp=False):
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
x = x.astype(float64)
if shp:
print('shape,', x.shape)
if val:
Expand Down
2 changes: 1 addition & 1 deletion src/facerender/modules/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,4 @@ def forward(self, source_image, kp_driving, kp_source):

output_dict["prediction"] = out

return output_dict
return output_dict
2 changes: 1 addition & 1 deletion src/facerender/modules/make_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ def forward(self, x):
self.mapping, use_exp = True,
yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)

return predictions_video
return predictions_video
2 changes: 1 addition & 1 deletion src/facerender/modules/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def forward(self, input_3dmm):
t = self.fc_t(out)
exp = self.fc_exp(out)

return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
Loading