Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Apex import Guard for MM collection #9099

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from omegaconf.dictconfig import DictConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from apex.contrib.group_norm import GroupNorm

try:
from apex.contrib.group_norm import GroupNorm

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

def conv_nd(dims, *args, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@

import torch
import torch.nn.functional as F
from apex.contrib.group_norm import GroupNorm

try:
from apex.contrib.group_norm import GroupNorm

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

from einops import rearrange, repeat
from torch import einsum, nn
from torch._dynamo import disable
Expand Down Expand Up @@ -142,6 +149,11 @@ def forward(self, x):

class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
if not HAVE_APEX:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
self.in_channels = in_channels

Expand Down Expand Up @@ -432,6 +444,11 @@ def __init__(
use_flash_attention=False,
lora_network_alpha=None,
):
if not HAVE_APEX:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
logging.info(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import numpy as np
import torch
import torch.nn as nn
from apex.contrib.group_norm import GroupNorm

try:
from apex.contrib.group_norm import GroupNorm

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

from einops import rearrange

from nemo.collections.multimodal.modules.stable_diffusion.attention import LinearAttention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from apex.contrib.group_norm import GroupNorm

try:
from apex.contrib.group_norm import GroupNorm

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer
from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import (
Expand Down Expand Up @@ -554,6 +560,11 @@ def __init__(
lora_network_alpha=None,
timesteps=1000,
):
if not HAVE_APEX:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
from omegaconf.listconfig import ListConfig

Expand Down Expand Up @@ -1229,6 +1240,11 @@ def __init__(
*args,
**kwargs,
):
if not HAVE_APEX:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
import numpy as np
import torch
import torch.nn as nn
from apex.contrib.group_norm import GroupNorm

try:
from apex.contrib.group_norm import GroupNorm

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

from einops import repeat
from torch._dynamo import disable
from torch.cuda.amp import custom_bwd, custom_fwd
Expand Down