[PyTorch] How to clear the modules output history when using register_forward_hook?

First I use a pytorch pretrained Resnet, then I use these codes to get the hidden feature.

feat_out = []

def hook_fn_forward(module, input, output):
    feat_out.append(output)
    print(output)

modules = model.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)

pred = model(x)

But when I run these codes the first time, len(feat_out) gives me 10, and print in hook function prints 10 lines. If I run these codes again, len(feat_out) gives me 20, and print in hook function prints 20 lines. Every time I run, the length of output in hook function increase by 1. The output is output in this time plus all past output. Only if I reinitialize the model and run these codes, the past output history will be removed.

How can I clear the output every time I run the model?

I use these codes in colab to reproduce this problem in minimum length (5 lines to load data, 2 lines to initialize model, 8 lines for this problem).

I avoid this problem by using dict to replace the outputs.

model_ft = models.resnet18(pretrained=True)

feature_out = {}
layers_name = list(model_ft._modules.keys())
layers = list(model_ft._modules.values())


def hook_fn_forward(module, input, output):
    layer = layers_name[np.argwhere([module == m for m in layers])[0, 0]]
    total_feat_out[layer] = output
    

modules = model_ft.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)

model_ft.eval() 
with torch.no_grad():
    pred = model_ft(dat)
1 Like

Also, I found that I didn’t remove the hooks, which essencially caused the problems.

total_feat_out = []

def hook_fn_forward(module, input, output):
    total_feat_out.append(output) 
    print(output.shape)

modules = model_ft.named_children()
handles = {}
for name, module in modules:
    handles[name] = module.register_forward_hook(hook_fn_forward)

model_ft.eval() 
with torch.no_grad():
    pred = model_ft(dat)
    for k, v in handles.items():
        handles[k].remove()