Skip to content

More options of input/output types in coord_check #80

@francois-rozet

Description

@francois-rozet

Hello 👋

I am having issues with using mup.get_coord_data because some of my modules return dataclass objects. Currently only, dict, list, tuple and tensors are supported. It would be great, and fairly easy, to also support dataclasses.

I think that the only code to modify would be

mup/mup/coord_check.py

Lines 129 to 148 in 1981497

def get_stat(d, x, fdict):
if isinstance(x, (tuple, list)):
for i, _x in enumerate(x):
_d = copy(d)
_d['module'] += f'[{i}]'
get_stat(_d, _x, fdict)
elif isinstance(x, dict):
for name, _x in x.items():
_d = copy(d)
_d['module'] += f'[{name}]'
get_stat(_d, _x, fdict)
elif isinstance(x, torch.Tensor):
_d = copy(d)
for fname, f in fdict.items():
_d[fname] = f(x).item()
records.append(_d)
elif x is None:
pass
else:
raise NotImplementedError(f'Unexpected output type: {type(x)}')

I can do a PR for that.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions