diff --git a/decalib/deca.py b/decalib/deca.py index 0a0c0d32..474b3ff2 100644 --- a/decalib/deca.py +++ b/decalib/deca.py @@ -51,7 +51,9 @@ def _setup_renderer(self, model_cfg): self.render = SRenderY(self.image_size, obj_filename=model_cfg.topology_path, uv_size=model_cfg.uv_size).to(self.device) # face mask for rendering details mask = imread(model_cfg.face_eye_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous() - self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) + np_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).detach().numpy() + np_eye_mask[np_eye_mask > 0] = 1 + self.uv_face_eye_mask = torch.from_numpy(np_eye_mask).cuda() mask = imread(model_cfg.face_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous() self.uv_face_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) # displacement correction @@ -193,9 +195,11 @@ def decode(self, codedict): uv_gt = F.grid_sample(images, uv_pverts.permute(0,2,3,1)[:,:,:,:2], mode='bilinear') if self.cfg.model.use_tex: ## TODO: poisson blending should give better-looking results - uv_texture_gt = uv_gt[:,:3,:,:]*self.uv_face_eye_mask + (uv_texture[:,:3,:,:]*(1-self.uv_face_eye_mask)*0.7) + uv_texture_gt = uv_gt[:, :3, :, :] * self.uv_face_eye_mask else: - uv_texture_gt = uv_gt[:,:3,:,:]*self.uv_face_eye_mask + (torch.ones_like(uv_gt[:,:3,:,:])*(1-self.uv_face_eye_mask)*0.7) + uv_texture_gt = uv_gt[:, :3, :, :] * self.uv_face_eye_mask + + texture_iamge = F.grid_sample(uv_gt, ops['grid'], align_corners=False) ## output opdict = { @@ -217,7 +221,8 @@ def decode(self, codedict): 'landmarks2d': util.tensor_vis_landmarks(images, landmarks2d, isScale=False), 'landmarks3d': util.tensor_vis_landmarks(images, landmarks3d, isScale=False), 'shape_images': shape_images, - 'shape_detail_images': shape_detail_images + 'shape_detail_images': shape_detail_images, + 'texture_iamge': texture_iamge } if self.cfg.model.use_tex: visdict['rendered_images'] = ops['images']